Skip to content

Commit

Permalink
fix bias[None, :] in tp's functional
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Jan 14, 2025
1 parent 21b2408 commit f8c40ad
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/nanotron/parallel/tensor_parallel/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,10 @@ def forward(ctx, tensor, weight, bias, group, tp_mode, tp_recompute_allgather):
)
else:
torch.addmm(
input=bias[None, :],
# NOTE(xrsrke): if keep bias[None, :], then we got
# RuntimeError: Attempted to make a tensor into a differentiable view,
# but the tensor already had autograd metadata associated with it
input=bias.view(1, -1),
mat1=tensor.view(first_dims, hidden_size),
mat2=weight.t(),
out=same_device_shard.view(first_dims, output_size),
Expand Down Expand Up @@ -236,7 +239,8 @@ def forward(ctx, tensor, weight, bias, group, tp_mode, tp_recompute_allgather):
)
else:
torch.addmm(
input=bias[None, :],
# input=bias[None, :],
input=bias.view(1, -1),
mat1=gathered_tensor[: sharded_batch_size * current_rank].view(first_dims, hidden_size),
mat2=weight.t(),
out=before_shard.view(first_dims, output_size),
Expand All @@ -253,7 +257,8 @@ def forward(ctx, tensor, weight, bias, group, tp_mode, tp_recompute_allgather):
)
else:
torch.addmm(
input=bias[None, :],
# input=bias[None, :],
input=bias.view(1, -1),
mat1=gathered_tensor[sharded_batch_size * (current_rank + 1) :].view(
first_dims, hidden_size
),
Expand Down

0 comments on commit f8c40ad

Please sign in to comment.