diff --git a/ndsl/comm/caching_comm.py b/ndsl/comm/caching_comm.py index 36587d73..42f92ea2 100644 --- a/ndsl/comm/caching_comm.py +++ b/ndsl/comm/caching_comm.py @@ -5,7 +5,7 @@ import numpy as np -from ndsl.comm.comm_abc import Comm, Request +from ndsl.comm.comm_abc import Comm, ReductionOperator, Request T = TypeVar("T") @@ -147,9 +147,12 @@ def Split(self, color, key) -> "CachingCommReader": new_data = self._data.get_split() return CachingCommReader(data=new_data) - def allreduce(self, sendobj, op=None) -> Any: + def allreduce(self, sendobj, op: Optional[ReductionOperator] = None) -> Any: return self._data.get_generic_obj() + def Allreduce(self, sendobj, recvobj, op: ReductionOperator) -> Any: + raise NotImplementedError("CachingCommReader.Allreduce") + @classmethod def load(cls, file: BinaryIO) -> "CachingCommReader": data = CachingCommData.load(file) @@ -229,7 +232,10 @@ def Split(self, color, key) -> "CachingCommWriter": def dump(self, file: BinaryIO): self._data.dump(file) - def allreduce(self, sendobj, op=None) -> Any: + def allreduce(self, sendobj, op: Optional[ReductionOperator] = None) -> Any: result = self._comm.allreduce(sendobj, op) self._data.generic_obj_buffers.append(copy.deepcopy(result)) return result + + def Allreduce(self, sendobj, recvobj, op: ReductionOperator) -> Any: + raise NotImplementedError("CachingCommWriter.Allreduce") diff --git a/ndsl/comm/comm_abc.py b/ndsl/comm/comm_abc.py index 77f56586..45596f1e 100644 --- a/ndsl/comm/comm_abc.py +++ b/ndsl/comm/comm_abc.py @@ -1,10 +1,30 @@ import abc +import enum from typing import List, Optional, TypeVar T = TypeVar("T") +@enum.unique +class ReductionOperator(enum.Enum): + OP_NULL = enum.auto() + MAX = enum.auto() + MIN = enum.auto() + SUM = enum.auto() + PROD = enum.auto() + LAND = enum.auto() + BAND = enum.auto() + LOR = enum.auto() + BOR = enum.auto() + LXOR = enum.auto() + BXOR = enum.auto() + MAXLOC = enum.auto() + MINLOC = enum.auto() + REPLACE = enum.auto() + NO_OP = enum.auto() + + class Request(abc.ABC): @abc.abstractmethod def wait(self): @@ -69,5 +89,12 @@ def Split(self, color, key) -> "Comm": ... @abc.abstractmethod - def allreduce(self, sendobj: T, op=None) -> T: + def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T: + ... + + @abc.abstractmethod + def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T: + ... + + def Allreduce_inplace(self, obj: T, op: ReductionOperator) -> T: ... diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index ff270df5..ba980d19 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -6,6 +6,8 @@ import ndsl.constants as constants from ndsl.buffer import array_buffer, device_synchronize, recv_buffer, send_buffer from ndsl.comm.boundary import Boundary +from ndsl.comm.comm_abc import Comm as CommABC +from ndsl.comm.comm_abc import ReductionOperator from ndsl.comm.partitioner import CubedSpherePartitioner, Partitioner, TilePartitioner from ndsl.halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater from ndsl.performance.timer import NullTimer, Timer @@ -44,7 +46,11 @@ def to_numpy(array, dtype=None) -> np.ndarray: class Communicator(abc.ABC): def __init__( - self, comm, partitioner, force_cpu: bool = False, timer: Optional[Timer] = None + self, + comm: CommABC, + partitioner, + force_cpu: bool = False, + timer: Optional[Timer] = None, ): self.comm = comm self.partitioner: Partitioner = partitioner @@ -61,7 +67,7 @@ def tile(self) -> "TileCommunicator": @abc.abstractmethod def from_layout( cls, - comm, + comm: CommABC, layout: Tuple[int, int], force_cpu: bool = False, timer: Optional[Timer] = None, @@ -93,17 +99,63 @@ def _device_synchronize(): # this is a method so we can profile it separately from other device syncs device_synchronize() + def _create_all_reduce_quantity( + self, input_metadata: QuantityMetadata, input_data + ) -> Quantity: + """Create a Quantity for all_reduce data and metadata""" + all_reduce_quantity = Quantity( + input_data, + dims=input_metadata.dims, + units=input_metadata.units, + origin=input_metadata.origin, + extent=input_metadata.extent, + gt4py_backend=input_metadata.gt4py_backend, + allow_mismatch_float_precision=False, + ) + return all_reduce_quantity + + def all_reduce( + self, + input_quantity: Quantity, + op: ReductionOperator, + output_quantity: Quantity = None, + ): + reduced_quantity_data = self.comm.allreduce(input_quantity.data, op) + if output_quantity is None: + all_reduce_quantity = self._create_all_reduce_quantity( + input_quantity.metadata, reduced_quantity_data + ) + return all_reduce_quantity + else: + if output_quantity.data.shape != input_quantity.data.shape: + raise TypeError("Shapes not matching") + + input_quantity.metadata.duplicate_metadata(output_quantity.metadata) + + output_quantity.data = reduced_quantity_data + + def all_reduce_per_element( + self, + input_quantity: Quantity, + output_quantity: Quantity, + op: ReductionOperator, + ): + self.comm.Allreduce(input_quantity.data, output_quantity.data, op) + + def all_reduce_per_element_in_place( + self, quantity: Quantity, op: ReductionOperator + ): + self.comm.Allreduce_inplace(quantity.data, op) + def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs): - with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( - numpy_module.zeros, recvbuf - ) as recv: - self.comm.Scatter(send, recv, **kwargs) + with send_buffer(numpy_module.zeros, sendbuf) as send: + with recv_buffer(numpy_module.zeros, recvbuf) as recv: + self.comm.Scatter(send, recv, **kwargs) def _Gather(self, numpy_module, sendbuf, recvbuf, **kwargs): - with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( - numpy_module.zeros, recvbuf - ) as recv: - self.comm.Gather(send, recv, **kwargs) + with send_buffer(numpy_module.zeros, sendbuf) as send: + with recv_buffer(numpy_module.zeros, recvbuf) as recv: + self.comm.Gather(send, recv, **kwargs) def scatter( self, @@ -709,7 +761,7 @@ class CubedSphereCommunicator(Communicator): def __init__( self, - comm, + comm: CommABC, partitioner: CubedSpherePartitioner, force_cpu: bool = False, timer: Optional[Timer] = None, @@ -722,6 +774,11 @@ def __init__( force_cpu: Force all communication to go through central memory. timer: Time communication operations. """ + if not issubclass(type(comm), CommABC): + raise TypeError( + "Communictor needs to be instantiated with communication subsytem" + f" derived from `comm_abc.Comm`, got {type(comm)}." + ) if comm.Get_size() != partitioner.total_ranks: raise ValueError( f"was given a partitioner for {partitioner.total_ranks} ranks but a " diff --git a/ndsl/comm/local_comm.py b/ndsl/comm/local_comm.py index 5ebfb47d..1ae10177 100644 --- a/ndsl/comm/local_comm.py +++ b/ndsl/comm/local_comm.py @@ -189,8 +189,14 @@ def Split(self, color, key): self._split_comms[color].append(new_comm) return new_comm - def allreduce(self, sendobj, op=None) -> Any: + def allreduce(self, sendobj, op=None, recvobj=None) -> Any: raise NotImplementedError( - "sendrecv fundamentally cannot be written for LocalComm, " + "allreduce fundamentally cannot be written for LocalComm, " + "as it requires synchronicity" + ) + + def Allreduce(self, sendobj, recvobj, op) -> Any: + raise NotImplementedError( + "Allreduce fundamentally cannot be written for LocalComm, " "as it requires synchronicity" ) diff --git a/ndsl/comm/mpi.py b/ndsl/comm/mpi.py index 6f47c791..6b3ff17f 100644 --- a/ndsl/comm/mpi.py +++ b/ndsl/comm/mpi.py @@ -1,10 +1,11 @@ try: + import mpi4py from mpi4py import MPI except ImportError: MPI = None -from typing import List, Optional, TypeVar, cast +from typing import Dict, List, Optional, TypeVar, cast -from ndsl.comm.comm_abc import Comm, Request +from ndsl.comm.comm_abc import Comm, ReductionOperator, Request from ndsl.logging import ndsl_log @@ -12,6 +13,24 @@ class MPIComm(Comm): + _op_mapping: Dict[ReductionOperator, mpi4py.MPI.Op] = { + ReductionOperator.OP_NULL: mpi4py.MPI.OP_NULL, + ReductionOperator.MAX: mpi4py.MPI.MAX, + ReductionOperator.MIN: mpi4py.MPI.MIN, + ReductionOperator.SUM: mpi4py.MPI.SUM, + ReductionOperator.PROD: mpi4py.MPI.PROD, + ReductionOperator.LAND: mpi4py.MPI.LAND, + ReductionOperator.BAND: mpi4py.MPI.BAND, + ReductionOperator.LOR: mpi4py.MPI.LOR, + ReductionOperator.BOR: mpi4py.MPI.BOR, + ReductionOperator.LXOR: mpi4py.MPI.LXOR, + ReductionOperator.BXOR: mpi4py.MPI.BXOR, + ReductionOperator.MAXLOC: mpi4py.MPI.MAXLOC, + ReductionOperator.MINLOC: mpi4py.MPI.MINLOC, + ReductionOperator.REPLACE: mpi4py.MPI.REPLACE, + ReductionOperator.NO_OP: mpi4py.MPI.NO_OP, + } + def __init__(self): if MPI is None: raise RuntimeError("MPI not available") @@ -72,8 +91,22 @@ def Split(self, color, key) -> "Comm": ) return self._comm.Split(color, key) - def allreduce(self, sendobj: T, op=None) -> T: + def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T: ndsl_log.debug( "allreduce on rank %s with operator %s", self._comm.Get_rank(), op ) - return self._comm.allreduce(sendobj, op) + return self._comm.allreduce(sendobj, self._op_mapping[op]) + + def Allreduce(self, sendobj_or_inplace: T, recvobj: T, op: ReductionOperator) -> T: + ndsl_log.debug( + "Allreduce on rank %s with operator %s", self._comm.Get_rank(), op + ) + return self._comm.Allreduce(sendobj_or_inplace, recvobj, self._op_mapping[op]) + + def Allreduce_inplace(self, recvobj: T, op: ReductionOperator) -> T: + ndsl_log.debug( + "Allreduce (in place) on rank %s with operator %s", + self._comm.Get_rank(), + op, + ) + return self._comm.Allreduce(mpi4py.MPI.IN_PLACE, recvobj, self._op_mapping[op]) diff --git a/ndsl/comm/null_comm.py b/ndsl/comm/null_comm.py index 7e0c07fa..5ca92359 100644 --- a/ndsl/comm/null_comm.py +++ b/ndsl/comm/null_comm.py @@ -1,7 +1,7 @@ import copy -from typing import Any, Mapping +from typing import Any, Mapping, Optional -from ndsl.comm.comm_abc import Comm, Request +from ndsl.comm.comm_abc import Comm, ReductionOperator, Request class NullAsyncResult(Request): @@ -91,5 +91,9 @@ def Split(self, color, key): self._split_comms[color].append(new_comm) return new_comm - def allreduce(self, sendobj, op=None) -> Any: + def allreduce(self, sendobj, op: Optional[ReductionOperator] = None) -> Any: return self._fill_value + + def Allreduce(self, sendobj, recvobj, op: ReductionOperator) -> Any: + recvobj = sendobj + return recvobj diff --git a/ndsl/quantity.py b/ndsl/quantity.py index b95a9aad..a38a7a5d 100644 --- a/ndsl/quantity.py +++ b/ndsl/quantity.py @@ -53,6 +53,15 @@ def np(self) -> NumpyModule: f"quantity underlying data is of unexpected type {self.data_type}" ) + def duplicate_metadata(self, metadata_copy): + metadata_copy.origin = self.origin + metadata_copy.extent = self.extent + metadata_copy.dims = self.dims + metadata_copy.units = self.units + metadata_copy.data_type = self.data_type + metadata_copy.dtype = self.dtype + metadata_copy.gt4py_backend = self.gt4py_backend + @dataclasses.dataclass class QuantityHaloSpec: @@ -492,6 +501,11 @@ def data(self) -> Union[np.ndarray, cupy.ndarray]: """the underlying array of data""" return self._data + @data.setter + def data(self, inputData): + if type(inputData) in [np.ndarray, cupy.ndarray]: + self._data = inputData + @property def origin(self) -> Tuple[int, ...]: """the start of the computational domain""" diff --git a/ndsl/stencils/testing/conftest.py b/ndsl/stencils/testing/conftest.py index f01d17be..2ed22fee 100644 --- a/ndsl/stencils/testing/conftest.py +++ b/ndsl/stencils/testing/conftest.py @@ -13,7 +13,7 @@ CubedSphereCommunicator, TileCommunicator, ) -from ndsl.comm.mpi import MPI +from ndsl.comm.mpi import MPI, MPIComm from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner from ndsl.dsl.dace.dace_config import DaceConfig from ndsl.namelist import Namelist @@ -323,7 +323,7 @@ def compute_grid_data(grid, namelist, backend, layout, topology_mode): npx=namelist.npx, npy=namelist.npy, npz=namelist.npz, - communicator=get_communicator(MPI.COMM_WORLD, layout, topology_mode), + communicator=get_communicator(MPIComm(), layout, topology_mode), backend=backend, ) @@ -377,13 +377,12 @@ def generate_parallel_stencil_tests(metafunc, *, backend: str): metafunc.config ) # get MPI environment - comm = MPI.COMM_WORLD - mpi_rank = comm.Get_rank() + comm = MPIComm() savepoint_cases = parallel_savepoint_cases( metafunc, data_path, namelist_filename, - mpi_rank, + comm.Get_rank(), backend=backend, comm=comm, ) @@ -393,7 +392,7 @@ def generate_parallel_stencil_tests(metafunc, *, backend: str): def get_communicator(comm, layout, topology_mode): - if (MPI.COMM_WORLD.Get_size() > 1) and (topology_mode == "cubed-sphere"): + if (comm.Get_size() > 1) and (topology_mode == "cubed-sphere"): partitioner = CubedSpherePartitioner(TilePartitioner(layout)) communicator = CubedSphereCommunicator(comm, partitioner) else: diff --git a/ndsl/stencils/testing/test_translate.py b/ndsl/stencils/testing/test_translate.py index 0147c040..9f0278d8 100644 --- a/ndsl/stencils/testing/test_translate.py +++ b/ndsl/stencils/testing/test_translate.py @@ -8,7 +8,7 @@ import ndsl.dsl.gt4py_utils as gt_utils from ndsl.comm.communicator import CubedSphereCommunicator, TileCommunicator -from ndsl.comm.mpi import MPI +from ndsl.comm.mpi import MPI, MPIComm from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner from ndsl.dsl.dace.dace_config import DaceConfig from ndsl.dsl.stencil import CompilationConfig, StencilConfig @@ -304,18 +304,19 @@ def test_parallel_savepoint( multimodal_metric, xy_indices=True, ): - if MPI.COMM_WORLD.Get_size() % 6 != 0: + mpi_comm = MPIComm() + if mpi_comm.Get_size() % 6 != 0: layout = ( - int(MPI.COMM_WORLD.Get_size() ** 0.5), - int(MPI.COMM_WORLD.Get_size() ** 0.5), + int(mpi_comm.Get_size() ** 0.5), + int(mpi_comm.Get_size() ** 0.5), ) - communicator = get_tile_communicator(MPI.COMM_WORLD, layout) + communicator = get_tile_communicator(mpi_comm, layout) else: layout = ( - int((MPI.COMM_WORLD.Get_size() // 6) ** 0.5), - int((MPI.COMM_WORLD.Get_size() // 6) ** 0.5), + int((mpi_comm.Get_size() // 6) ** 0.5), + int((mpi_comm.Get_size() // 6) ** 0.5), ) - communicator = get_communicator(MPI.COMM_WORLD, layout) + communicator = get_communicator(mpi_comm, layout) if case.testobj is None: pytest.xfail( f"no translate object available for savepoint {case.savepoint_name}" diff --git a/tests/dsl/test_compilation_config.py b/tests/dsl/test_compilation_config.py index 62049d91..fa323b06 100644 --- a/tests/dsl/test_compilation_config.py +++ b/tests/dsl/test_compilation_config.py @@ -7,6 +7,7 @@ CompilationConfig, CubedSphereCommunicator, CubedSpherePartitioner, + NullComm, RunMode, TilePartitioner, ) @@ -33,8 +34,7 @@ def test_check_communicator_valid( partitioner = CubedSpherePartitioner( TilePartitioner((int(sqrt(size / 6)), int((sqrt(size / 6))))) ) - comm = unittest.mock.MagicMock() - comm.Get_size.return_value = size + comm = NullComm(rank=0, total_ranks=size) cubed_sphere_comm = CubedSphereCommunicator(comm, partitioner) config = CompilationConfig( run_mode=run_mode, use_minimal_caching=use_minimal_caching @@ -52,8 +52,7 @@ def test_check_communicator_invalid( nx: int, ny: int, use_minimal_caching: bool, run_mode: RunMode ): partitioner = CubedSpherePartitioner(TilePartitioner((nx, ny))) - comm = unittest.mock.MagicMock() - comm.Get_size.return_value = nx * ny * 6 + comm = NullComm(rank=0, total_ranks=nx * ny * 6) cubed_sphere_comm = CubedSphereCommunicator(comm, partitioner) config = CompilationConfig( run_mode=run_mode, use_minimal_caching=use_minimal_caching @@ -91,9 +90,7 @@ def test_get_decomposition_info_from_comm( partitioner = CubedSpherePartitioner( TilePartitioner((int(sqrt(size / 6)), int(sqrt(size / 6)))) ) - comm = unittest.mock.MagicMock() - comm.Get_rank.return_value = rank - comm.Get_size.return_value = size + comm = NullComm(rank=rank, total_ranks=size) cubed_sphere_comm = CubedSphereCommunicator(comm, partitioner) config = CompilationConfig(use_minimal_caching=True, run_mode=RunMode.Run) ( @@ -133,8 +130,7 @@ def test_determine_compiling_equivalent( TilePartitioner((sqrt(size / 6), sqrt(size / 6))) ) comm = unittest.mock.MagicMock() - comm.Get_rank.return_value = rank - comm.Get_size.return_value = size + comm = NullComm(rank=rank, total_ranks=size) cubed_sphere_comm = CubedSphereCommunicator(comm, partitioner) assert ( config.determine_compiling_equivalent(rank, cubed_sphere_comm.partitioner) diff --git a/tests/mpi/test_mpi_all_reduce_sum.py b/tests/mpi/test_mpi_all_reduce_sum.py new file mode 100644 index 00000000..bec096dd --- /dev/null +++ b/tests/mpi/test_mpi_all_reduce_sum.py @@ -0,0 +1,149 @@ +import numpy as np +import pytest + +from ndsl import ( + CubedSphereCommunicator, + CubedSpherePartitioner, + Quantity, + TilePartitioner, +) +from ndsl.comm.comm_abc import ReductionOperator +from ndsl.comm.mpi import MPIComm +from ndsl.dsl.typing import Float +from tests.mpi.mpi_comm import MPI + + +@pytest.fixture +def layout(): + if MPI is not None: + size = MPI.COMM_WORLD.Get_size() + ranks_per_tile = size // 6 + ranks_per_edge = int(ranks_per_tile ** 0.5) + return (ranks_per_edge, ranks_per_edge) + else: + return (1, 1) + + +@pytest.fixture(params=[0.1, 1.0]) +def edge_interior_ratio(request): + return request.param + + +@pytest.fixture +def tile_partitioner(layout, edge_interior_ratio: float): + return TilePartitioner(layout, edge_interior_ratio=edge_interior_ratio) + + +@pytest.fixture +def cube_partitioner(tile_partitioner): + return CubedSpherePartitioner(tile_partitioner) + + +@pytest.fixture() +def communicator(cube_partitioner): + return CubedSphereCommunicator( + comm=MPIComm(), + partitioner=cube_partitioner, + ) + + +@pytest.mark.skipif( + MPI is None, reason="mpi4py is not available or pytest was not run in parallel" +) +def test_all_reduce(communicator): + backends = ["dace:cpu", "gt:cpu_kfirst", "numpy"] + + for backend in backends: + base_array = np.array([i for i in range(5)], dtype=Float) + + testQuantity_1D = Quantity( + data=base_array, + dims=["K"], + units="Some 1D unit", + gt4py_backend=backend, + ) + + base_array = np.array([i for i in range(5 * 5)], dtype=Float) + base_array = base_array.reshape(5, 5) + + testQuantity_2D = Quantity( + data=base_array, + dims=["I", "J"], + units="Some 2D unit", + gt4py_backend=backend, + ) + + base_array = np.array([i for i in range(5 * 5 * 5)], dtype=Float) + base_array = base_array.reshape(5, 5, 5) + + testQuantity_3D = Quantity( + data=base_array, + dims=["I", "J", "K"], + units="Some 3D unit", + gt4py_backend=backend, + ) + + global_sum_q = communicator.all_reduce(testQuantity_1D, ReductionOperator.SUM) + assert global_sum_q.metadata == testQuantity_1D.metadata + assert (global_sum_q.data == (testQuantity_1D.data * communicator.size)).all() + + global_sum_q = communicator.all_reduce(testQuantity_2D, ReductionOperator.SUM) + assert global_sum_q.metadata == testQuantity_2D.metadata + assert (global_sum_q.data == (testQuantity_2D.data * communicator.size)).all() + + global_sum_q = communicator.all_reduce(testQuantity_3D, ReductionOperator.SUM) + assert global_sum_q.metadata == testQuantity_3D.metadata + assert (global_sum_q.data == (testQuantity_3D.data * communicator.size)).all() + + base_array = np.array([i for i in range(5)], dtype=Float) + testQuantity_1D_out = Quantity( + data=base_array, + dims=["K"], + units="New 1D unit", + gt4py_backend=backend, + origin=(8,), + extent=(7,), + ) + + base_array = np.array([i for i in range(5 * 5)], dtype=Float) + base_array = base_array.reshape(5, 5) + + testQuantity_2D_out = Quantity( + data=base_array, + dims=["I", "J"], + units="Some 2D unit", + gt4py_backend=backend, + ) + + base_array = np.array([i for i in range(5 * 5 * 5)], dtype=Float) + base_array = base_array.reshape(5, 5, 5) + + testQuantity_3D_out = Quantity( + data=base_array, + dims=["I", "J", "K"], + units="Some 3D unit", + gt4py_backend=backend, + ) + communicator.all_reduce( + testQuantity_1D, ReductionOperator.SUM, testQuantity_1D_out + ) + assert testQuantity_1D_out.metadata == testQuantity_1D.metadata + assert ( + testQuantity_1D_out.data == (testQuantity_1D.data * communicator.size) + ).all() + + communicator.all_reduce( + testQuantity_2D, ReductionOperator.SUM, testQuantity_2D_out + ) + assert testQuantity_2D_out.metadata == testQuantity_2D.metadata + assert ( + testQuantity_2D_out.data == (testQuantity_2D.data * communicator.size) + ).all() + + communicator.all_reduce( + testQuantity_3D, ReductionOperator.SUM, testQuantity_3D_out + ) + assert testQuantity_3D_out.metadata == testQuantity_3D.metadata + assert ( + testQuantity_3D_out.data == (testQuantity_3D.data * communicator.size) + ).all() diff --git a/tests/mpi/test_mpi_halo_update.py b/tests/mpi/test_mpi_halo_update.py index ab11b16e..b6c38e95 100644 --- a/tests/mpi/test_mpi_halo_update.py +++ b/tests/mpi/test_mpi_halo_update.py @@ -9,6 +9,7 @@ TilePartitioner, ) from ndsl.comm._boundary_utils import get_boundary_slice +from ndsl.comm.mpi import MPIComm from ndsl.constants import ( BOUNDARY_TYPES, EDGE_BOUNDARY_TYPES, @@ -176,7 +177,7 @@ def extent(n_points, dims, nz, ny, nx): @pytest.fixture() def communicator(cube_partitioner): return CubedSphereCommunicator( - comm=MPI.COMM_WORLD, + comm=MPIComm(), partitioner=cube_partitioner, )