Skip to content

Commit

Permalink
Change caching of global window inputs to be guarded by experiment (a…
Browse files Browse the repository at this point in the history
…pache#31013)

* Change caching of global window inputs to be guarded by experiment
disable_global_windowed_args_caching
  • Loading branch information
scwhittle authored Apr 18, 2024
1 parent 4f964bf commit bcb40cf
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 25 deletions.
4 changes: 3 additions & 1 deletion sdks/python/apache_beam/runners/common.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ cdef class PerWindowInvoker(DoFnInvoker):
cdef dict kwargs_for_process_batch
cdef list placeholders_for_process_batch
cdef bint has_windowed_inputs
cdef bint cache_globally_windowed_args
cdef bint recalculate_window_args
cdef bint has_cached_window_args
cdef bint has_cached_window_batch_args
cdef object process_method
cdef object process_batch_method
cdef bint is_splittable
Expand Down
75 changes: 51 additions & 24 deletions sdks/python/apache_beam/runners/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,17 @@ def __init__(self,
self.current_window_index = None
self.stop_window_index = None

# TODO(https://github.com/apache/beam/issues/28776): Remove caching after
# fully rolling out.
# If true, always recalculate window args. If false, has_cached_window_args
# and has_cached_window_batch_args will be set to true if the corresponding
# self.args_for_process,have been updated and should be reused directly.
self.recalculate_window_args = (
self.has_windowed_inputs or 'disable_global_windowed_args_caching' in
RuntimeValueProvider.experiments)
self.has_cached_window_args = False
self.has_cached_window_batch_args = False

# Try to prepare all the arguments that can just be filled in
# without any additional work. in the process function.
# Also cache all the placeholders needed in the process function.
Expand Down Expand Up @@ -921,16 +932,23 @@ def _invoke_process_per_window(self,
additional_kwargs,
):
# type: (...) -> Optional[SplitResultResidual]
if self.has_windowed_inputs:
assert len(windowed_value.windows) <= 1
window, = windowed_value.windows
if self.has_cached_window_args:
args_for_process, kwargs_for_process = (
self.args_for_process, self.kwargs_for_process)
else:
window = GlobalWindow()
side_inputs = [si[window] for si in self.side_inputs]
side_inputs.extend(additional_args)
args_for_process, kwargs_for_process = util.insert_values_in_args(
self.args_for_process, self.kwargs_for_process,
side_inputs)
if self.has_windowed_inputs:
assert len(windowed_value.windows) <= 1
window, = windowed_value.windows
else:
window = GlobalWindow()
side_inputs = [si[window] for si in self.side_inputs]
side_inputs.extend(additional_args)
args_for_process, kwargs_for_process = util.insert_values_in_args(
self.args_for_process, self.kwargs_for_process, side_inputs)
if not self.recalculate_window_args:
self.args_for_process, self.kwargs_for_process = (
args_for_process, kwargs_for_process)
self.has_cached_window_args = True

# Extract key in the case of a stateful DoFn. Note that in the case of a
# stateful DoFn, we set during __init__ self.has_windowed_inputs to be
Expand Down Expand Up @@ -1012,20 +1030,29 @@ def _invoke_process_batch_per_window(
):
# type: (...) -> Optional[SplitResultResidual]

if self.has_windowed_inputs:
assert isinstance(windowed_batch, HomogeneousWindowedBatch)
assert len(windowed_batch.windows) <= 1
window, = windowed_batch.windows
if self.has_cached_window_batch_args:
args_for_process_batch, kwargs_for_process_batch = (
self.args_for_process_batch, self.kwargs_for_process_batch)
else:
window = GlobalWindow()
side_inputs = [si[window] for si in self.side_inputs]
side_inputs.extend(additional_args)
(args_for_process_batch, kwargs_for_process_batch) = (
util.insert_values_in_args(
self.args_for_process_batch,
self.kwargs_for_process_batch,
side_inputs,
))
if self.has_windowed_inputs:
assert isinstance(windowed_batch, HomogeneousWindowedBatch)
assert len(windowed_batch.windows) <= 1
window, = windowed_batch.windows
else:
window = GlobalWindow()
side_inputs = [si[window] for si in self.side_inputs]
side_inputs.extend(additional_args)
args_for_process_batch, kwargs_for_process_batch = (
util.insert_values_in_args(
self.args_for_process_batch,
self.kwargs_for_process_batch,
side_inputs,
)
)
if not self.recalculate_window_args:
self.args_for_process_batch, self.kwargs_for_process_batch = (
args_for_process_batch, kwargs_for_process_batch)
self.has_cached_window_batch_args = True

for i, p in self.placeholders_for_process_batch:
if core.DoFn.ElementParam == p:
Expand Down Expand Up @@ -1541,8 +1568,8 @@ def __init__(self,
tagged_receivers, # type: Mapping[Optional[str], Receiver]
per_element_output_counter,
output_batch_converter, # type: Optional[BatchConverter]
process_yields_batches, # type: bool,
process_batch_yields_elements, # type: bool,
process_yields_batches, # type: bool
process_batch_yields_elements, # type: bool
):
"""Initializes ``_OutputHandler``.
Expand Down

0 comments on commit bcb40cf

Please sign in to comment.