-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch/C++] Comm+GEMM overlap compatibility with QuantizedTensor #1427
base: release_v2.0
Are you sure you want to change the base?
[PyTorch/C++] Comm+GEMM overlap compatibility with QuantizedTensor #1427
Conversation
…th cppqtensor Signed-off-by: Alp Dener <[email protected]> CommOverlap objects can now return overlap buffers to PyTorch as QuantizedTensors Signed-off-by: Alp Dener <[email protected]> updated comm+GEMM overlap test for pure GEMM, both BF16 and FP8 working with QuantizedTensor Signed-off-by: Alp Dener <[email protected]> te.Linear and te.LayerNormMLP updated for TP overlap w/ QuantizedTensor. All overlaps work in BF16. All ovrlaps except bulk WGRAD work in FP8. Signed-off-by: Alp Dener <[email protected]> completed TP overlap QuantizedTensor updates for LayerNormLinear, but issues with quantized normalization Signed-off-by: Alp Dener <[email protected]> all overlaps working with bf16, all but bulk WGRAD working with FP8 Signed-off-by: Alp Dener <[email protected]> all overlaps work with Float8Tensor, except bulk wgrad in LayerNormMLP (works in other modules) Signed-off-by: Alp Dener <[email protected]> all overlaps working with QuantizedTensor in BF16 and FP8 Signed-off-by: Alp Dener <[email protected]> cleaned up pytest formatting Signed-off-by: Alp Dener <[email protected]>
9ba5009
to
f1dcf35
Compare
for more information, see https://pre-commit.ci
…and updated test sizing Signed-off-by: Alp Dener <[email protected]>
# Configure quantizer for normalization output | ||
if fp8 and input_quantizer is None: | ||
raise ValueError("Missing quantizer for input tensor") | ||
if fp8: | ||
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and isinstance( | ||
FP8GlobalStateManager.get_fp8_recipe(), BlockScaling | ||
): | ||
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling") | ||
|
||
if input_quantizer is None: | ||
raise ValueError("Missing quantizer for input tensor") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Putting UB logic here makes the comment incorrect
- This won't generalize when we add more quantization schemes. Instead of assuming that all recipes except MXFP8 support UB, we should only assume FP8 delayed scaling supports UB.
# Configure quantizer for normalization output | |
if fp8 and input_quantizer is None: | |
raise ValueError("Missing quantizer for input tensor") | |
if fp8: | |
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and isinstance( | |
FP8GlobalStateManager.get_fp8_recipe(), BlockScaling | |
): | |
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling") | |
if input_quantizer is None: | |
raise ValueError("Missing quantizer for input tensor") | |
# Check if overlapped communication is supported | |
if ( | |
fp8 | |
and (ub_overlap_ag_fprop or ub_overlap_rs_fprop) | |
and not FP8GlobalStateManager.get_fp8_recipe().delayed() | |
): | |
raise NotImplementedError("Comm+GEMM overlap is only supported with FP8 delayed scaling") | |
# Configure quantizer for normalization output | |
if fp8: | |
if input_quantizer is None: | |
raise ValueError("Missing quantizer for input tensor") |
and isinstance(FP8GlobalStateManager.get_fp8_recipe(), BlockScaling) | ||
): | ||
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and isinstance(FP8GlobalStateManager.get_fp8_recipe(), BlockScaling) | |
): | |
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling") | |
and not FP8GlobalStateManager.get_fp8_recipe().delayed() | |
): | |
raise NotImplementedError("Comm+GEMM overlap is only supported with FP8 delayed scaling") |
ub_obj_fprop = None | ||
ln_out = None | ||
if ub_overlap_ag_fprop: | ||
ub_obj_fprop = get_ub(ub_name + "_fprop") | ||
ln_out = ub_obj_fprop.get_buffer(input_quantizer, True) | ||
elif with_quantized_norm: | ||
ln_out = input_quantizer.make_empty(inputmat.shape, dtype=inputmat.dtype, device="cuda") | ||
else: | ||
ln_out = torch.empty_like( | ||
inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format | ||
) | ||
|
||
# Apply normalization | ||
ln_out, mu, rsigma = apply_normalization( | ||
_, mu, rsigma = apply_normalization( | ||
inputmat, | ||
None, | ||
ln_out, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We prefer constructing tensors in C++ to reduce CPU overheads:
ub_obj_fprop = None | |
ln_out = None | |
if ub_overlap_ag_fprop: | |
ub_obj_fprop = get_ub(ub_name + "_fprop") | |
ln_out = ub_obj_fprop.get_buffer(input_quantizer, True) | |
elif with_quantized_norm: | |
ln_out = input_quantizer.make_empty(inputmat.shape, dtype=inputmat.dtype, device="cuda") | |
else: | |
ln_out = torch.empty_like( | |
inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format | |
) | |
# Apply normalization | |
ln_out, mu, rsigma = apply_normalization( | |
_, mu, rsigma = apply_normalization( | |
inputmat, | |
None, | |
ln_out, | |
ub_obj_fprop = None | |
ln_out = None | |
if ub_overlap_ag_fprop: | |
ub_obj_fprop = get_ub(ub_name + "_fprop") | |
ln_out = ub_obj_fprop.get_buffer(input_quantizer, True) | |
# Apply normalization | |
ln_out, mu, rsigma = apply_normalization( | |
inputmat, | |
ln_out, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Constructing this in PyTorch is how it was before. Switching it to None
to construct it in C++ is triggering an error:
RuntimeError: /home/adener/devroot/nvte-internal/transformer_engine/common/normalization/common.h:197 in function getKernel: Unavailable kernel for this normalization config.
This is happening with TP overlap turned off as well. Retaining the same PyTorch initialization as before works correctly.
" MULTI-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - FP8 OUT ", | ||
], | ||
) | ||
def test_atomic_gemm_overlaps(ag_type, rs_type, p2p, fp8_out): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we removing these tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Atomic GEMM is deprecated in Blackwell and onward, and split GEMM + CUDA Graphs is more performant on Hopper. cuBlasMp does not support it either, so the functionality will disappear for good in TE when we eventually deprecate Userbuffers. So I removed the tests in this PR in the interest of having one less thing to maintain that no longer has any use cases.
@@ -261,6 +261,29 @@ def _create_transpose(self): | |||
self._transpose = tex.fp8_transpose(data, self._fp8_dtype, out=self._transpose) | |||
self._transpose_invalid = False | |||
|
|||
def _fix_gathered_transpose(self, tp_size=1, from_rowwise=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like how unintuitive and UB-specific this function is. It would have been better to implement CommOverlap::get_buffer
so that it internally transposes the UB buffer and constructs a Float8Tensor
with only column-wise data (instead of how this moves the row-wise data to the column-wise data). We're not quite there yet since Float8Tensor
s with only column-wise data experience problems.
For now, we should delete this function and reimplement the logic in the linear modules. We should make clear that this is a hacky workaround and not allow it to become an API we rely on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks reasonable, although I have some stylistic suggestions. This is fine in our last-minute scramble to restore UB support with FP8. Next we will need to think about extending it to support MXFP8 and other quantization schemes.
#define NOT_IMPLEMENTED_ERROR() NVTE_ERROR("Operation is not implemented.") | ||
|
||
#define NOT_SUPPORTED_ERROR() NVTE_ERROR("Operation not supported.") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I don't think these macros are giving us much benefit. We save 20 characters but add another level of indirection.
#define NOT_IMPLEMENTED_ERROR() NVTE_ERROR("Operation is not implemented.") | |
#define NOT_SUPPORTED_ERROR() NVTE_ERROR("Operation not supported.") |
if ub_type == tex.CommOverlapType.AG and ub.is_p2p_overlap(): | ||
out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have a separate path for this UB case? The GEMM invocation seems identical and the swizzling is only relevant for the MXFP8 case.
Not critical right now, but I think we could clean up this entire section:
# Swizzle scaling factors if needed
is_swizzle_disabled = ub_type == tex.CommOverlapType.AG and ub.is_p2p_overlap()
original_scale_inverses = None
if not is_swizzle_disabled:
original_scale_inverses = swizzle_inputs(A, B, layout)
# Perform GEMM
out, bias_grad, gelu_input, extra_output = tex.generic_gemm(
A,
transa,
B,
transb,
...
)
# Reset swizzled scaling factors
if not is_swizzle_disabled:
reset_swizzled_inputs(A, B, original_scale_inverses)
return out, bias_grad, gelu_input, extra_output
# Prepare input tensor | ||
# Note: Cast to expected dtype and perform tensor-parallel communication | ||
inputmat = inp | ||
inputmat_total = None | ||
with_input_all_gather = parallel_mode == "column" and sequence_parallel | ||
with_input_all_gather_nccl = ( | ||
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop | ||
) | ||
own_quantized_input = False | ||
if fp8: | ||
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and isinstance( | ||
FP8GlobalStateManager.get_fp8_recipe(), BlockScaling | ||
): | ||
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as #1427 (comment):
# Prepare input tensor | |
# Note: Cast to expected dtype and perform tensor-parallel communication | |
inputmat = inp | |
inputmat_total = None | |
with_input_all_gather = parallel_mode == "column" and sequence_parallel | |
with_input_all_gather_nccl = ( | |
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop | |
) | |
own_quantized_input = False | |
if fp8: | |
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and isinstance( | |
FP8GlobalStateManager.get_fp8_recipe(), BlockScaling | |
): | |
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling") | |
# Check if overlapped communication is supported | |
if ( | |
fp8 | |
and (ub_overlap_ag_fprop or ub_overlap_rs_fprop) | |
and not FP8GlobalStateManager.get_fp8_recipe().delayed() | |
): | |
raise NotImplementedError("Comm+GEMM overlap is only supported with FP8 delayed scaling") | |
# Prepare input tensor | |
# Note: Cast to expected dtype and perform tensor-parallel communication | |
inputmat = inp | |
inputmat_total = None | |
with_input_all_gather_nccl = ( | |
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop | |
) | |
own_quantized_input = False | |
if fp8: |
and isinstance(FP8GlobalStateManager.get_fp8_recipe(), BlockScaling) | ||
): | ||
recipe = FP8GlobalStateManager.get_fp8_recipe() | ||
print(f"FP8 Recipe: {type(recipe)} -> {recipe}") | ||
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and isinstance(FP8GlobalStateManager.get_fp8_recipe(), BlockScaling) | |
): | |
recipe = FP8GlobalStateManager.get_fp8_recipe() | |
print(f"FP8 Recipe: {type(recipe)} -> {recipe}") | |
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling") | |
and not FP8GlobalStateManager.get_fp8_recipe().delayed() | |
): | |
raise NotImplementedError("Comm+GEMM overlap is only supported with FP8 delayed scaling") |
# if opts.fp8: | ||
# dist_print("GEMM1 FP8 metas = [INPUT, WEIGHT, OUTPUT]", src=0, section=True) | ||
# fp8_meta_info = ( | ||
# f"amax_reference = {fp8_meta.amax_history[1][:3].tolist()}\n" | ||
# + f"amax_history = {fp8_meta.amax_history[0][:3].tolist()}\n" | ||
# + f"scale = {fp8_meta.scale[:3].tolist()}\n" | ||
# + f"scale_inv = {fp8_meta.scale_inv[:3].tolist()}" | ||
# ) | ||
# dist_print(fp8_meta_info, src=0, group=tp_group) | ||
# if ub_obj2 is not None: | ||
# dist_print("GEMM2 FP8 metas = [INPUT, WEIGHT, OUTPUT]", src=0, section=True) | ||
# fp8_meta_info = ( | ||
# f"amax_reference = {fp8_meta.amax_history[1][3:].tolist()}\n" | ||
# + f"amax_history = {fp8_meta.amax_history[0][3:].tolist()}\n" | ||
# + f"scale = {fp8_meta.scale[3:].tolist()}\n" | ||
# + f"scale_inv = {fp8_meta.scale_inv[3:].tolist()}" | ||
# ) | ||
# dist_print(fp8_meta_info, src=0, group=tp_group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Debugging code:
# if opts.fp8: | |
# dist_print("GEMM1 FP8 metas = [INPUT, WEIGHT, OUTPUT]", src=0, section=True) | |
# fp8_meta_info = ( | |
# f"amax_reference = {fp8_meta.amax_history[1][:3].tolist()}\n" | |
# + f"amax_history = {fp8_meta.amax_history[0][:3].tolist()}\n" | |
# + f"scale = {fp8_meta.scale[:3].tolist()}\n" | |
# + f"scale_inv = {fp8_meta.scale_inv[:3].tolist()}" | |
# ) | |
# dist_print(fp8_meta_info, src=0, group=tp_group) | |
# if ub_obj2 is not None: | |
# dist_print("GEMM2 FP8 metas = [INPUT, WEIGHT, OUTPUT]", src=0, section=True) | |
# fp8_meta_info = ( | |
# f"amax_reference = {fp8_meta.amax_history[1][3:].tolist()}\n" | |
# + f"amax_history = {fp8_meta.amax_history[0][3:].tolist()}\n" | |
# + f"scale = {fp8_meta.scale[3:].tolist()}\n" | |
# + f"scale_inv = {fp8_meta.scale_inv[3:].tolist()}" | |
# ) | |
# dist_print(fp8_meta_info, src=0, group=tp_group) |
if any([ub_overlap_ag, ub_overlap_rs]) and isinstance( | ||
FP8GlobalStateManager.get_fp8_recipe(), BlockScaling | ||
): | ||
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if any([ub_overlap_ag, ub_overlap_rs]) and isinstance( | |
FP8GlobalStateManager.get_fp8_recipe(), BlockScaling | |
): | |
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling") | |
if ( | |
(ub_overlap_ag or ub_overlap_rs) | |
and not FP8GlobalStateManager.get_fp8_recipe().delayed() | |
): | |
raise NotImplementedError("Comm+GEMM overlap is only supported with FP8 delayed scaling") |
ub_obj_lnout = None | ||
ln_out = None | ||
if ub_overlap_ag: | ||
ub_obj_lnout = get_ub("fc1_fprop") | ||
ln_out = ub_obj_lnout.get_buffer(fc1_input_quantizer, True) | ||
elif with_quantized_norm: | ||
ln_out = fc1_input_quantizer.make_empty( | ||
inputmat.shape, dtype=inputmat.dtype, device="cuda" | ||
) | ||
else: | ||
ln_out = torch.empty_like( | ||
inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format | ||
) | ||
|
||
# Apply normalization | ||
ln_out, mu, rsigma = apply_normalization( | ||
_, mu, rsigma = apply_normalization( | ||
inputmat, | ||
None, | ||
ln_out, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as #1427 (comment):
ub_obj_lnout = None | |
ln_out = None | |
if ub_overlap_ag: | |
ub_obj_lnout = get_ub("fc1_fprop") | |
ln_out = ub_obj_lnout.get_buffer(fc1_input_quantizer, True) | |
elif with_quantized_norm: | |
ln_out = fc1_input_quantizer.make_empty( | |
inputmat.shape, dtype=inputmat.dtype, device="cuda" | |
) | |
else: | |
ln_out = torch.empty_like( | |
inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format | |
) | |
# Apply normalization | |
ln_out, mu, rsigma = apply_normalization( | |
_, mu, rsigma = apply_normalization( | |
inputmat, | |
None, | |
ln_out, | |
ub_obj_lnout = None | |
ln_out = None | |
if ub_overlap_ag: | |
ub_obj_lnout = get_ub("fc1_fprop") | |
ln_out = ub_obj_lnout.get_buffer(fc1_input_quantizer, True) | |
# Apply normalization | |
ln_out, mu, rsigma = apply_normalization( | |
inputmat, | |
ln_out, |
and isinstance(FP8GlobalStateManager.get_fp8_recipe(), BlockScaling) | ||
): | ||
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and isinstance(FP8GlobalStateManager.get_fp8_recipe(), BlockScaling) | |
): | |
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling") | |
and not FP8GlobalStateManager.get_fp8_recipe().delayed() | |
): | |
raise NotImplementedError("Comm+GEMM overlap is only supported with FP8 delayed scaling") |
Description
This PR updates TE/common and TE/PyTorch API for comm+GEMM overlap to support the new QuantizedTensor abstraction.
Type of change
Checklist: