diff --git a/src/Schedgen2/mpi_colls.py b/src/Schedgen2/mpi_colls.py index bff02ef..d0904e7 100644 --- a/src/Schedgen2/mpi_colls.py +++ b/src/Schedgen2/mpi_colls.py @@ -1,6 +1,6 @@ import json from goal import GoalComm -from patterns import binomialtree, recdoub, ring, linear +from patterns import binomialtree, recdoub, ring, linear, swing def mpi_communication_pattern_selection( @@ -226,6 +226,25 @@ def allreduce( **kwargs, ) ) + elif ptrn == "swing": + comms.append( + swing( + comm_size=comm_size, + datasize=datasize, + tag=tag, + algorithm="reduce-scatter", + **kwargs, + ) + ) + comms.append( + swing( + comm_size=comm_size, + datasize=datasize, + tag=tag + comm_size, + algorithm="allgather", + **kwargs, + ) + ) elif ptrn == "ring": comms.append( ring( diff --git a/src/Schedgen2/patterns.py b/src/Schedgen2/patterns.py index 9fbf87d..98326d1 100644 --- a/src/Schedgen2/patterns.py +++ b/src/Schedgen2/patterns.py @@ -115,6 +115,78 @@ def recdoub( dependencies[rank] = calc return comm +def swing( + comm_size: int, + datasize: int, + tag: int, + algorithm: str = "reduce-scatter", + compute_time_dependency: int = 0, + **kwargs, +) -> GoalComm: + """ + Create a Swing communication pattern. + + :param comm_size: number of ranks in the communicator + :param datasize: size of data to send or receive + :param tag: tag that is used for all send and receive operations + :param algorithm: communication algorithm that uses this pattern; default is reduce-scatter + :param compute_time_dependency: compute time dependency for each send operation; if 0 (default), no compute time is added + :param kwargs: additional arguments that are ignored + :return: GoalComm object that represents the communication pattern + """ + + assert algorithm in [ + "reduce-scatter", + "allgather", + ], f"the pattern does not currently support the {algorithm} algorithm" + + if not log2(comm_size).is_integer(): + raise ValueError("At the moment, Swing only support a number of ranks which is a power of 2") + + # Add other values if you plan to run it on more than 2**20 nodes + rhos = [1, -1, 3, -5, 11, -21, 43, -85, 171, -341, 683, -1365, 2731, -5461, 10923, -21845, 43691, -87381, 174763, -349525] + comm = GoalComm(comm_size) + num_steps = int(log2(comm_size)) + dependencies = [None] * comm_size + + if num_steps > len(rhos): + raise ValueError("Please increase the values of rhos in the code.") + + for r in range(num_steps): + for rank in range(comm_size): + if algorithm in ["reduce-scatter"]: + distance = rhos[r] + message_size = datasize // (2 ** (r + 1)) + elif algorithm in ["allgather"]: + distance = rhos[num_steps - r - 1] + message_size = datasize // (2 ** (num_steps - r)) + else: + raise ValueError( + f"the pattern does not currently support the {algorithm} algorithm" + ) + + # Flip the direction for odd ranks + if rank % 2: + distance = -distance + + if (rank + distance) < 0: + dest = (rank + distance) + comm_size + else: + dest = (rank + distance) % comm_size + + if dest < comm_size: + send = comm.Send(size=message_size, src=rank, dst=dest, tag=tag + r) + if dependencies[rank] is not None: + send.requires(dependencies[rank]) + dependencies[rank] = comm.Recv( + size=message_size, src=dest, dst=rank, tag=tag + r + ) + if compute_time_dependency > 0: + calc = comm.Calc(host=rank, size=compute_time_dependency) + calc.requires(dependencies[rank]) + dependencies[rank] = calc + return comm + def ring( comm_size: int, diff --git a/src/Schedgen2/schedgen.py b/src/Schedgen2/schedgen.py index e1b35b9..d5b46e3 100755 --- a/src/Schedgen2/schedgen.py +++ b/src/Schedgen2/schedgen.py @@ -64,7 +64,7 @@ p.add_argument( "--ptrn", dest="ptrn", - choices=["datasize_based", "binomialtree", "recdoub", "ring", "linear"], + choices=["datasize_based", "binomialtree", "recdoub", "ring", "linear", "swing"], default="datasize_based", help="Pattern to use for communication, note that not all patterns are available for all communication types", ) @@ -136,6 +136,9 @@ def verify_params(args): assert ( args.ptrn != "recdoub" or args.comm_size & (args.comm_size - 1) == 0 ), "Currently recdoub pattern requires a power of 2 communicator size." + assert ( + args.ptrn != "swing" or args.comm_size & (args.comm_size - 1) == 0 + ), "Currently swing pattern requires a power of 2 communicator size." def comm_to_func(comm: str) -> callable: