Skip to content

Commit

Permalink
Merge branch 'main' into parambole/jsts_gpu_pp
Browse files Browse the repository at this point in the history
  • Loading branch information
parambole authored Feb 7, 2025
2 parents 51ae39d + 3c10501 commit ea6fe73
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
2 changes: 2 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,8 @@ enable_single_controller: False
custom_mesh: "" # Available options: ['hybrid_ring_64x4', 'hybrid_ring_32x8']
# Split physical axes for https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.mesh_utils.create_device_mesh.html
allow_split_physical_axes: False
# Apply transformations to the mesh to optimize for TPU v6e
optimize_mesh_for_tpu_v6e: False

use_ragged_attention: False
ragged_block_size: 256
Expand Down
26 changes: 26 additions & 0 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,30 @@ def is_valid_custom_mesh(ici_parallelism, strategy):
raise ValueError(f"The strategy {strategy} to reshape the mesh is invalid.")


def optimize_mesh_for_tpu_v6e(mesh, devices):
"""Apply transformations to the mesh to optimize for TPU v6e"""
if devices[0].device_kind != "TPU v6 lite":
return mesh
num_devices = len(devices)
mesh_is_1d_ring = num_devices in mesh.shape
if not mesh_is_1d_ring:
return mesh
# check that the physical topology is 2x4
device_coords = [d.coords for d in devices]
coord_size = len(device_coords[0])
max_coords = tuple(max(dc[i] for dc in device_coords) for i in range(coord_size))
min_coords = tuple(min(dc[i] for dc in device_coords) for i in range(coord_size))
dims = tuple(h - l + 1 for (h, l) in zip(max_coords, min_coords))
if dims != (2, 4, 1):
return mesh
axis_idx = mesh.shape.index(num_devices)
new_mesh = np.moveaxis(mesh, axis_idx, 0)
new_mesh[4:] = new_mesh[-1:3:-1]
new_mesh = np.moveaxis(new_mesh, 0, axis_idx)
max_logging.log("Optimized the mesh for TPU v6e")
return new_mesh


def create_device_mesh(config, devices=None):
"""Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas"""
if devices is None:
Expand Down Expand Up @@ -612,6 +636,8 @@ def create_device_mesh(config, devices=None):
ici_parallelism,
devices,
)
if config.optimize_mesh_for_tpu_v6e:
mesh = optimize_mesh_for_tpu_v6e(mesh, devices)

max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}")

Expand Down

0 comments on commit ea6fe73

Please sign in to comment.