Skip to content

Commit

Permalink
[core][compiled graphs] Use new with_tensor_transport API
Browse files Browse the repository at this point in the history
Signed-off-by: Rui Qiao <[email protected]>
  • Loading branch information
ruisearch42 committed Jan 10, 2025
1 parent 29e14f7 commit 7e4aedd
Show file tree
Hide file tree
Showing 10 changed files with 159 additions and 124 deletions.
19 changes: 5 additions & 14 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,9 +1041,9 @@ def _preprocess(self) -> None:
)

# Collect actors for NCCL P2P methods.
if dag_node.type_hint.requires_nccl():
if dag_node.requires_nccl():
nccl_actors_p2p.add(actor_handle)
custom_communicator = dag_node.type_hint.get_custom_communicator()
custom_communicator = dag_node.custom_communicator
mixed_nccl_group_error_message = (
"Compiled Graphs do not support mixed usage of "
"type hints of default NCCL group "
Expand Down Expand Up @@ -1080,17 +1080,12 @@ def _preprocess(self) -> None:
"overlap_gpu_communication=False."
)
elif isinstance(dag_node, InputNode):
if dag_node.type_hint.requires_nccl():
if dag_node.requires_nccl():
raise ValueError(
"DAG inputs cannot be transferred via NCCL because "
"the driver cannot participate in the NCCL group"
)

if type(dag_node.type_hint) == ChannelOutputType:
# No type hint specified by the user. Replace
# with the default type hint for this DAG.
dag_node.with_type_hint(self._default_type_hint)

for _, val in task.kwargs.items():
if isinstance(val, DAGNode):
raise ValueError(
Expand Down Expand Up @@ -1149,7 +1144,7 @@ def _preprocess(self) -> None:

upstream_task.downstream_task_idxs[task_idx] = downstream_actor_handle

if upstream_task.dag_node.type_hint.requires_nccl():
if upstream_task.dag_node.requires_nccl():
# Add all readers to the NCCL actors of P2P.
nccl_actors_p2p.add(downstream_actor_handle)

Expand Down Expand Up @@ -2714,11 +2709,7 @@ def visualize(

# Add the node to the graph with attributes
dot.node(str(idx), label, shape=shape, style=style, fillcolor=fillcolor)
channel_type_str = (
type(dag_node.type_hint).__name__
if dag_node.type_hint
else "UnknownType"
) + "\n"
channel_type_str = dag_node.channel_type_str + "\n"

# This logic is built on the assumption that there will only be multiple
# output channels if the task has multiple returns
Expand Down
78 changes: 73 additions & 5 deletions python/ray/dag/dag_node.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from ray.experimental.channel.shared_memory_channel import SharedMemoryType
from ray.experimental.channel.torch_tensor_type import TorchTensorType
import ray
from ray.dag.base import DAGNodeBase
from ray.dag.py_obj_scanner import _PyObjScanner
from ray.util.annotations import DeveloperAPI
import copy

from itertools import chain

Expand All @@ -21,6 +22,7 @@

from ray.dag.compiled_dag_node import build_compiled_dag_from_ray_dag
from ray.experimental.channel import ChannelOutputType
from ray.experimental.channel.communicator import Communicator

T = TypeVar("T")

Expand Down Expand Up @@ -78,7 +80,11 @@ def __init__(
# Cached values from last call to execute()
self.cache_from_last_execute = {}

self._type_hint: ChannelOutputType = ChannelOutputType()
self._type_hint: ChannelOutputType = SharedMemoryType()
self._tensor_transport: Optional[Union[str, Communicator]] = None
self._tensor_static_shape: Optional[bool] = None
self._static_tensor_schema: Optional[bool] = None

# Whether this node calls `experimental_compile`.
self.is_cgraph_output_node = False

Expand Down Expand Up @@ -129,14 +135,71 @@ def _collect_upstream_nodes(self) -> List["DAGNode"]:
upstream_node._downstream_nodes.append(self)
return upstream_nodes

def with_type_hint(self, typ: ChannelOutputType):
self._type_hint = copy.deepcopy(typ)
def with_tensor_transport(
self,
transport: Optional[Union[str, Communicator]] = "auto",
_static_shape: bool = True,
_direct_return: bool = False,
):
if transport == "auto":
self._type_hint = SharedMemoryType()
elif transport == "nccl":
self._type_hint = TorchTensorType(
transport=transport,
_static_shape=_static_shape,
_direct_return=_direct_return,
)
else:
self._type_hint = TorchTensorType(
transport=transport,
_static_shape=_static_shape,
_direct_return=_direct_return,
)
self._tensor_transport = transport
self._tensor_static_shape = _static_shape
self._tensor_direct_return = _direct_return
return self

@property
def type_hint(self) -> ChannelOutputType:
return self._type_hint

def requires_nccl(self) -> bool:
return self._tensor_transport == "nccl" or isinstance(
self._tensor_transport, Communicator
)

@property
def custom_communicator(self) -> Optional[Communicator]:
return (
self._tensor_transport
if isinstance(self._tensor_transport, Communicator)
else None
)

@property
def tensor_transport(self) -> Optional[Union[str, Communicator]]:
return self._tensor_transport

@property
def static_tensor_shape(self) -> Optional[bool]:
return self._tensor_static_shape

@property
def static_tensor_schema(self) -> Optional[bool]:
return self._static_tensor_schema

@property
def channel_type_str(self) -> str:
if self._tensor_transport is None or self._tensor_transport == "auto":
return "SharedMemoryChannel"
elif self._tensor_transport == "nccl" or isinstance(
self._tensor_transport, Communicator
):
return "NcclChannel"
else:
return "UnknownChannel"

def get_args(self) -> Tuple[Any]:
"""Return the tuple of arguments for this node."""

Expand Down Expand Up @@ -575,7 +638,12 @@ def _copy(
new_args, new_kwargs, new_options, new_other_args_to_resolve
)
instance._stable_uuid = self._stable_uuid
instance = instance.with_type_hint(self.type_hint)
instance = instance.with_tensor_transport(
transport=self.tensor_transport,
has_static_shape=self.tensor_has_static_shape,
has_static_schema=self.tensor_has_static_schema,
direct_return=self.tensor_direct_return,
)
return instance

def __getstate__(self):
Expand Down
11 changes: 5 additions & 6 deletions python/ray/dag/tests/experimental/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
get_or_create_event_loop,
)
from ray.dag import DAGContext
from ray.experimental.channel.torch_tensor_type import TorchTensorType
from ray._private.test_utils import (
run_string_as_driver_nonblocking,
wait_for_pid_to_exit,
Expand Down Expand Up @@ -830,13 +829,13 @@ def test_multi_args_and_torch_type(self, ray_start_regular):
a2 = Actor.remote(0)
c = Collector.remote()
with InputNode() as i:
i.with_type_hint(TorchTensorType())
i.with_tensor_transport()
branch1 = a1.echo.bind(i[0])
branch1.with_type_hint(TorchTensorType())
branch1.with_tensor_transport()
branch2 = a2.echo.bind(i[1])
branch2.with_type_hint(TorchTensorType())
branch2.with_tensor_transport()
dag = c.collect_two.bind(branch2, branch1)
dag.with_type_hint(TorchTensorType())
dag.with_tensor_transport()

compiled_dag = dag.experimental_compile()

Expand Down Expand Up @@ -2508,7 +2507,7 @@ def __init__(self):
inp,
self._base.generate_torch_tensor.bind(
inp,
).with_type_hint(TorchTensorType()),
).with_tensor_transport(),
)
self._cdag = dag.experimental_compile()

Expand Down
19 changes: 7 additions & 12 deletions python/ray/dag/tests/experimental/test_collective_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
check_nccl_group_teardown,
)
from ray.dag import InputNode, MultiOutputNode
from ray.experimental.channel.torch_tensor_type import TorchTensorType
from ray.tests.conftest import * # noqa

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -148,10 +147,10 @@ def test_comm_deduplicate_p2p_and_collective(ray_start_regular, monkeypatch):
recvs = [
# Each of the 2 workers receives from the other.
workers[0].recv.bind(
collectives[1].with_type_hint(TorchTensorType(transport="nccl"))
collectives[1].with_tensor_transport(transport="nccl")
),
workers[1].recv.bind(
collectives[0].with_type_hint(TorchTensorType(transport="nccl"))
collectives[0].with_tensor_transport(transport="nccl")
),
]
dag = MultiOutputNode(recvs)
Expand All @@ -170,7 +169,7 @@ def test_comm_deduplicate_p2p_and_collective(ray_start_regular, monkeypatch):
collectives = collective.allreduce.bind(computes)
# Sender is workers[0] and receiver is workers[1].
dag = workers[1].recv.bind(
collectives[0].with_type_hint(TorchTensorType(transport="nccl"))
collectives[0].with_tensor_transport(transport="nccl")
)
dag = MultiOutputNode([dag, collectives[1]])

Expand Down Expand Up @@ -202,7 +201,7 @@ def test_custom_comm_deduplicate(ray_start_regular, monkeypatch):
collectives = collective.allreduce.bind(computes, transport=comm)
collectives = collective.allreduce.bind(collectives)
dag = workers[0].recv.bind(
collectives[1].with_type_hint(TorchTensorType(transport="nccl"))
collectives[1].with_tensor_transport(transport="nccl")
)
dag = MultiOutputNode([dag, collectives[0]])

Expand All @@ -220,9 +219,7 @@ def test_custom_comm_deduplicate(ray_start_regular, monkeypatch):
computes = [worker.return_tensor.bind(inp) for worker in workers]
collectives = collective.allreduce.bind(computes)
collectives = collective.allreduce.bind(collectives)
dag = workers[0].recv.bind(
collectives[1].with_type_hint(TorchTensorType(transport=comm))
)
dag = workers[0].recv.bind(collectives[1].with_tensor_transport(transport=comm))
dag = MultiOutputNode([dag, collectives[0]])

compiled_dag, mock_nccl_group_set = check_nccl_group_init(
Expand Down Expand Up @@ -255,9 +252,7 @@ def test_custom_comm_init_teardown(ray_start_regular, monkeypatch):
with InputNode() as inp:
tensors = [worker.return_tensor.bind(inp) for worker in workers]
allreduce = collective.allreduce.bind(tensors, transport=comm)
dag = workers[0].recv.bind(
allreduce[1].with_type_hint(TorchTensorType(transport=comm))
)
dag = workers[0].recv.bind(allreduce[1].with_tensor_transport(transport=comm))
dag = MultiOutputNode([dag, allreduce[0]])

compiled_dag, mock_nccl_group_set = check_nccl_group_init(
Expand All @@ -278,7 +273,7 @@ def test_custom_comm_init_teardown(ray_start_regular, monkeypatch):
allreduce1 = collective.allreduce.bind(tensors, transport=comm_1)
allreduce2 = collective.allreduce.bind(allreduce1, transport=comm_2)
dag = workers[0].recv.bind(
allreduce2[1].with_type_hint(TorchTensorType(transport=comm_3))
allreduce2[1].with_tensor_transport(transport=comm_3)
)
dag = MultiOutputNode([dag, allreduce2[0]])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import ray
import ray.cluster_utils
from ray.exceptions import RayChannelError
from ray.experimental.channel.torch_tensor_type import TorchTensorType
from ray.experimental.channel.cpu_communicator import CPUCommunicator
from ray.dag import InputNode
import ray.experimental.collective as collective
Expand Down Expand Up @@ -83,7 +82,7 @@ def test_p2p_basic(ray_start_cluster):

with InputNode() as inp:
dag = sender.send.bind(inp.shape, inp.dtype, inp[0])
dag = dag.with_type_hint(TorchTensorType(transport=cpu_group))
dag = dag.with_tensor_transport(transport=cpu_group)
dag = receiver.recv.bind(dag)

compiled_dag = dag.experimental_compile()
Expand Down Expand Up @@ -273,9 +272,9 @@ def test_allreduce_scheduling(ray_start_cluster):
x = workers[0].send.bind(shape, dtype, inp)
y = workers[1].send.bind(shape, dtype, inp)

# Tensor to be sent from workes[0] to workers[1].
# Tensor to be sent from workers[0] to workers[1].
t = workers[0].send.bind(shape, dtype, inp)
t.with_type_hint(TorchTensorType(transport=cpu_group))
t = t.with_tensor_transport(transport=cpu_group)

collectives = collective.allreduce.bind([x, y], transport=cpu_group)
recv = workers[1].recv.bind(t)
Expand Down
Loading

0 comments on commit 7e4aedd

Please sign in to comment.