Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 27, 2024
1 parent aa2bfce commit 09e2316
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 25 deletions.
35 changes: 13 additions & 22 deletions tests/jax/test_distributed_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from utils import assert_allclose


jax.config.update('jax_enable_compilation_cache', False)
jax.config.update("jax_enable_compilation_cache", False)


# AG+GEMM: (4, 32/P, 128) ----(AG)----> (4, 32, 128) x (128, 256/P) ----------> (4, 32, 256/P)
Expand Down Expand Up @@ -48,20 +48,18 @@ def _get_mesh(parallel_dist):
batched = False
fsdp = False
mesh_shape = dict(tp=NUM_DEVICES)
resources = dict(cp_resource='tp', tp_resource='tp')
resources = dict(cp_resource="tp", tp_resource="tp")
if parallel_dist in ["DP_TP", "FSDP_TP"]:
batched = True
mesh_shape.update(dict(tp=NUM_DEVICES//2, dp=NUM_DEVICES//2))
resources.update(dict(dp_resource='dp'))
mesh_shape.update(dict(tp=NUM_DEVICES // 2, dp=NUM_DEVICES // 2))
resources.update(dict(dp_resource="dp"))
if parallel_dist == "FSDP_TP":
fsdp = True
mesh_shape.update(dict(tp=NUM_DEVICES//2, dp=1, zp=NUM_DEVICES//2))
resources.update(dict(fsdp_resource='zp'))
mesh_shape.update(dict(tp=NUM_DEVICES // 2, dp=1, zp=NUM_DEVICES // 2))
resources.update(dict(fsdp_resource="zp"))
mesh_resource = te.MeshResource(**resources)

devices = mesh_utils.create_device_mesh(
(NUM_DEVICES, ), devices=jax.devices()[:NUM_DEVICES]
)
devices = mesh_utils.create_device_mesh((NUM_DEVICES,), devices=jax.devices()[:NUM_DEVICES])

mesh = Mesh(np.array(devices).reshape(tuple(mesh_shape.values())), tuple(mesh_shape.keys()))

Expand All @@ -73,9 +71,7 @@ def _get_inputs(mesh, mesh_resource, dtype, fwd_comm_type, batched, fsdp, fwd_bw

# Operand and output shapes
lhs_shape = (
[SEQ_LEN, HIDDEN_SIZE]
if fwd_comm_type == "ALL_GATHER"
else [SEQ_LEN, FFN_HIDDEN_SIZE]
[SEQ_LEN, HIDDEN_SIZE] if fwd_comm_type == "ALL_GATHER" else [SEQ_LEN, FFN_HIDDEN_SIZE]
)
rhs_shape = (
[HIDDEN_SIZE, FFN_HIDDEN_SIZE]
Expand Down Expand Up @@ -125,12 +121,12 @@ def _get_inputs(mesh, mesh_resource, dtype, fwd_comm_type, batched, fsdp, fwd_bw
sigma = 0.023
shapes = (lhs_shape, rhs_shape)
if fwd_bwd:
shapes += (out_shape, )
shapes += (out_shape,)
global_operands = list(
map(
lambda key, shape: jax.device_put(
mu + (sigma * jax.random.normal(key, shape, dtype=dtype)),
NamedSharding(mesh, PartitionSpec(None))
NamedSharding(mesh, PartitionSpec(None)),
),
split_keys,
shapes,
Expand All @@ -140,7 +136,7 @@ def _get_inputs(mesh, mesh_resource, dtype, fwd_comm_type, batched, fsdp, fwd_bw
# Allocate sharded operands on device
partition_axes = (lhs_spec, rhs_spec)
if fwd_bwd:
partition_axes += (out_spec, )
partition_axes += (out_spec,)
local_operands = list(
map(
lambda x, spec: jax.device_put(x, NamedSharding(mesh, PartitionSpec(*spec))),
Expand Down Expand Up @@ -245,9 +241,7 @@ def test_gemm_impl(comm_type, mesh_type):
global_operands,
output_info,
fsdp_gathered_rhs_spec,
) = _get_inputs(
mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp
)
) = _get_inputs(mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp)

@jax.jit
def _test_fn(lhs, rhs):
Expand All @@ -272,9 +266,7 @@ def test_gemm_fwd_bwd(comm_type, mesh_type):
global_operands,
output_info,
fsdp_gathered_rhs_spec,
) = _get_inputs(
mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp, fwd_bwd=True
)
) = _get_inputs(mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp, fwd_bwd=True)

@jax.jit
def _test_fn(lhs, rhs, grad):
Expand Down Expand Up @@ -308,4 +300,3 @@ def _test_fn(lhs, rhs, grad):
)

_check_output(mesh, *output_info, *global_operands, output, dgrad, wgrad, fwd_bwd=True)

6 changes: 3 additions & 3 deletions transformer_engine/jax/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,9 @@ def infer_sharding_from_operands(
rhs_spec_new = list(rhs_spec).copy()
lhs_spec_new[lhs_outer_dim] = None
if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is not None:
assert lhs_spec_new[lhs_inner_dim] == rhs_spec_new[rhs_inner_dim], (
"Contracting dimensions of LHS and RHS operands must have the same sharding."
)
assert (
lhs_spec_new[lhs_inner_dim] == rhs_spec_new[rhs_inner_dim]
), "Contracting dimensions of LHS and RHS operands must have the same sharding."
if lhs_spec_new[lhs_outer_dim] is not None:
warnings.warn(
"Outer dimension of the LHS operand must be all-gathered when both contracting "
Expand Down

0 comments on commit 09e2316

Please sign in to comment.