From de1ffad350f0e49f2571229e541c5786a76cd8f8 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 15 Jan 2025 13:42:31 +0000 Subject: [PATCH] temp --- src/brevitas/core/zero_point.py | 30 ++++++++++++++ src/brevitas/graph/base.py | 6 +++ src/brevitas/graph/equalize.py | 6 +++ src/brevitas/graph/quantize.py | 2 +- src/brevitas/nn/equalized_layer.py | 4 +- src/brevitas/nn/quant_sdpa.py | 26 +++++++++--- .../common/generative/quantize.py | 23 ++++++++--- src/brevitas_examples/llm/main.py | 40 ++++++++++++------- 8 files changed, 110 insertions(+), 27 deletions(-) diff --git a/src/brevitas/core/zero_point.py b/src/brevitas/core/zero_point.py index f74fffae8..7038fe1a9 100644 --- a/src/brevitas/core/zero_point.py +++ b/src/brevitas/core/zero_point.py @@ -344,3 +344,33 @@ def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor: # pre-zero centering before rounding and clipping z = self.get_zero_center(x) / scale # need to scale the norm by s return z + + +class RuntimeDynamicGroupZeroScaling(brevitas.jit.ScriptModule): + + def __init__( + self, + group_size: int, + group_dim: int, + input_view_impl: Module, + zero_point_stats_impl: Module, + int_quant, + quantize_zero_point) -> None: + super(RuntimeDynamicGroupZeroScaling, self).__init__() + + self.group_size = group_size + self.group_dim = group_dim + self.zero_point_stats_impl = zero_point_stats_impl + self.input_view_impl = input_view_impl + self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point) + + @brevitas.jit.script_method + def forward( + self, + stats_input: torch.Tensor, + scale, + bit_width) -> torch.Tensor: + + stats_input_reshaped = self.input_view_impl(stats_input) + out = self.zero_point_stats_impl(stats_input_reshaped) + return self.scale_shift_zero_point(-out, scale, bit_width) diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index def3f7070..956983cb3 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -6,6 +6,7 @@ import inspect from inspect import getcallargs +from brevitas.nn import ScaledDotProductAttention import torch from torch.nn import Module from torch.overrides import get_testing_overrides @@ -116,6 +117,9 @@ def _map_origin_vars(self, vars: dict): def _module_attributes(self, module): attrs = vars(module) + if isinstance(module, ScaledDotProductAttention): + print(attrs) + # workaround since bias doesn't show up on vars of Linear if hasattr(module, 'bias'): attrs['bias'] = module.bias @@ -142,6 +146,8 @@ def _init_new_module(self, old_module: Module, name=None): new_kwargs = self._module_attributes(old_module) # transforms attribute of original module, e.g. bias Parameter -> bool new_kwargs = self._map_origin_vars(new_kwargs) + if isinstance(old_module, ScaledDotProductAttention): + print(new_kwargs) # restrict to only values that are in the init of the new module new_module_signature_keys = signature_keys(self.new_module_class) new_kwargs = {k: v for k, v in new_kwargs.items() if k in new_module_signature_keys} diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 7cbe38c6a..82f45e44c 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -10,6 +10,7 @@ from typing import Callable, Dict, List, Optional, Set, Tuple, Union import warnings +from brevitas.nn import ScaledDotProductAttention import packaging import packaging.version import torch @@ -1504,6 +1505,11 @@ def find_sink(node): name_to_module={ 'src0': src_module, 'sink0': sink_module}) regions.append(region) + for m in graph_module.modules(): + if isinstance(m, ScaledDotProductAttention): + m.pre_process_q = functional_rotate_input + m.pre_process_k = functional_rotate_input + # m.pre_process_v = partial(functional_rotate_input, transpose=True) return regions def apply(self, diff --git a/src/brevitas/graph/quantize.py b/src/brevitas/graph/quantize.py index 7724e8f9d..6dbdae4cb 100644 --- a/src/brevitas/graph/quantize.py +++ b/src/brevitas/graph/quantize.py @@ -50,7 +50,7 @@ def __init__(self, model: torch.nn.Module, quant_map: Dict, enabled: bool = True self.enabled = enabled for stateless_function, stateless_module in quant_map.items(): if not hasattr(model, str(stateless_function)): - setattr(model, str(stateless_function), stateless_module()) + model.add_module(str(stateless_function), stateless_module()) def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index 8413a8208..e3d930a50 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -81,7 +81,7 @@ def forward(self, inp, **kwargs): def functional_rotate_input(inp, transpose=False): is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None if transpose: - inp = inp.t() + inp = inp.transpose(-2, -1) if is_cuda and fast_hadamard_transform is not None: had_K, K = get_hadK(inp.shape[-1]) inp = matmul_hadU_cuda(inp, had_K, K) @@ -89,5 +89,5 @@ def functional_rotate_input(inp, transpose=False): inp = matmul_hadU(inp) if transpose: - inp = inp.t() + inp = inp.transpose(-2, -1) return inp diff --git a/src/brevitas/nn/quant_sdpa.py b/src/brevitas/nn/quant_sdpa.py index 43f99e827..728d924e9 100644 --- a/src/brevitas/nn/quant_sdpa.py +++ b/src/brevitas/nn/quant_sdpa.py @@ -43,6 +43,8 @@ import math from typing import Optional, Tuple, Union +from brevitas.core.function_wrapper.misc import Identity +from brevitas.function import identity import torch from torch import Tensor from torch.nn import Module @@ -57,6 +59,12 @@ class ScaledDotProductAttention(Module): + def __init__(self, pre_process_q = identity, pre_process_k = identity, pre_process_v = identity): + super().__init__() + self.pre_process_q = pre_process_q + self.pre_process_k = pre_process_k + self.pre_process_v = pre_process_v + def forward( self, query: Tensor, @@ -103,9 +111,9 @@ def forward( if enable_gqa: kwargs["enable_gqa"] = enable_gqa return F.scaled_dot_product_attention( - query=query, - key=key, - value=value, + query=self.pre_process_q(query), + key=self.pre_process_k(key), + value=value,#self.pre_process_v(value), attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, @@ -116,6 +124,7 @@ class QuantScaledDotProductAttention(Module): def __init__( self, + pre_process_q = Identity(), pre_process_k = Identity(), pre_process_v = Identity(), softmax_input_quant=None, attn_output_weights_quant=Uint8ActPerTensorFloat, q_scaled_quant=Int8ActPerTensorFloat, @@ -125,6 +134,11 @@ def __init__( **kwargs) -> None: super(QuantScaledDotProductAttention, self).__init__() + self.pre_process_q = pre_process_q + self.pre_process_k = pre_process_k + self.pre_process_v = pre_process_v + print(self.pre_process_q) + def filter_kwargs(prefix): return {k[len(prefix):]: v for k, v in kwargs.items() if k.startswith(prefix)} @@ -196,14 +210,16 @@ def forward( attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask - q_scaled = self.q_scaled_quant(query * scale_factor) + query, key, value = self.pre_process_q(query), self.pre_process_k(key), self.pre_process_v(value) + q_scaled = query * scale_factor#self.q_scaled_quant(query * scale_factor) k_transpose = self.k_transposed_quant(key.transpose(-2, -1)) attn_weight = q_scaled @ k_transpose attn_weight += attn_bias attn_weight = self.softmax_input_quant(attn_weight) attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) - attn_weight = self.attn_output_weights_quant(attn_weight) + # attn_weight = self.pre_process_q(attn_weight) + # attn_weight = self.attn_output_weights_quant(attn_weight) attn_output = attn_weight @ self.v_quant(value) attn_output = self.sdpa_output_quant(attn_output) return attn_output diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 778955285..40e6063f4 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -4,6 +4,9 @@ """ import re +from brevitas.core.stats import NegativeMinOrZero +from brevitas.quant.base import ParameterFromRuntimeZeroPoint +from dependencies import this import torch from torch import nn @@ -11,7 +14,7 @@ from brevitas.core.function_wrapper import CeilSte from brevitas.core.function_wrapper import FloorSte from brevitas.core.restrict_val import RoundSte -from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint +from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint, RuntimeDynamicGroupZeroScaling from brevitas.graph.quantize import layerwise_quantize from brevitas.quant.experimental.float import Fp8e4m3Act from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat @@ -57,7 +60,7 @@ from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear -from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat +from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat, RuntimeDynamicStatsZeroPoint from brevitas_examples.common.generative.quantizers import FP8e4m3OCPDynamicActPerRowFixedPoint from brevitas_examples.common.generative.quantizers import FP8e4m3OCPDynamicActPerRowFloat from brevitas_examples.common.generative.quantizers import Fp8e4m3OCPWeightPerChannelFixedPointMSE @@ -149,6 +152,15 @@ 'per_channel': { 'sym': Fp8e4m3FNUZWeightPerChannelFloat}}}}} +class Test(Int8DynamicActPerGroupFloat): + # zero_point_impl = RuntimeDynamicStatsZeroPoint + zero_point_impl = RuntimeDynamicGroupZeroScaling + zero_point_stats_impl = NegativeMinOrZero + scaling_stats_op = 'min_max' + signed = False + # zero_point_shape = this.scaling_shape + # zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl + INPUT_QUANT_MAP = { 'int': { 'static': { @@ -177,7 +189,8 @@ 'sym': Int8DynamicActPerRowFloat, 'asym': ShiftedUint8DynamicActPerRowFloat}, 'per_group': { - 'sym': Int8DynamicActPerGroupFloat}}}, + 'sym': Int8DynamicActPerGroupFloat, + 'asym': Test}}}, 'po2_scale': { 'stats': { 'per_row': { @@ -388,10 +401,10 @@ def generate_quantizers( elif input_quant_granularity == 'per_group': q_scaled_quant = sym_input_quant.let( **{ - 'group_dim': 2, 'group_size': input_group_size}) + 'group_dim': -1, 'group_size': input_group_size}) k_transposed_quant = sym_input_quant.let( **{ - 'group_dim': 1, 'group_size': input_group_size}) + 'group_dim': -1, 'group_size': input_group_size}) v_quant = q_scaled_quant attn_output_weights_quant = q_scaled_quant else: diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 4f03ba087..97f27596f 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -81,8 +81,17 @@ def set_seed(seed): def fused_rotation_no_fx(model, calibration_loader, args): + print("Here") with torch.no_grad(): new_model, guards = torch._dynamo.export(model)(**calibration_loader[0]) + print(getattr(model, str(torch.nn.functional.scaled_dot_product_attention))) + if hasattr(model, str(torch.nn.functional.scaled_dot_product_attention)): + m_to_add = getattr(model, str(torch.nn.functional.scaled_dot_product_attention)) + new_model.add_module(str(torch.nn.functional.scaled_dot_product_attention), m_to_add) + # for m in new_model.modules(): + # print(type(m)) + # if hasattr(m, 'pre_process_q'): + # raise apply_layernorm_affine_merge(new_model) new_model, rewriters = apply_layernorm_to_rmsnorm(new_model, return_rewriters=True) rewriters = fix_rewriter(rewriters, model, 'weight') @@ -300,19 +309,7 @@ def quantize_llm(args): apply_layernorm_to_rmsnorm(model) print("Layernorm To RMSNorm applied.") - if args.rotation == 'fx': - model = offload_model(model) - eq = GraphRotationEqualization( - orphan_sink=args.rotation_orphan_sink, - full_rotation_method=args.rotation_mode, - sdpa_regions=args.rotation_sdpa_regions) - model = eq.apply(model) - remove_hooks(model) - elif args.rotation == 'layerwise': - eq = LayerwiseActivationRotation() - model = eq.apply(model) - elif args.rotation == 'fused_no_fx': - fused_rotation_no_fx(model, calibration_loader, args) + # Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing # with all the variability in HF implementations @@ -330,6 +327,21 @@ def quantize_llm(args): with torch.no_grad(), functional_quantization_mode(model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention}): model(**calibration_loader[0]) remove_hooks(model) + + if args.rotation == 'fx': + model = offload_model(model) + eq = GraphRotationEqualization( + orphan_sink=args.rotation_orphan_sink, + full_rotation_method=args.rotation_mode, + sdpa_regions=args.rotation_sdpa_regions) + model = eq.apply(model) + remove_hooks(model) + elif args.rotation == 'layerwise': + eq = LayerwiseActivationRotation() + model = eq.apply(model) + elif args.rotation == 'fused_no_fx': + fused_rotation_no_fx(model, calibration_loader, args) + if args.weight_equalization: print("Apply weight equalization...") # In case of float16 model, we need to offload to account for missing ops @@ -518,7 +530,7 @@ def quantize_llm(args): print(f"Saving checkpoint to {args.checkpoint_name}") torch.save(model.state_dict(), args.checkpoint_name) - if args.eval and not args.no_quantize: + if args.eval:# and not args.no_quantize: print("Model eval...") with torch.no_grad(), quant_inference_mode(model): model(**calibration_loader[0])