Skip to content

Commit

Permalink
[Data] support passing kwargs to map tasks. (ray-project#49208)
Browse files Browse the repository at this point in the history
## Why are these changes needed?

This PR enables passing kwargs to map tasks, which will be accessible
via `TaskContext.kwargs`.

This is a prerequisite to fixing
ray-project#49207. And optimization rules
can use this API to pass additional arguments to the map tasks.

---------

Signed-off-by: Hao Chen <[email protected]>
  • Loading branch information
raulchen authored and simonsays1980 committed Dec 12, 2024
1 parent a582045 commit 9c0c2ad
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional

from ray.data._internal.progress_bar import ProgressBar
Expand Down Expand Up @@ -39,3 +39,6 @@ class TaskContext:

# The target maximum number of bytes to include in the task's output block.
target_max_block_size: Optional[int] = None

# Additional keyword arguments passed to the task.
kwargs: Dict[str, Any] = field(default_factory=dict)
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,12 @@ def _dispatch_tasks(self):
num_returns="streaming",
name=self.name,
**self._ray_actor_task_remote_args,
).remote(self.data_context, ctx, *input_blocks)
).remote(
self.data_context,
ctx,
*input_blocks,
**self.get_map_task_kwargs(),
)

def _task_done_callback(actor_to_return):
# Return the actor that was running the task to the pool.
Expand Down Expand Up @@ -391,12 +396,14 @@ def submit(
data_context: DataContext,
ctx: TaskContext,
*blocks: Block,
**kwargs: Dict[str, Any],
) -> Iterator[Union[Block, List[BlockMetadata]]]:
yield from _map_task(
self._map_transformer,
data_context,
ctx,
*blocks,
**kwargs,
)

def __repr__(self):
Expand Down
21 changes: 21 additions & 0 deletions python/ray/data/_internal/execution/operators/map_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,25 @@ def __init__(
# too-large blocks, which may reduce parallelism for
# the subsequent operator.
self._additional_split_factor = None
# Callback functions that generate additional task kwargs
# for the map task.
self._map_task_kwargs_fns: List[Callable[[], Dict[str, Any]]] = []

def add_map_task_kwargs_fn(self, map_task_kwargs_fn: Callable[[], Dict[str, Any]]):
"""Add a callback function that generates additional kwargs for the map tasks.
In the map tasks, the kwargs can be accessible via `TaskContext.kwargs`.
"""
self._map_task_kwargs_fns.append(map_task_kwargs_fn)

def get_map_task_kwargs(self) -> Dict[str, Any]:
"""Get the kwargs for the map task.
Subclasses should pass the returned kwargs to the map tasks.
In the map tasks, the kwargs can be accessible via `TaskContext.kwargs`.
"""
kwargs = {}
for fn in self._map_task_kwargs_fns:
kwargs.update(fn())
return kwargs

def get_additional_split_factor(self) -> int:
if self._additional_split_factor is None:
Expand Down Expand Up @@ -468,6 +487,7 @@ def _map_task(
data_context: DataContext,
ctx: TaskContext,
*blocks: Block,
**kwargs: Dict[str, Any],
) -> Iterator[Union[Block, List[BlockMetadata]]]:
"""Remote function for a single operator task.
Expand All @@ -481,6 +501,7 @@ def _map_task(
as the last generator return.
"""
DataContext._set_current(data_context)
ctx.kwargs.update(kwargs)
stats = BlockExecStats.builder()
map_transformer.set_target_max_block_size(ctx.target_max_block_size)
for b_out in map_transformer.apply_transform(iter(blocks), ctx):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def _add_bundled_input(self, bundle: RefBundle):
data_context,
ctx,
*bundle.block_refs,
**self.get_map_task_kwargs(),
)
self._submit_data_task(gen, bundle)

Expand Down
5 changes: 5 additions & 0 deletions python/ray/data/_internal/logical/rules/operator_fusion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
from typing import List, Optional, Tuple

# TODO(Clark): Remove compute dependency once we delete the legacy compute.
Expand Down Expand Up @@ -330,6 +331,10 @@ def _get_fused_map_operator(
ray_remote_args_fn=ray_remote_args_fn,
)
op.set_logical_operators(*up_op._logical_operators, *down_op._logical_operators)
for map_task_kwargs_fn in itertools.chain(
up_op._map_task_kwargs_fns, down_op._map_task_kwargs_fns
):
op.add_map_task_kwargs_fn(map_task_kwargs_fn)

# Build a map logical operator to be used as a reference for further fusion.
# TODO(Scott): This is hacky, remove this once we push fusion to be purely based
Expand Down
42 changes: 42 additions & 0 deletions python/ray/data/tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
PhysicalOperator,
RefBundle,
)
from ray.data._internal.execution.interfaces.task_context import TaskContext
from ray.data._internal.execution.operators.actor_pool_map_operator import (
ActorPoolMapOperator,
)
Expand Down Expand Up @@ -1022,6 +1023,47 @@ def yield_five(block_iter: Iterable[Block], ctx) -> Iterable[Block]:
assert op._estimated_num_output_bundles == 100


@pytest.mark.parametrize("use_actors", [False, True])
def test_map_kwargs(ray_start_regular_shared, use_actors):
"""Test propagating additional kwargs to map tasks."""
foo = 1
bar = np.random.random(1024 * 1024)
kwargs = {
"foo": foo, # Pass by value
"bar": ray.put(bar), # Pass by ObjectRef
}

def map_fn(block_iter: Iterable[Block], ctx: TaskContext) -> Iterable[Block]:
nonlocal foo, bar
assert ctx.kwargs["foo"] == foo
# bar should be automatically deref'ed.
assert np.array_equal(ctx.kwargs["bar"], bar)

yield from block_iter

input_op = InputDataBuffer(
DataContext.get_current(),
make_ref_bundles([[i] for i in range(10)]),
)
compute_strategy = ActorPoolStrategy() if use_actors else TaskPoolStrategy()
op = MapOperator.create(
create_map_transformer_from_block_fn(map_fn),
input_op=input_op,
data_context=DataContext.get_current(),
name="TestMapper",
compute_strategy=compute_strategy,
)
op.add_map_task_kwargs_fn(lambda: kwargs)
op.start(ExecutionOptions())
while input_op.has_next():
op.add_input(input_op.get_next(), 0)
op.all_inputs_done()
run_op_tasks_sync(op)

_take_outputs(op)
assert op.completed()


def test_limit_estimated_num_output_bundles():
# Test limit operator estimation
input_op = InputDataBuffer(
Expand Down

0 comments on commit 9c0c2ad

Please sign in to comment.