Skip to content

Commit

Permalink
Add EventTimeFilter and BaseRelation.render_event_time_filtered (#285)
Browse files Browse the repository at this point in the history
Co-authored-by: Colin Rogers <[email protected]>
  • Loading branch information
MichelleArk and colin-rogers-dbt authored Sep 10, 2024
1 parent dadd0f2 commit cffa724
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 4 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20240905-180956.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Add EventTimeFilter to BaseRelation, which renders a filtered relation when
start or end is set
time: 2024-09-05T18:09:56.159385-04:00
custom:
Author: 'michelleark QMalcolm'
Issue: "294"
54 changes: 51 additions & 3 deletions dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Hashable
from dataclasses import dataclass, field
from datetime import datetime
from typing import (
Any,
Dict,
Expand Down Expand Up @@ -36,6 +37,13 @@
SerializableIterable = Union[Tuple, FrozenSet]


@dataclass(frozen=True, eq=False, repr=False)
class EventTimeFilter(FakeAPIObject, Hashable):
field_name: str
start: Optional[datetime] = None
end: Optional[datetime] = None


@dataclass(frozen=True, eq=False, repr=False)
class BaseRelation(FakeAPIObject, Hashable):
path: Path
Expand All @@ -47,6 +55,7 @@ class BaseRelation(FakeAPIObject, Hashable):
quote_policy: Policy = field(default_factory=lambda: Policy())
dbt_created: bool = False
limit: Optional[int] = None
event_time_filter: Optional[EventTimeFilter] = None
require_alias: bool = (
True # used to govern whether to add an alias when render_limited is called
)
Expand Down Expand Up @@ -208,14 +217,19 @@ def render(self) -> str:
# if there is nothing set, this will return the empty string.
return ".".join(part for _, part in self._render_iterator() if part is not None)

def _render_limited_alias(self) -> str:
def _render_subquery_alias(self, namespace: str) -> str:
"""Some databases require an alias for subqueries (postgres, mysql) for all others we want to avoid adding
an alias as it has the potential to introduce issues with the query if the user also defines an alias.
"""
if self.require_alias:
return f" _dbt_limit_subq_{self.table}"
return f" _dbt_{namespace}_subq_{self.table}"
return ""

def _render_limited_alias(
self,
) -> str:
return self._render_subquery_alias(namespace="limit")

def render_limited(self) -> str:
rendered = self.render()
if self.limit is None:
Expand All @@ -225,6 +239,31 @@ def render_limited(self) -> str:
else:
return f"(select * from {rendered} limit {self.limit}){self._render_limited_alias()}"

def render_event_time_filtered(self, rendered: Optional[str] = None) -> str:
rendered = rendered or self.render()
if self.event_time_filter is None:
return rendered

filter = self._render_event_time_filtered(self.event_time_filter)
if not filter:
return rendered

return f"(select * from {rendered} where {filter}){self._render_subquery_alias(namespace='et_filter')}"

def _render_event_time_filtered(self, event_time_filter: EventTimeFilter) -> str:
"""
Returns "" if start and end are both None
"""
filter = ""
if event_time_filter.start and event_time_filter.end:
filter = f"{event_time_filter.field_name} >= '{event_time_filter.start}' and {event_time_filter.field_name} < '{event_time_filter.end}'"
elif event_time_filter.start:
filter = f"{event_time_filter.field_name} >= '{event_time_filter.start}'"
elif event_time_filter.end:
filter = f"{event_time_filter.field_name} < '{event_time_filter.end}'"

return filter

def quoted(self, identifier):
return "{quote_char}{identifier}{quote_char}".format(
quote_char=self.quote_character,
Expand All @@ -240,6 +279,7 @@ def create_ephemeral_from(
cls: Type[Self],
relation_config: RelationConfig,
limit: Optional[int] = None,
event_time_filter: Optional[EventTimeFilter] = None,
) -> Self:
# Note that ephemeral models are based on the identifier, which will
# point to the model's alias if one exists and otherwise fall back to
Expand All @@ -250,6 +290,7 @@ def create_ephemeral_from(
type=cls.CTE,
identifier=identifier,
limit=limit,
event_time_filter=event_time_filter,
).quote(identifier=False)

@classmethod
Expand Down Expand Up @@ -315,7 +356,14 @@ def __hash__(self) -> int:
return hash(self.render())

def __str__(self) -> str:
return self.render() if self.limit is None else self.render_limited()
rendered = self.render() if self.limit is None else self.render_limited()

# Limited subquery is wrapped by the event time filter subquery, and not the other way around.
# This is because in the context of resolving limited refs, we care more about performance than reliably producing a sample of a certain size.
if self.event_time_filter:
rendered = self.render_event_time_filtered(rendered)

return rendered

@property
def database(self) -> Optional[str]:
Expand Down
77 changes: 76 additions & 1 deletion tests/unit/test_relation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from dataclasses import dataclass, replace

from datetime import datetime
import pytest

from dbt.adapters.base import BaseRelation
from dbt.adapters.base.relation import EventTimeFilter
from dbt.adapters.contracts.relation import RelationType


Expand Down Expand Up @@ -81,6 +82,80 @@ def test_render_limited(limit, require_alias, expected_result):
assert str(my_relation) == expected_result


@pytest.mark.parametrize(
"event_time_filter,require_alias,expected_result",
[
(None, False, '"test_database"."test_schema"."test_identifier"'),
(
EventTimeFilter(field_name="column"),
False,
'"test_database"."test_schema"."test_identifier"',
),
(None, True, '"test_database"."test_schema"."test_identifier"'),
(
EventTimeFilter(field_name="column"),
True,
'"test_database"."test_schema"."test_identifier"',
),
(
EventTimeFilter(field_name="column", start=datetime(year=2020, month=1, day=1)),
False,
"""(select * from "test_database"."test_schema"."test_identifier" where column >= '2020-01-01 00:00:00')""",
),
(
EventTimeFilter(field_name="column", start=datetime(year=2020, month=1, day=1)),
True,
"""(select * from "test_database"."test_schema"."test_identifier" where column >= '2020-01-01 00:00:00') _dbt_et_filter_subq_test_identifier""",
),
(
EventTimeFilter(field_name="column", end=datetime(year=2020, month=1, day=1)),
False,
"""(select * from "test_database"."test_schema"."test_identifier" where column < '2020-01-01 00:00:00')""",
),
(
EventTimeFilter(
field_name="column",
start=datetime(year=2020, month=1, day=1),
end=datetime(year=2020, month=1, day=2),
),
False,
"""(select * from "test_database"."test_schema"."test_identifier" where column >= '2020-01-01 00:00:00' and column < '2020-01-02 00:00:00')""",
),
],
)
def test_render_event_time_filtered(event_time_filter, require_alias, expected_result):
my_relation = BaseRelation.create(
database="test_database",
schema="test_schema",
identifier="test_identifier",
event_time_filter=event_time_filter,
require_alias=require_alias,
)
actual_result = my_relation.render_event_time_filtered()
assert actual_result == expected_result
assert str(my_relation) == expected_result


def test_render_event_time_filtered_and_limited():
my_relation = BaseRelation.create(
database="test_database",
schema="test_schema",
identifier="test_identifier",
event_time_filter=EventTimeFilter(
field_name="column",
start=datetime(year=2020, month=1, day=1),
end=datetime(year=2020, month=1, day=2),
),
limit=0,
require_alias=False,
)
expected_result = """(select * from (select * from "test_database"."test_schema"."test_identifier" where false limit 0) where column >= '2020-01-01 00:00:00' and column < '2020-01-02 00:00:00')"""

actual_result = my_relation.render_event_time_filtered(my_relation.render_limited())
assert actual_result == expected_result
assert str(my_relation) == expected_result


def test_create_ephemeral_from_uses_identifier():
@dataclass
class Node:
Expand Down

0 comments on commit cffa724

Please sign in to comment.