diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index db6ca75ad..cb98c7067 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -533,6 +533,8 @@ enable_single_controller: False custom_mesh: "" # Available options: ['hybrid_ring_64x4', 'hybrid_ring_32x8'] # Split physical axes for https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.mesh_utils.create_device_mesh.html allow_split_physical_axes: False +# Apply transformations to the mesh to optimize for TPU v6e +optimize_mesh_for_tpu_v6e: False use_ragged_attention: False ragged_block_size: 256 diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index fb41b8347..de12062ad 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -563,6 +563,30 @@ def is_valid_custom_mesh(ici_parallelism, strategy): raise ValueError(f"The strategy {strategy} to reshape the mesh is invalid.") +def optimize_mesh_for_tpu_v6e(mesh, devices): + """Apply transformations to the mesh to optimize for TPU v6e""" + if devices[0].device_kind != "TPU v6 lite": + return mesh + num_devices = len(devices) + mesh_is_1d_ring = num_devices in mesh.shape + if not mesh_is_1d_ring: + return mesh + # check that the physical topology is 2x4 + device_coords = [d.coords for d in devices] + coord_size = len(device_coords[0]) + max_coords = tuple(max(dc[i] for dc in device_coords) for i in range(coord_size)) + min_coords = tuple(min(dc[i] for dc in device_coords) for i in range(coord_size)) + dims = tuple(h - l + 1 for (h, l) in zip(max_coords, min_coords)) + if dims != (2, 4, 1): + return mesh + axis_idx = mesh.shape.index(num_devices) + new_mesh = np.moveaxis(mesh, axis_idx, 0) + new_mesh[4:] = new_mesh[-1:3:-1] + new_mesh = np.moveaxis(new_mesh, 0, axis_idx) + max_logging.log("Optimized the mesh for TPU v6e") + return new_mesh + + def create_device_mesh(config, devices=None): """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas""" if devices is None: @@ -612,6 +636,8 @@ def create_device_mesh(config, devices=None): ici_parallelism, devices, ) + if config.optimize_mesh_for_tpu_v6e: + mesh = optimize_mesh_for_tpu_v6e(mesh, devices) max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}")