Skip to content

Commit

Permalink
Enforce comm_abc.Comm into Communicator
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianDeconinck committed Dec 22, 2024
1 parent f8cc2ce commit 7ad271f
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 12 deletions.
30 changes: 21 additions & 9 deletions ndsl/comm/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import ndsl.constants as constants
from ndsl.buffer import array_buffer, device_synchronize, recv_buffer, send_buffer
from ndsl.comm.boundary import Boundary
from ndsl.comm.comm_abc import Comm as CommABC
from ndsl.comm.comm_abc import ReductionOperator
from ndsl.comm.partitioner import CubedSpherePartitioner, Partitioner, TilePartitioner
from ndsl.halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater
Expand Down Expand Up @@ -45,7 +46,11 @@ def to_numpy(array, dtype=None) -> np.ndarray:

class Communicator(abc.ABC):
def __init__(
self, comm, partitioner, force_cpu: bool = False, timer: Optional[Timer] = None
self,
comm: CommABC,
partitioner,
force_cpu: bool = False,
timer: Optional[Timer] = None,
):
self.comm = comm
self.partitioner: Partitioner = partitioner
Expand All @@ -62,7 +67,7 @@ def tile(self) -> "TileCommunicator":
@abc.abstractmethod
def from_layout(
cls,
comm,
comm: CommABC,
layout: Tuple[int, int],
force_cpu: bool = False,
timer: Optional[Timer] = None,
Expand Down Expand Up @@ -138,15 +143,17 @@ def all_reduce_per_element(
self.comm.Allreduce(input_quantity.data, output_quantity.data, op)

def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs):
with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer(
numpy_module.zeros, recvbuf
) as recv:
with (
send_buffer(numpy_module.zeros, sendbuf) as send,
recv_buffer(numpy_module.zeros, recvbuf) as recv,
):
self.comm.Scatter(send, recv, **kwargs)

def _Gather(self, numpy_module, sendbuf, recvbuf, **kwargs):
with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer(
numpy_module.zeros, recvbuf
) as recv:
with (
send_buffer(numpy_module.zeros, sendbuf) as send,
recv_buffer(numpy_module.zeros, recvbuf) as recv,
):
self.comm.Gather(send, recv, **kwargs)

def scatter(
Expand Down Expand Up @@ -753,7 +760,7 @@ class CubedSphereCommunicator(Communicator):

def __init__(
self,
comm,
comm: CommABC,
partitioner: CubedSpherePartitioner,
force_cpu: bool = False,
timer: Optional[Timer] = None,
Expand All @@ -766,6 +773,11 @@ def __init__(
force_cpu: Force all communication to go through central memory.
timer: Time communication operations.
"""
if not issubclass(type(comm), CommABC):
raise TypeError(
"Communictor needs to be instantiated with communication subsytem"
f" derived from `comm_abc.Comm`, got {type(comm)}."
)
if comm.Get_size() != partitioner.total_ranks:
raise ValueError(
f"was given a partitioner for {partitioner.total_ranks} ranks but a "
Expand Down
3 changes: 2 additions & 1 deletion tests/mpi/test_mpi_all_reduce_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
TilePartitioner,
)
from ndsl.comm.comm_abc import ReductionOperator
from ndsl.comm.mpi import MPIComm
from ndsl.dsl.typing import Float
from tests.mpi.mpi_comm import MPI

Expand Down Expand Up @@ -41,7 +42,7 @@ def cube_partitioner(tile_partitioner):
@pytest.fixture()
def communicator(cube_partitioner):
return CubedSphereCommunicator(
comm=MPI.COMM_WORLD,
comm=MPIComm(),
partitioner=cube_partitioner,
)

Expand Down
5 changes: 3 additions & 2 deletions tests/mpi/test_mpi_halo_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Quantity,
TilePartitioner,
)
from ndsl.comm.mpi import MPIComm
from ndsl.comm._boundary_utils import get_boundary_slice
from ndsl.constants import (
BOUNDARY_TYPES,
Expand Down Expand Up @@ -39,7 +40,7 @@ def layout():
if MPI is not None:
size = MPI.COMM_WORLD.Get_size()
ranks_per_tile = size // 6
ranks_per_edge = int(ranks_per_tile ** 0.5)
ranks_per_edge = int(ranks_per_tile**0.5)
return (ranks_per_edge, ranks_per_edge)
else:
return (1, 1)
Expand Down Expand Up @@ -176,7 +177,7 @@ def extent(n_points, dims, nz, ny, nx):
@pytest.fixture()
def communicator(cube_partitioner):
return CubedSphereCommunicator(
comm=MPI.COMM_WORLD,
comm=MPIComm(),
partitioner=cube_partitioner,
)

Expand Down

0 comments on commit 7ad271f

Please sign in to comment.