Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch/C++] Comm+GEMM overlap compatibility with QuantizedTensor #1427

Open
wants to merge 3 commits into
base: release_v2.0
Choose a base branch
from

Conversation

denera
Copy link
Collaborator

@denera denera commented Jan 28, 2025

Description

This PR updates TE/common and TE/PyTorch API for comm+GEMM overlap to support the new QuantizedTensor abstraction.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • [x I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@denera denera added the 2.0.0 label Jan 28, 2025
@denera denera requested review from timmoon10 and ptrendx January 28, 2025 01:56
@denera denera self-assigned this Jan 28, 2025
…th cppqtensor

Signed-off-by: Alp Dener <[email protected]>

CommOverlap objects can now return overlap buffers to PyTorch as QuantizedTensors

Signed-off-by: Alp Dener <[email protected]>

updated comm+GEMM overlap test for pure GEMM, both BF16 and FP8 working with QuantizedTensor

Signed-off-by: Alp Dener <[email protected]>

te.Linear and te.LayerNormMLP updated for TP overlap w/ QuantizedTensor. All overlaps work in BF16. All ovrlaps except bulk WGRAD work in FP8.

Signed-off-by: Alp Dener <[email protected]>

completed TP overlap QuantizedTensor updates for LayerNormLinear, but issues with quantized normalization

Signed-off-by: Alp Dener <[email protected]>

all overlaps working with bf16, all but bulk WGRAD working with FP8

Signed-off-by: Alp Dener <[email protected]>

all overlaps work with Float8Tensor, except bulk wgrad in LayerNormMLP (works in other modules)

Signed-off-by: Alp Dener <[email protected]>

all overlaps working with QuantizedTensor in BF16 and FP8

Signed-off-by: Alp Dener <[email protected]>

cleaned up pytest formatting

Signed-off-by: Alp Dener <[email protected]>
@denera denera force-pushed the blackwell-cppqtensor-tp-overlap-v2.0 branch from 9ba5009 to f1dcf35 Compare January 28, 2025 02:38
@ksivaman ksivaman self-requested a review January 28, 2025 20:42
Comment on lines 142 to +150
# Configure quantizer for normalization output
if fp8 and input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if fp8:
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and isinstance(
FP8GlobalStateManager.get_fp8_recipe(), BlockScaling
):
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling")

if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Putting UB logic here makes the comment incorrect
  • This won't generalize when we add more quantization schemes. Instead of assuming that all recipes except MXFP8 support UB, we should only assume FP8 delayed scaling supports UB.
Suggested change
# Configure quantizer for normalization output
if fp8 and input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if fp8:
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and isinstance(
FP8GlobalStateManager.get_fp8_recipe(), BlockScaling
):
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling")
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
# Check if overlapped communication is supported
if (
fp8
and (ub_overlap_ag_fprop or ub_overlap_rs_fprop)
and not FP8GlobalStateManager.get_fp8_recipe().delayed()
):
raise NotImplementedError("Comm+GEMM overlap is only supported with FP8 delayed scaling")
# Configure quantizer for normalization output
if fp8:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")

Comment on lines +404 to +406
and isinstance(FP8GlobalStateManager.get_fp8_recipe(), BlockScaling)
):
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
and isinstance(FP8GlobalStateManager.get_fp8_recipe(), BlockScaling)
):
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling")
and not FP8GlobalStateManager.get_fp8_recipe().delayed()
):
raise NotImplementedError("Comm+GEMM overlap is only supported with FP8 delayed scaling")

Comment on lines +164 to +179
ub_obj_fprop = None
ln_out = None
if ub_overlap_ag_fprop:
ub_obj_fprop = get_ub(ub_name + "_fprop")
ln_out = ub_obj_fprop.get_buffer(input_quantizer, True)
elif with_quantized_norm:
ln_out = input_quantizer.make_empty(inputmat.shape, dtype=inputmat.dtype, device="cuda")
else:
ln_out = torch.empty_like(
inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format
)

# Apply normalization
ln_out, mu, rsigma = apply_normalization(
_, mu, rsigma = apply_normalization(
inputmat,
None,
ln_out,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We prefer constructing tensors in C++ to reduce CPU overheads:

Suggested change
ub_obj_fprop = None
ln_out = None
if ub_overlap_ag_fprop:
ub_obj_fprop = get_ub(ub_name + "_fprop")
ln_out = ub_obj_fprop.get_buffer(input_quantizer, True)
elif with_quantized_norm:
ln_out = input_quantizer.make_empty(inputmat.shape, dtype=inputmat.dtype, device="cuda")
else:
ln_out = torch.empty_like(
inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format
)
# Apply normalization
ln_out, mu, rsigma = apply_normalization(
_, mu, rsigma = apply_normalization(
inputmat,
None,
ln_out,
ub_obj_fprop = None
ln_out = None
if ub_overlap_ag_fprop:
ub_obj_fprop = get_ub(ub_name + "_fprop")
ln_out = ub_obj_fprop.get_buffer(input_quantizer, True)
# Apply normalization
ln_out, mu, rsigma = apply_normalization(
inputmat,
ln_out,

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Constructing this in PyTorch is how it was before. Switching it to None to construct it in C++ is triggering an error:

RuntimeError: /home/adener/devroot/nvte-internal/transformer_engine/common/normalization/common.h:197 in function getKernel: Unavailable kernel for this normalization config.

This is happening with TP overlap turned off as well. Retaining the same PyTorch initialization as before works correctly.

@timmoon10 timmoon10 self-requested a review January 28, 2025 22:16
" MULTI-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - FP8 OUT ",
],
)
def test_atomic_gemm_overlaps(ag_type, rs_type, p2p, fp8_out):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we removing these tests?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Atomic GEMM is deprecated in Blackwell and onward, and split GEMM + CUDA Graphs is more performant on Hopper. cuBlasMp does not support it either, so the functionality will disappear for good in TE when we eventually deprecate Userbuffers. So I removed the tests in this PR in the interest of having one less thing to maintain that no longer has any use cases.

@@ -261,6 +261,29 @@ def _create_transpose(self):
self._transpose = tex.fp8_transpose(data, self._fp8_dtype, out=self._transpose)
self._transpose_invalid = False

def _fix_gathered_transpose(self, tp_size=1, from_rowwise=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like how unintuitive and UB-specific this function is. It would have been better to implement CommOverlap::get_buffer so that it internally transposes the UB buffer and constructs a Float8Tensor with only column-wise data (instead of how this moves the row-wise data to the column-wise data). We're not quite there yet since Float8Tensors with only column-wise data experience problems.

For now, we should delete this function and reimplement the logic in the linear modules. We should make clear that this is a hacky workaround and not allow it to become an API we rely on.

@timmoon10 timmoon10 self-requested a review January 28, 2025 23:44
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks reasonable, although I have some stylistic suggestions. This is fine in our last-minute scramble to restore UB support with FP8. Next we will need to think about extending it to support MXFP8 and other quantization schemes.

Comment on lines +20 to +23
#define NOT_IMPLEMENTED_ERROR() NVTE_ERROR("Operation is not implemented.")

#define NOT_SUPPORTED_ERROR() NVTE_ERROR("Operation not supported.")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I don't think these macros are giving us much benefit. We save 20 characters but add another level of indirection.

Suggested change
#define NOT_IMPLEMENTED_ERROR() NVTE_ERROR("Operation is not implemented.")
#define NOT_SUPPORTED_ERROR() NVTE_ERROR("Operation not supported.")

Comment on lines +140 to +141
if ub_type == tex.CommOverlapType.AG and ub.is_p2p_overlap():
out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have a separate path for this UB case? The GEMM invocation seems identical and the swizzling is only relevant for the MXFP8 case.

Not critical right now, but I think we could clean up this entire section:

# Swizzle scaling factors if needed
is_swizzle_disabled = ub_type == tex.CommOverlapType.AG and ub.is_p2p_overlap()
original_scale_inverses = None
if not is_swizzle_disabled:
    original_scale_inverses = swizzle_inputs(A, B, layout)

# Perform GEMM
out, bias_grad, gelu_input, extra_output = tex.generic_gemm(
    A,
    transa,
    B,
    transb,
    ...
)

# Reset swizzled scaling factors
if not is_swizzle_disabled:
    reset_swizzled_inputs(A, B, original_scale_inverses)

return out, bias_grad, gelu_input, extra_output

Comment on lines 112 to +125
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
inputmat = inp
inputmat_total = None
with_input_all_gather = parallel_mode == "column" and sequence_parallel
with_input_all_gather_nccl = (
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
)
own_quantized_input = False
if fp8:
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and isinstance(
FP8GlobalStateManager.get_fp8_recipe(), BlockScaling
):
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as #1427 (comment):

Suggested change
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
inputmat = inp
inputmat_total = None
with_input_all_gather = parallel_mode == "column" and sequence_parallel
with_input_all_gather_nccl = (
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
)
own_quantized_input = False
if fp8:
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and isinstance(
FP8GlobalStateManager.get_fp8_recipe(), BlockScaling
):
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling")
# Check if overlapped communication is supported
if (
fp8
and (ub_overlap_ag_fprop or ub_overlap_rs_fprop)
and not FP8GlobalStateManager.get_fp8_recipe().delayed()
):
raise NotImplementedError("Comm+GEMM overlap is only supported with FP8 delayed scaling")
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
inputmat = inp
inputmat_total = None
with_input_all_gather_nccl = (
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
)
own_quantized_input = False
if fp8:

Comment on lines +323 to +327
and isinstance(FP8GlobalStateManager.get_fp8_recipe(), BlockScaling)
):
recipe = FP8GlobalStateManager.get_fp8_recipe()
print(f"FP8 Recipe: {type(recipe)} -> {recipe}")
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
and isinstance(FP8GlobalStateManager.get_fp8_recipe(), BlockScaling)
):
recipe = FP8GlobalStateManager.get_fp8_recipe()
print(f"FP8 Recipe: {type(recipe)} -> {recipe}")
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling")
and not FP8GlobalStateManager.get_fp8_recipe().delayed()
):
raise NotImplementedError("Comm+GEMM overlap is only supported with FP8 delayed scaling")

Comment on lines +726 to +743
# if opts.fp8:
# dist_print("GEMM1 FP8 metas = [INPUT, WEIGHT, OUTPUT]", src=0, section=True)
# fp8_meta_info = (
# f"amax_reference = {fp8_meta.amax_history[1][:3].tolist()}\n"
# + f"amax_history = {fp8_meta.amax_history[0][:3].tolist()}\n"
# + f"scale = {fp8_meta.scale[:3].tolist()}\n"
# + f"scale_inv = {fp8_meta.scale_inv[:3].tolist()}"
# )
# dist_print(fp8_meta_info, src=0, group=tp_group)
# if ub_obj2 is not None:
# dist_print("GEMM2 FP8 metas = [INPUT, WEIGHT, OUTPUT]", src=0, section=True)
# fp8_meta_info = (
# f"amax_reference = {fp8_meta.amax_history[1][3:].tolist()}\n"
# + f"amax_history = {fp8_meta.amax_history[0][3:].tolist()}\n"
# + f"scale = {fp8_meta.scale[3:].tolist()}\n"
# + f"scale_inv = {fp8_meta.scale_inv[3:].tolist()}"
# )
# dist_print(fp8_meta_info, src=0, group=tp_group)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debugging code:

Suggested change
# if opts.fp8:
# dist_print("GEMM1 FP8 metas = [INPUT, WEIGHT, OUTPUT]", src=0, section=True)
# fp8_meta_info = (
# f"amax_reference = {fp8_meta.amax_history[1][:3].tolist()}\n"
# + f"amax_history = {fp8_meta.amax_history[0][:3].tolist()}\n"
# + f"scale = {fp8_meta.scale[:3].tolist()}\n"
# + f"scale_inv = {fp8_meta.scale_inv[:3].tolist()}"
# )
# dist_print(fp8_meta_info, src=0, group=tp_group)
# if ub_obj2 is not None:
# dist_print("GEMM2 FP8 metas = [INPUT, WEIGHT, OUTPUT]", src=0, section=True)
# fp8_meta_info = (
# f"amax_reference = {fp8_meta.amax_history[1][3:].tolist()}\n"
# + f"amax_history = {fp8_meta.amax_history[0][3:].tolist()}\n"
# + f"scale = {fp8_meta.scale[3:].tolist()}\n"
# + f"scale_inv = {fp8_meta.scale_inv[3:].tolist()}"
# )
# dist_print(fp8_meta_info, src=0, group=tp_group)

Comment on lines +159 to +162
if any([ub_overlap_ag, ub_overlap_rs]) and isinstance(
FP8GlobalStateManager.get_fp8_recipe(), BlockScaling
):
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if any([ub_overlap_ag, ub_overlap_rs]) and isinstance(
FP8GlobalStateManager.get_fp8_recipe(), BlockScaling
):
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling")
if (
(ub_overlap_ag or ub_overlap_rs)
and not FP8GlobalStateManager.get_fp8_recipe().delayed()
):
raise NotImplementedError("Comm+GEMM overlap is only supported with FP8 delayed scaling")

Comment on lines +198 to +215
ub_obj_lnout = None
ln_out = None
if ub_overlap_ag:
ub_obj_lnout = get_ub("fc1_fprop")
ln_out = ub_obj_lnout.get_buffer(fc1_input_quantizer, True)
elif with_quantized_norm:
ln_out = fc1_input_quantizer.make_empty(
inputmat.shape, dtype=inputmat.dtype, device="cuda"
)
else:
ln_out = torch.empty_like(
inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format
)

# Apply normalization
ln_out, mu, rsigma = apply_normalization(
_, mu, rsigma = apply_normalization(
inputmat,
None,
ln_out,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as #1427 (comment):

Suggested change
ub_obj_lnout = None
ln_out = None
if ub_overlap_ag:
ub_obj_lnout = get_ub("fc1_fprop")
ln_out = ub_obj_lnout.get_buffer(fc1_input_quantizer, True)
elif with_quantized_norm:
ln_out = fc1_input_quantizer.make_empty(
inputmat.shape, dtype=inputmat.dtype, device="cuda"
)
else:
ln_out = torch.empty_like(
inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format
)
# Apply normalization
ln_out, mu, rsigma = apply_normalization(
_, mu, rsigma = apply_normalization(
inputmat,
None,
ln_out,
ub_obj_lnout = None
ln_out = None
if ub_overlap_ag:
ub_obj_lnout = get_ub("fc1_fprop")
ln_out = ub_obj_lnout.get_buffer(fc1_input_quantizer, True)
# Apply normalization
ln_out, mu, rsigma = apply_normalization(
inputmat,
ln_out,

Comment on lines +551 to +553
and isinstance(FP8GlobalStateManager.get_fp8_recipe(), BlockScaling)
):
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
and isinstance(FP8GlobalStateManager.get_fp8_recipe(), BlockScaling)
):
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling")
and not FP8GlobalStateManager.get_fp8_recipe().delayed()
):
raise NotImplementedError("Comm+GEMM overlap is only supported with FP8 delayed scaling")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants