Skip to content

Commit

Permalink
[core][compiled graphs] Introduce with_tensor_transport API (#49753)
Browse files Browse the repository at this point in the history
Signed-off-by: Rui Qiao <[email protected]>
  • Loading branch information
ruisearch42 authored Jan 22, 2025
1 parent dc57f42 commit d9898c3
Show file tree
Hide file tree
Showing 11 changed files with 429 additions and 122 deletions.
93 changes: 81 additions & 12 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
import uuid
import traceback

from ray.experimental.channel.auto_transport_type import (
AutoTransportType,
TypeHintResolver,
)
import ray.exceptions
from ray.dag.dag_operation_future import GPUFuture, DAGOperationFuture, ResolvedFuture
from ray.experimental.channel.cached_channel import CachedChannel
Expand Down Expand Up @@ -375,7 +379,9 @@ def __init__(self, idx: int, dag_node: "ray.dag.DAGNode"):
# corresponding value from `args` or `kwargs` in the DAG's input.
self.output_channels: List[ChannelInterface] = []
self.output_idxs: List[Optional[Union[int, str]]] = []
self.arg_type_hints: List["ChannelOutputType"] = []
# The DAGNodes that are arguments to this task.
# This is used for lazy resolution of the arguments' type hints.
self.arg_nodes: List["ray.dag.DAGNode"] = []
# idxs of possible ClassMethodOutputNodes if they exist, used for visualization
self.output_node_idxs: List[int] = []

Expand All @@ -391,6 +397,10 @@ def kwargs(self) -> Dict[str, Any]:
def num_readers(self) -> int:
return len(self.downstream_task_idxs)

@property
def arg_type_hints(self) -> List["ChannelOutputType"]:
return [arg_node.type_hint for arg_node in self.arg_nodes]

def __str__(self) -> str:
return f"""
Node: {self.dag_node}
Expand Down Expand Up @@ -890,6 +900,9 @@ def __init__(
self.actor_to_tasks: Dict[
"ray.actor.ActorHandle", List["CompiledTask"]
] = defaultdict(list)
# Mapping from actor handle to its GPU IDs.
# This is used for type hint resolution for with_tensor_transport("auto").
self.actor_to_gpu_ids: Dict["ray.actor.ActorHandle", List[str]] = {}
self.actor_to_executable_tasks: Dict[
"ray.actor.ActorHandle", List["ExecutableTask"]
] = {}
Expand All @@ -899,7 +912,8 @@ def __init__(
"ray.actor.ActorHandle", List[_DAGNodeOperation]
] = defaultdict(list)
# Mapping from the actor handle to the node ID that the actor is on.
self.actor_to_node_id: Dict["ray.actor.ActorHandle", str] = {}
# A None actor handle means the actor is the driver.
self.actor_to_node_id: Dict[Optional["ray.actor.ActorHandle"], str] = {}

# This is set to true when type hint of `transport="nccl"` is used.
self._use_default_nccl_group = False
Expand Down Expand Up @@ -1043,6 +1057,10 @@ def _preprocess(self) -> None:
# Collect the set of InputNode keys bound to DAG node args.
input_positional_args: Set[int] = set()
input_kwargs: Set[str] = set()
# Set of tasks with annotation of with_tensor_transport("auto").
# These only correspond to ClassMethodNodes, but not InputNodes
# or InputAttributeNodes.
auto_transport_tasks: Set["CompiledTask"] = set()

# For each task node, set its upstream and downstream task nodes.
# Also collect the set of tasks that produce torch.tensors.
Expand Down Expand Up @@ -1070,6 +1088,14 @@ def _preprocess(self) -> None:
"that is already created with Actor.remote()"
)

if actor_handle not in self.actor_to_gpu_ids:
self.actor_to_gpu_ids[actor_handle] = CompiledDAG._get_gpu_ids(
actor_handle
)

if isinstance(dag_node.type_hint, AutoTransportType):
auto_transport_tasks.add(task)

# Collect actors for NCCL P2P methods.
if dag_node.type_hint.requires_nccl():
nccl_actors_p2p.add(actor_handle)
Expand Down Expand Up @@ -1109,17 +1135,23 @@ def _preprocess(self) -> None:
"supported for NCCL collective operations. Please set "
"overlap_gpu_communication=False."
)
elif isinstance(dag_node, InputNode):
elif isinstance(dag_node, InputNode) or isinstance(
dag_node, InputAttributeNode
):
if dag_node.type_hint.requires_nccl():
raise ValueError(
"DAG inputs cannot be transferred via NCCL because "
"the driver cannot participate in the NCCL group"
)
if isinstance(dag_node.type_hint, AutoTransportType):
# Currently driver on GPU is not supported, so we always
# use shared memory to transfer tensors.
dag_node.type_hint = TorchTensorType()

if type(dag_node.type_hint) is ChannelOutputType:
# No type hint specified by the user. Replace
# with the default type hint for this DAG.
dag_node.with_type_hint(self._default_type_hint)
dag_node.type_hint = self._default_type_hint

for _, val in task.kwargs.items():
if isinstance(val, DAGNode):
Expand All @@ -1140,8 +1172,9 @@ def _preprocess(self) -> None:
):
downstream_actor_handle = dag_node._get_actor_handle()

# Add the type hint of the upstream node to the task.
task.arg_type_hints.append(upstream_task.dag_node.type_hint)
# Add upstream node as the argument nodes of this task, whose
# type hints may be updated when resolved lazily.
task.arg_nodes.append(upstream_task.dag_node)

if isinstance(upstream_task.dag_node, InputAttributeNode):
# Record all of the keys used to index the InputNode.
Expand Down Expand Up @@ -1208,6 +1241,29 @@ def _preprocess(self) -> None:
f"the MultiOutputNode."
)

type_hint_resolver = TypeHintResolver(self.actor_to_gpu_ids)
# Resolve AutoChannelType type hints and track the actors that use NCCL.
# This is needed so that the NCCL group can be initialized for these
# actors that use NCCL.
for task in auto_transport_tasks:
writer = task.dag_node._get_actor_handle()
readers = task.downstream_task_idxs.values()
writer_and_node = (writer, self._get_node_id(writer))
reader_and_node_list = [
(reader, self._get_node_id(reader)) for reader in readers
]
# Update the type hint to the resolved one. This is needed because
# the resolved type hint's `register_custom_serializer` will be called
# in preparation for channel I/O.
task.dag_node.type_hint = type_hint_resolver.resolve(
task.dag_node.type_hint,
writer_and_node,
reader_and_node_list,
)
if task.dag_node.type_hint.requires_nccl():
nccl_actors_p2p.add(writer)
nccl_actors_p2p.update(readers)

nccl_actors_p2p = list(nccl_actors_p2p)
if None in nccl_actors_p2p:
raise ValueError("Driver cannot participate in the NCCL group.")
Expand Down Expand Up @@ -1292,27 +1348,39 @@ def _preprocess(self) -> None:
self._input_num_positional_args = max(input_positional_args) + 1
self._input_kwargs = tuple(input_kwargs)

def _get_node_id(self, actor_handle: "ray.actor.ActorHandle") -> str:
@staticmethod
def _get_gpu_ids(actor_handle: "ray.actor.ActorHandle") -> List[str]:
"""
Get the GPU IDs of an actor handle.
"""
accelerator_ids = ray.get(
actor_handle.__ray_call__.remote(
lambda self: ray.get_runtime_context().get_accelerator_ids()
)
)
return accelerator_ids.get("GPU", [])

def _get_node_id(self, actor_handle: Optional["ray.actor.ActorHandle"]) -> str:
"""
Get the node ID of an actor handle and cache it.
Args:
actor_handle: The actor handle.
actor_handle: The actor handle, or None if the actor handle is the
driver.
Returns:
The node ID of the actor handle.
The node ID of the actor handle or driver.
"""
if actor_handle in self.actor_to_node_id:
return self.actor_to_node_id[actor_handle]
node_id = None
if actor_handle == self._proxy_actor:
if actor_handle == self._proxy_actor or actor_handle is None:
node_id = ray.get_runtime_context().get_node_id()
else:
node_id = ray.get(
actor_handle.__ray_call__.remote(
lambda self: ray.get_runtime_context().get_node_id()
)
)
assert node_id is not None
self.actor_to_node_id[actor_handle] = node_id
return node_id

Expand Down Expand Up @@ -1492,10 +1560,11 @@ def _get_or_compile(
reader_and_node_list = list(
input_node_to_reader_and_node_set[input_dag_node]
)

output_channel = do_allocate_channel(
self,
reader_and_node_list,
type_hint,
input_dag_node.type_hint,
None,
)
task.output_channels.append(output_channel)
Expand Down
49 changes: 45 additions & 4 deletions python/ray/dag/dag_node.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import copy
from ray.experimental.channel.auto_transport_type import AutoTransportType
from ray.experimental.channel.torch_tensor_type import TorchTensorType
import ray
from ray.dag.base import DAGNodeBase
from ray.dag.py_obj_scanner import _PyObjScanner
from ray.util.annotations import DeveloperAPI
import copy

from itertools import chain

Expand All @@ -21,6 +23,7 @@

from ray.dag.compiled_dag_node import build_compiled_dag_from_ray_dag
from ray.experimental.channel import ChannelOutputType
from ray.experimental.channel.communicator import Communicator

T = TypeVar("T")

Expand Down Expand Up @@ -79,6 +82,12 @@ def __init__(
self.cache_from_last_execute = {}

self._type_hint: ChannelOutputType = ChannelOutputType()

# If the original type hint is an AutoTransportType, we make a copy
# here when it is resolved to the actual type, as additional debugging
# information. Otherwise, it is None.
self._original_type_hint: Optional[ChannelOutputType] = None

# Whether this node calls `experimental_compile`.
self.is_cgraph_output_node = False

Expand Down Expand Up @@ -129,14 +138,45 @@ def _collect_upstream_nodes(self) -> List["DAGNode"]:
upstream_node._downstream_nodes.append(self)
return upstream_nodes

def with_type_hint(self, typ: ChannelOutputType):
self._type_hint = copy.deepcopy(typ)
def with_tensor_transport(
self,
transport: Optional[Union[str, Communicator]] = "auto",
_static_shape: bool = False,
_direct_return: bool = False,
):
if transport == "auto":
self._type_hint = AutoTransportType(
_static_shape=_static_shape,
_direct_return=_direct_return,
)
elif transport == "nccl":
self._type_hint = TorchTensorType(
transport=transport,
_static_shape=_static_shape,
_direct_return=_direct_return,
)
else:
if not isinstance(transport, Communicator):
raise ValueError(
"transport must be 'auto', 'nccl' or a Communicator type"
)
self._type_hint = TorchTensorType(
transport=transport,
_static_shape=_static_shape,
_direct_return=_direct_return,
)
return self

@property
def type_hint(self) -> ChannelOutputType:
return self._type_hint

@type_hint.setter
def type_hint(self, type_hint: ChannelOutputType) -> None:
if isinstance(self._type_hint, AutoTransportType):
self._original_type_hint = self._type_hint
self._type_hint = type_hint

def get_args(self) -> Tuple[Any]:
"""Return the tuple of arguments for this node."""

Expand Down Expand Up @@ -558,7 +598,8 @@ def _copy(
new_args, new_kwargs, new_options, new_other_args_to_resolve
)
instance._stable_uuid = self._stable_uuid
instance = instance.with_type_hint(self.type_hint)
instance._type_hint = copy.deepcopy(self._type_hint)
instance._original_type_hint = copy.deepcopy(self._original_type_hint)
return instance

def __getstate__(self):
Expand Down
11 changes: 5 additions & 6 deletions python/ray/dag/tests/experimental/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
get_or_create_event_loop,
)
from ray.dag import DAGContext
from ray.experimental.channel.torch_tensor_type import TorchTensorType
from ray._private.test_utils import (
run_string_as_driver_nonblocking,
wait_for_pid_to_exit,
Expand Down Expand Up @@ -891,13 +890,13 @@ def test_multi_args_and_torch_type(self, ray_start_regular):
a2 = Actor.remote(0)
c = Collector.remote()
with InputNode() as i:
i.with_type_hint(TorchTensorType())
i.with_tensor_transport()
branch1 = a1.echo.bind(i[0])
branch1.with_type_hint(TorchTensorType())
branch1.with_tensor_transport()
branch2 = a2.echo.bind(i[1])
branch2.with_type_hint(TorchTensorType())
branch2.with_tensor_transport()
dag = c.collect_two.bind(branch2, branch1)
dag.with_type_hint(TorchTensorType())
dag.with_tensor_transport()

compiled_dag = dag.experimental_compile()

Expand Down Expand Up @@ -2850,7 +2849,7 @@ def __init__(self):
inp,
self._base.generate_torch_tensor.bind(
inp,
).with_type_hint(TorchTensorType()),
).with_tensor_transport(),
)
self._cdag = dag.experimental_compile()

Expand Down
Loading

0 comments on commit d9898c3

Please sign in to comment.