Skip to content

Commit

Permalink
[core][compiled graphs] Introduce with_tensor_transport API
Browse files Browse the repository at this point in the history
Signed-off-by: Rui Qiao <[email protected]>
  • Loading branch information
ruisearch42 committed Jan 21, 2025
1 parent 7a372c7 commit 99eb84b
Show file tree
Hide file tree
Showing 11 changed files with 391 additions and 114 deletions.
76 changes: 72 additions & 4 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
import uuid
import traceback

from ray.experimental.channel.auto_transport_type import (
AutoTransportType,
TypeHintResolver,
)
import ray.exceptions
from ray.dag.dag_operation_future import GPUFuture, DAGOperationFuture, ResolvedFuture
from ray.experimental.channel.cached_channel import CachedChannel
Expand Down Expand Up @@ -375,7 +379,9 @@ def __init__(self, idx: int, dag_node: "ray.dag.DAGNode"):
# corresponding value from `args` or `kwargs` in the DAG's input.
self.output_channels: List[ChannelInterface] = []
self.output_idxs: List[Optional[Union[int, str]]] = []
self.arg_type_hints: List["ChannelOutputType"] = []
# The DAGNodes that are arguments to this task.
# This is used for lazy resolution of the arguments' type hints.
self.arg_nodes: List["ray.dag.DAGNode"] = []
# idxs of possible ClassMethodOutputNodes if they exist, used for visualization
self.output_node_idxs: List[int] = []

Expand All @@ -391,6 +397,10 @@ def kwargs(self) -> Dict[str, Any]:
def num_readers(self) -> int:
return len(self.downstream_task_idxs)

@property
def arg_type_hints(self) -> List["ChannelOutputType"]:
return [arg_node.type_hint for arg_node in self.arg_nodes]

def __str__(self) -> str:
return f"""
Node: {self.dag_node}
Expand Down Expand Up @@ -890,6 +900,9 @@ def __init__(
self.actor_to_tasks: Dict[
"ray.actor.ActorHandle", List["CompiledTask"]
] = defaultdict(list)
# Mapping from actor handle to its GPU IDs.
# This is used for type hint resolution for with_tensor_transport("auto").
self.actor_to_gpu_ids: Dict["ray.actor.ActorHandle", List[str]] = {}
self.actor_to_executable_tasks: Dict[
"ray.actor.ActorHandle", List["ExecutableTask"]
] = {}
Expand Down Expand Up @@ -1043,6 +1056,8 @@ def _preprocess(self) -> None:
# Collect the set of InputNode keys bound to DAG node args.
input_positional_args: Set[int] = set()
input_kwargs: Set[str] = set()
# Set of tasks with annotation of with_tensor_transport("auto").
auto_transport_tasks: Set["CompiledTask"] = set()

# For each task node, set its upstream and downstream task nodes.
# Also collect the set of tasks that produce torch.tensors.
Expand Down Expand Up @@ -1070,6 +1085,14 @@ def _preprocess(self) -> None:
"that is already created with Actor.remote()"
)

if actor_handle not in self.actor_to_gpu_ids:
self.actor_to_gpu_ids[actor_handle] = self._get_gpu_ids(
actor_handle
)

if isinstance(dag_node.type_hint, AutoTransportType):
auto_transport_tasks.add(task)

# Collect actors for NCCL P2P methods.
if dag_node.type_hint.requires_nccl():
nccl_actors_p2p.add(actor_handle)
Expand Down Expand Up @@ -1140,8 +1163,9 @@ def _preprocess(self) -> None:
):
downstream_actor_handle = dag_node._get_actor_handle()

# Add the type hint of the upstream node to the task.
task.arg_type_hints.append(upstream_task.dag_node.type_hint)
# Add upstream node as the argument nodes of this task, whose
# type hints may be updated when resolved lazily.
task.arg_nodes.append(upstream_task.dag_node)

if isinstance(upstream_task.dag_node, InputAttributeNode):
# Record all of the keys used to index the InputNode.
Expand Down Expand Up @@ -1208,6 +1232,33 @@ def _preprocess(self) -> None:
f"the MultiOutputNode."
)

type_hint_resolver = TypeHintResolver(self.actor_to_gpu_ids)
# Resolve AutoChannelType type hints and track the actors that use NCCL.
# This is needed so that the NCCL group can be initialized for these
# actors that use NCCL.
for task in auto_transport_tasks:
writer = task.dag_node._get_actor_handle()
readers = task.downstream_task_idxs.values()
if any(reader is None for reader in readers):
# None means reader is the driver, currently driver on GPU
# is not supported, so we always use shared memory to transfer
# tensors.
task.dag_node.type_hint = TorchTensorType()
continue
writer_and_node = (writer, self._get_node_id(writer))
reader_and_node_list = [
(reader, self._get_node_id(reader)) for reader in readers
]
# Update the type hint to the resolved one. This is needed because
# the resolved type hint's `register_custom_serializer` will be called
# in preparation for channel I/O.
task.dag_node.type_hint = type_hint_resolver.resolve(
writer_and_node, reader_and_node_list
)
if task.dag_node.type_hint.requires_nccl():
nccl_actors_p2p.add(writer)
nccl_actors_p2p.update(readers)

nccl_actors_p2p = list(nccl_actors_p2p)
if None in nccl_actors_p2p:
raise ValueError("Driver cannot participate in the NCCL group.")
Expand Down Expand Up @@ -1292,6 +1343,17 @@ def _preprocess(self) -> None:
self._input_num_positional_args = max(input_positional_args) + 1
self._input_kwargs = tuple(input_kwargs)

def _get_gpu_ids(self, actor_handle: "ray.actor.ActorHandle") -> List[str]:
"""
Get the GPU IDs of an actor handle.
"""
accelerator_ids = ray.get(
actor_handle.__ray_call__.remote(
lambda self: ray.get_runtime_context().get_accelerator_ids()
)
)
return accelerator_ids.get("GPU", [])

def _get_node_id(self, actor_handle: "ray.actor.ActorHandle") -> str:
"""
Get the node ID of an actor handle and cache it.
Expand Down Expand Up @@ -1492,10 +1554,16 @@ def _get_or_compile(
reader_and_node_list = list(
input_node_to_reader_and_node_set[input_dag_node]
)

if isinstance(input_dag_node.type_hint, AutoTransportType):
# Currently driver on GPU is not supported, so we always
# use shared memory to transfer tensors.
input_dag_node.type_hint = TorchTensorType()

output_channel = do_allocate_channel(
self,
reader_and_node_list,
type_hint,
input_dag_node.type_hint,
None,
)
task.output_channels.append(output_channel)
Expand Down
45 changes: 41 additions & 4 deletions python/ray/dag/dag_node.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import copy
from ray.experimental.channel.auto_transport_type import AutoTransportType
from ray.experimental.channel.torch_tensor_type import TorchTensorType
import ray
from ray.dag.base import DAGNodeBase
from ray.dag.py_obj_scanner import _PyObjScanner
from ray.util.annotations import DeveloperAPI
import copy

from itertools import chain

Expand All @@ -21,6 +23,7 @@

from ray.dag.compiled_dag_node import build_compiled_dag_from_ray_dag
from ray.experimental.channel import ChannelOutputType
from ray.experimental.channel.communicator import Communicator

T = TypeVar("T")

Expand Down Expand Up @@ -79,6 +82,12 @@ def __init__(
self.cache_from_last_execute = {}

self._type_hint: ChannelOutputType = ChannelOutputType()

# If the original type hint is an AutoTransportType, we make a copy
# here when it is resolved to the actual type, as additional debugging
# information. Otherwise, it is None.
self._original_type_hint: Optional[ChannelOutputType] = None

# Whether this node calls `experimental_compile`.
self.is_cgraph_output_node = False

Expand Down Expand Up @@ -129,14 +138,42 @@ def _collect_upstream_nodes(self) -> List["DAGNode"]:
upstream_node._downstream_nodes.append(self)
return upstream_nodes

def with_type_hint(self, typ: ChannelOutputType):
self._type_hint = copy.deepcopy(typ)
def with_tensor_transport(
self,
transport: Optional[Union[str, Communicator]] = "auto",
_static_shape: bool = False,
_direct_return: bool = False,
):
if transport == "auto":
self._type_hint = AutoTransportType()
elif transport == "nccl":
self._type_hint = TorchTensorType(
transport=transport,
_static_shape=_static_shape,
_direct_return=_direct_return,
)
else:
if not isinstance(transport, Communicator):
raise ValueError(
"transport must be 'auto', 'nccl' or a Communicator type"
)
self._type_hint = TorchTensorType(
transport=transport,
_static_shape=_static_shape,
_direct_return=_direct_return,
)
return self

@property
def type_hint(self) -> ChannelOutputType:
return self._type_hint

@type_hint.setter
def type_hint(self, type_hint: ChannelOutputType) -> None:
if isinstance(self._type_hint, AutoTransportType):
self._original_type_hint = self._type_hint
self._type_hint = type_hint

def get_args(self) -> Tuple[Any]:
"""Return the tuple of arguments for this node."""

Expand Down Expand Up @@ -558,7 +595,7 @@ def _copy(
new_args, new_kwargs, new_options, new_other_args_to_resolve
)
instance._stable_uuid = self._stable_uuid
instance = instance.with_type_hint(self.type_hint)
instance._type_hint = copy.deepcopy(self._type_hint)
return instance

def __getstate__(self):
Expand Down
11 changes: 5 additions & 6 deletions python/ray/dag/tests/experimental/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
get_or_create_event_loop,
)
from ray.dag import DAGContext
from ray.experimental.channel.torch_tensor_type import TorchTensorType
from ray._private.test_utils import (
run_string_as_driver_nonblocking,
wait_for_pid_to_exit,
Expand Down Expand Up @@ -891,13 +890,13 @@ def test_multi_args_and_torch_type(self, ray_start_regular):
a2 = Actor.remote(0)
c = Collector.remote()
with InputNode() as i:
i.with_type_hint(TorchTensorType())
i.with_tensor_transport()
branch1 = a1.echo.bind(i[0])
branch1.with_type_hint(TorchTensorType())
branch1.with_tensor_transport()
branch2 = a2.echo.bind(i[1])
branch2.with_type_hint(TorchTensorType())
branch2.with_tensor_transport()
dag = c.collect_two.bind(branch2, branch1)
dag.with_type_hint(TorchTensorType())
dag.with_tensor_transport()

compiled_dag = dag.experimental_compile()

Expand Down Expand Up @@ -2850,7 +2849,7 @@ def __init__(self):
inp,
self._base.generate_torch_tensor.bind(
inp,
).with_type_hint(TorchTensorType()),
).with_tensor_transport(),
)
self._cdag = dag.experimental_compile()

Expand Down
19 changes: 7 additions & 12 deletions python/ray/dag/tests/experimental/test_collective_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
check_nccl_group_teardown,
)
from ray.dag import InputNode, MultiOutputNode
from ray.experimental.channel.torch_tensor_type import TorchTensorType
from ray.tests.conftest import * # noqa

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -148,10 +147,10 @@ def test_comm_deduplicate_p2p_and_collective(ray_start_regular, monkeypatch):
recvs = [
# Each of the 2 workers receives from the other.
workers[0].recv.bind(
collectives[1].with_type_hint(TorchTensorType(transport="nccl"))
collectives[1].with_tensor_transport(transport="nccl")
),
workers[1].recv.bind(
collectives[0].with_type_hint(TorchTensorType(transport="nccl"))
collectives[0].with_tensor_transport(transport="nccl")
),
]
dag = MultiOutputNode(recvs)
Expand All @@ -170,7 +169,7 @@ def test_comm_deduplicate_p2p_and_collective(ray_start_regular, monkeypatch):
collectives = collective.allreduce.bind(computes)
# Sender is workers[0] and receiver is workers[1].
dag = workers[1].recv.bind(
collectives[0].with_type_hint(TorchTensorType(transport="nccl"))
collectives[0].with_tensor_transport(transport="nccl")
)
dag = MultiOutputNode([dag, collectives[1]])

Expand Down Expand Up @@ -202,7 +201,7 @@ def test_custom_comm_deduplicate(ray_start_regular, monkeypatch):
collectives = collective.allreduce.bind(computes, transport=comm)
collectives = collective.allreduce.bind(collectives)
dag = workers[0].recv.bind(
collectives[1].with_type_hint(TorchTensorType(transport="nccl"))
collectives[1].with_tensor_transport(transport="nccl")
)
dag = MultiOutputNode([dag, collectives[0]])

Expand All @@ -220,9 +219,7 @@ def test_custom_comm_deduplicate(ray_start_regular, monkeypatch):
computes = [worker.return_tensor.bind(inp) for worker in workers]
collectives = collective.allreduce.bind(computes)
collectives = collective.allreduce.bind(collectives)
dag = workers[0].recv.bind(
collectives[1].with_type_hint(TorchTensorType(transport=comm))
)
dag = workers[0].recv.bind(collectives[1].with_tensor_transport(transport=comm))
dag = MultiOutputNode([dag, collectives[0]])

compiled_dag, mock_nccl_group_set = check_nccl_group_init(
Expand Down Expand Up @@ -255,9 +252,7 @@ def test_custom_comm_init_teardown(ray_start_regular, monkeypatch):
with InputNode() as inp:
tensors = [worker.return_tensor.bind(inp) for worker in workers]
allreduce = collective.allreduce.bind(tensors, transport=comm)
dag = workers[0].recv.bind(
allreduce[1].with_type_hint(TorchTensorType(transport=comm))
)
dag = workers[0].recv.bind(allreduce[1].with_tensor_transport(transport=comm))
dag = MultiOutputNode([dag, allreduce[0]])

compiled_dag, mock_nccl_group_set = check_nccl_group_init(
Expand All @@ -278,7 +273,7 @@ def test_custom_comm_init_teardown(ray_start_regular, monkeypatch):
allreduce1 = collective.allreduce.bind(tensors, transport=comm_1)
allreduce2 = collective.allreduce.bind(allreduce1, transport=comm_2)
dag = workers[0].recv.bind(
allreduce2[1].with_type_hint(TorchTensorType(transport=comm_3))
allreduce2[1].with_tensor_transport(transport=comm_3)
)
dag = MultiOutputNode([dag, allreduce2[0]])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import ray
import ray.cluster_utils
from ray.exceptions import RayChannelError
from ray.experimental.channel.torch_tensor_type import TorchTensorType
from ray.experimental.channel.cpu_communicator import CPUCommunicator
from ray.dag import InputNode
import ray.experimental.collective as collective
Expand Down Expand Up @@ -83,7 +82,7 @@ def test_p2p_basic(ray_start_cluster):

with InputNode() as inp:
dag = sender.send.bind(inp.shape, inp.dtype, inp[0])
dag = dag.with_type_hint(TorchTensorType(transport=cpu_group))
dag = dag.with_tensor_transport(transport=cpu_group)
dag = receiver.recv.bind(dag)

compiled_dag = dag.experimental_compile()
Expand Down Expand Up @@ -273,9 +272,9 @@ def test_allreduce_scheduling(ray_start_cluster):
x = workers[0].send.bind(shape, dtype, inp)
y = workers[1].send.bind(shape, dtype, inp)

# Tensor to be sent from workes[0] to workers[1].
# Tensor to be sent from workers[0] to workers[1].
t = workers[0].send.bind(shape, dtype, inp)
t.with_type_hint(TorchTensorType(transport=cpu_group))
t = t.with_tensor_transport(transport=cpu_group)

collectives = collective.allreduce.bind([x, y], transport=cpu_group)
recv = workers[1].recv.bind(t)
Expand Down
Loading

0 comments on commit 99eb84b

Please sign in to comment.