Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix interval join to work with Bytewax 0.20.1 #5

Merged
merged 2 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytests/test_join_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def test_join_interval_complete() -> None:
run_main(flow)
assert out_down == [
("left", "right1"),
("left", "right2"),
]


Expand Down
1 change: 1 addition & 0 deletions src/bytewax/interval_join/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Interval joins for Bytewax."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's your thinking on the premium package name structure? bytewax.interval_join.operators.interval seems kind of long and repetitive, but we can only have the namespace package at the top level. Put everything in bytewax.interval.operators? Or just bytewax.interval?

Also this specific module maybe just call interval because it does provide more functionality than just joins.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think bytewax.interval makes sense to me. As you point out, we provide more than interval_join.

I was originally thinking that bytewax.interval.operators makes sense, as other premium packages may provide things other than operators, but I could be convinced otherwise.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Interval joins for Bytewax."""
"""Operators that find items on different streams close in time."""

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. I'll apply this in another PR. I missed this suggestion before I merged this PR.

37 changes: 17 additions & 20 deletions src/bytewax/interval_join/operators/interval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,19 +459,20 @@ def shim_builder(


@dataclass
class _JoinIntervalCompleteLogic(IntervalLogic[Tuple[str, V], _JoinState, _JoinState]):
class _JoinIntervalCompleteLogic(IntervalLogic[Tuple[int, V], _JoinState, _JoinState]):
state: _JoinState

@override
def on_value(self, side: LeftRight, value: Tuple[str, V]) -> Iterable[_JoinState]:
def on_value(self, side: LeftRight, value: Tuple[int, V]) -> Iterable[_JoinState]:
join_side, join_value = value
self.state.set_val(join_side, join_value)

if self.state.all_set():
state = copy.deepcopy(self.state)
# Only reset right side since we'll never see left side
# Only reset all right sides since we'll never see left side
# again by definition in an interval.
self.state.seen["right"] = []
for i in range(1, len(self.state.seen)):
self.state.seen[i] = []
return (state,)
else:
return _EMPTY
Expand All @@ -486,11 +487,11 @@ def snapshot(self) -> _JoinState:


@dataclass
class _JoinIntervalFinalLogic(IntervalLogic[Tuple[str, V], _JoinState, _JoinState]):
class _JoinIntervalFinalLogic(IntervalLogic[Tuple[int, V], _JoinState, _JoinState]):
state: _JoinState

@override
def on_value(self, side: LeftRight, value: Tuple[str, V]) -> Iterable[_JoinState]:
def on_value(self, side: LeftRight, value: Tuple[int, V]) -> Iterable[_JoinState]:
join_side, join_value = value
self.state.set_val(join_side, join_value)
return _EMPTY
Expand All @@ -506,11 +507,11 @@ def snapshot(self) -> _JoinState:


@dataclass
class _JoinIntervalRunningLogic(IntervalLogic[Tuple[str, V], _JoinState, _JoinState]):
class _JoinIntervalRunningLogic(IntervalLogic[Tuple[int, V], _JoinState, _JoinState]):
state: _JoinState

@override
def on_value(self, side: LeftRight, value: Tuple[str, V]) -> Iterable[_JoinState]:
def on_value(self, side: LeftRight, value: Tuple[int, V]) -> Iterable[_JoinState]:
join_side, join_value = value
self.state.set_val(join_side, join_value)
return (copy.deepcopy(self.state),)
Expand All @@ -525,11 +526,11 @@ def snapshot(self) -> _JoinState:


@dataclass
class _JoinIntervalProductLogic(IntervalLogic[Tuple[str, V], _JoinState, _JoinState]):
class _JoinIntervalProductLogic(IntervalLogic[Tuple[int, V], _JoinState, _JoinState]):
state: _JoinState

@override
def on_value(self, side: LeftRight, value: Tuple[str, V]) -> Iterable[_JoinState]:
def on_value(self, side: LeftRight, value: Tuple[int, V]) -> Iterable[_JoinState]:
join_side, join_value = value
self.state.add_val(join_side, join_value)
return _EMPTY
Expand All @@ -544,11 +545,9 @@ def snapshot(self) -> _JoinState:
return copy.deepcopy(self.state)


def _add_side_builder(i: int) -> Callable[[V], Tuple[str, V]]:
s = str(i)

def add_side(v: V) -> Tuple[str, V]:
return (s, v)
def _add_side_builder(i: int) -> Callable[[V], Tuple[int, V]]:
def add_side(v: V) -> Tuple[int, V]:
return (i, v)

return add_side

Expand Down Expand Up @@ -711,10 +710,8 @@ def shim_getter(i_v: Tuple[str, V]) -> datetime:
wait_for_system_duration=clock.wait_for_system_duration,
)

names = [str(i) for i in range(len(rights) + 1)]

logic_class: Callable[
[_JoinState], IntervalLogic[Tuple[str, V], _JoinState, _JoinState]
[_JoinState], IntervalLogic[Tuple[int, V], _JoinState, _JoinState]
]
if mode == "complete":
logic_class = _JoinIntervalCompleteLogic
Expand All @@ -730,8 +727,8 @@ def shim_getter(i_v: Tuple[str, V]) -> datetime:

def shim_builder(
resume_state: Optional[_JoinState],
) -> IntervalLogic[Tuple[str, V], _JoinState, _JoinState]:
state = _JoinState.for_names(names)
) -> IntervalLogic[Tuple[int, V], _JoinState, _JoinState]:
state = _JoinState.for_side_count(len(rights) + 1)
return logic_class(state)

interval_out = interval(
Expand Down
Loading