Skip to content

Commit

Permalink
temp
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 15, 2025
1 parent c44566d commit de1ffad
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 27 deletions.
30 changes: 30 additions & 0 deletions src/brevitas/core/zero_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,33 @@ def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor:
# pre-zero centering before rounding and clipping
z = self.get_zero_center(x) / scale # need to scale the norm by s
return z


class RuntimeDynamicGroupZeroScaling(brevitas.jit.ScriptModule):

def __init__(
self,
group_size: int,
group_dim: int,
input_view_impl: Module,
zero_point_stats_impl: Module,
int_quant,
quantize_zero_point) -> None:
super(RuntimeDynamicGroupZeroScaling, self).__init__()

self.group_size = group_size
self.group_dim = group_dim
self.zero_point_stats_impl = zero_point_stats_impl
self.input_view_impl = input_view_impl
self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point)

@brevitas.jit.script_method
def forward(
self,
stats_input: torch.Tensor,
scale,
bit_width) -> torch.Tensor:

stats_input_reshaped = self.input_view_impl(stats_input)
out = self.zero_point_stats_impl(stats_input_reshaped)
return self.scale_shift_zero_point(-out, scale, bit_width)
6 changes: 6 additions & 0 deletions src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import inspect
from inspect import getcallargs

from brevitas.nn import ScaledDotProductAttention
import torch
from torch.nn import Module
from torch.overrides import get_testing_overrides
Expand Down Expand Up @@ -116,6 +117,9 @@ def _map_origin_vars(self, vars: dict):

def _module_attributes(self, module):
attrs = vars(module)
if isinstance(module, ScaledDotProductAttention):
print(attrs)

# workaround since bias doesn't show up on vars of Linear
if hasattr(module, 'bias'):
attrs['bias'] = module.bias
Expand All @@ -142,6 +146,8 @@ def _init_new_module(self, old_module: Module, name=None):
new_kwargs = self._module_attributes(old_module)
# transforms attribute of original module, e.g. bias Parameter -> bool
new_kwargs = self._map_origin_vars(new_kwargs)
if isinstance(old_module, ScaledDotProductAttention):
print(new_kwargs)
# restrict to only values that are in the init of the new module
new_module_signature_keys = signature_keys(self.new_module_class)
new_kwargs = {k: v for k, v in new_kwargs.items() if k in new_module_signature_keys}
Expand Down
6 changes: 6 additions & 0 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
import warnings

from brevitas.nn import ScaledDotProductAttention
import packaging
import packaging.version
import torch
Expand Down Expand Up @@ -1504,6 +1505,11 @@ def find_sink(node):
name_to_module={
'src0': src_module, 'sink0': sink_module})
regions.append(region)
for m in graph_module.modules():
if isinstance(m, ScaledDotProductAttention):
m.pre_process_q = functional_rotate_input
m.pre_process_k = functional_rotate_input
# m.pre_process_v = partial(functional_rotate_input, transpose=True)
return regions

def apply(self,
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, model: torch.nn.Module, quant_map: Dict, enabled: bool = True
self.enabled = enabled
for stateless_function, stateless_module in quant_map.items():
if not hasattr(model, str(stateless_function)):
setattr(model, str(stateless_function), stateless_module())
model.add_module(str(stateless_function), stateless_module())

def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/nn/equalized_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ def forward(self, inp, **kwargs):
def functional_rotate_input(inp, transpose=False):
is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None
if transpose:
inp = inp.t()
inp = inp.transpose(-2, -1)
if is_cuda and fast_hadamard_transform is not None:
had_K, K = get_hadK(inp.shape[-1])
inp = matmul_hadU_cuda(inp, had_K, K)
else:
inp = matmul_hadU(inp)

if transpose:
inp = inp.t()
inp = inp.transpose(-2, -1)
return inp
26 changes: 21 additions & 5 deletions src/brevitas/nn/quant_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
import math
from typing import Optional, Tuple, Union

from brevitas.core.function_wrapper.misc import Identity
from brevitas.function import identity
import torch
from torch import Tensor
from torch.nn import Module
Expand All @@ -57,6 +59,12 @@

class ScaledDotProductAttention(Module):

def __init__(self, pre_process_q = identity, pre_process_k = identity, pre_process_v = identity):
super().__init__()
self.pre_process_q = pre_process_q
self.pre_process_k = pre_process_k
self.pre_process_v = pre_process_v

def forward(
self,
query: Tensor,
Expand Down Expand Up @@ -103,9 +111,9 @@ def forward(
if enable_gqa:
kwargs["enable_gqa"] = enable_gqa
return F.scaled_dot_product_attention(
query=query,
key=key,
value=value,
query=self.pre_process_q(query),
key=self.pre_process_k(key),
value=value,#self.pre_process_v(value),
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
Expand All @@ -116,6 +124,7 @@ class QuantScaledDotProductAttention(Module):

def __init__(
self,
pre_process_q = Identity(), pre_process_k = Identity(), pre_process_v = Identity(),
softmax_input_quant=None,
attn_output_weights_quant=Uint8ActPerTensorFloat,
q_scaled_quant=Int8ActPerTensorFloat,
Expand All @@ -125,6 +134,11 @@ def __init__(
**kwargs) -> None:
super(QuantScaledDotProductAttention, self).__init__()

self.pre_process_q = pre_process_q
self.pre_process_k = pre_process_k
self.pre_process_v = pre_process_v
print(self.pre_process_q)

def filter_kwargs(prefix):
return {k[len(prefix):]: v for k, v in kwargs.items() if k.startswith(prefix)}

Expand Down Expand Up @@ -196,14 +210,16 @@ def forward(
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
q_scaled = self.q_scaled_quant(query * scale_factor)
query, key, value = self.pre_process_q(query), self.pre_process_k(key), self.pre_process_v(value)
q_scaled = query * scale_factor#self.q_scaled_quant(query * scale_factor)
k_transpose = self.k_transposed_quant(key.transpose(-2, -1))
attn_weight = q_scaled @ k_transpose
attn_weight += attn_bias
attn_weight = self.softmax_input_quant(attn_weight)
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
attn_weight = self.attn_output_weights_quant(attn_weight)
# attn_weight = self.pre_process_q(attn_weight)
# attn_weight = self.attn_output_weights_quant(attn_weight)
attn_output = attn_weight @ self.v_quant(value)
attn_output = self.sdpa_output_quant(attn_output)
return attn_output
23 changes: 18 additions & 5 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
"""
import re

from brevitas.core.stats import NegativeMinOrZero
from brevitas.quant.base import ParameterFromRuntimeZeroPoint
from dependencies import this
import torch
from torch import nn

from brevitas import nn as qnn
from brevitas.core.function_wrapper import CeilSte
from brevitas.core.function_wrapper import FloorSte
from brevitas.core.restrict_val import RoundSte
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint, RuntimeDynamicGroupZeroScaling
from brevitas.graph.quantize import layerwise_quantize
from brevitas.quant.experimental.float import Fp8e4m3Act
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat
Expand Down Expand Up @@ -57,7 +60,7 @@
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE
from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d
from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear
from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat
from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat, RuntimeDynamicStatsZeroPoint
from brevitas_examples.common.generative.quantizers import FP8e4m3OCPDynamicActPerRowFixedPoint
from brevitas_examples.common.generative.quantizers import FP8e4m3OCPDynamicActPerRowFloat
from brevitas_examples.common.generative.quantizers import Fp8e4m3OCPWeightPerChannelFixedPointMSE
Expand Down Expand Up @@ -149,6 +152,15 @@
'per_channel': {
'sym': Fp8e4m3FNUZWeightPerChannelFloat}}}}}

class Test(Int8DynamicActPerGroupFloat):
# zero_point_impl = RuntimeDynamicStatsZeroPoint
zero_point_impl = RuntimeDynamicGroupZeroScaling
zero_point_stats_impl = NegativeMinOrZero
scaling_stats_op = 'min_max'
signed = False
# zero_point_shape = this.scaling_shape
# zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl

INPUT_QUANT_MAP = {
'int': {
'static': {
Expand Down Expand Up @@ -177,7 +189,8 @@
'sym': Int8DynamicActPerRowFloat,
'asym': ShiftedUint8DynamicActPerRowFloat},
'per_group': {
'sym': Int8DynamicActPerGroupFloat}}},
'sym': Int8DynamicActPerGroupFloat,
'asym': Test}}},
'po2_scale': {
'stats': {
'per_row': {
Expand Down Expand Up @@ -388,10 +401,10 @@ def generate_quantizers(
elif input_quant_granularity == 'per_group':
q_scaled_quant = sym_input_quant.let(
**{
'group_dim': 2, 'group_size': input_group_size})
'group_dim': -1, 'group_size': input_group_size})
k_transposed_quant = sym_input_quant.let(
**{
'group_dim': 1, 'group_size': input_group_size})
'group_dim': -1, 'group_size': input_group_size})
v_quant = q_scaled_quant
attn_output_weights_quant = q_scaled_quant
else:
Expand Down
40 changes: 26 additions & 14 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,17 @@ def set_seed(seed):


def fused_rotation_no_fx(model, calibration_loader, args):
print("Here")
with torch.no_grad():
new_model, guards = torch._dynamo.export(model)(**calibration_loader[0])
print(getattr(model, str(torch.nn.functional.scaled_dot_product_attention)))
if hasattr(model, str(torch.nn.functional.scaled_dot_product_attention)):
m_to_add = getattr(model, str(torch.nn.functional.scaled_dot_product_attention))
new_model.add_module(str(torch.nn.functional.scaled_dot_product_attention), m_to_add)
# for m in new_model.modules():
# print(type(m))
# if hasattr(m, 'pre_process_q'):
# raise
apply_layernorm_affine_merge(new_model)
new_model, rewriters = apply_layernorm_to_rmsnorm(new_model, return_rewriters=True)
rewriters = fix_rewriter(rewriters, model, 'weight')
Expand Down Expand Up @@ -300,19 +309,7 @@ def quantize_llm(args):
apply_layernorm_to_rmsnorm(model)
print("Layernorm To RMSNorm applied.")

if args.rotation == 'fx':
model = offload_model(model)
eq = GraphRotationEqualization(
orphan_sink=args.rotation_orphan_sink,
full_rotation_method=args.rotation_mode,
sdpa_regions=args.rotation_sdpa_regions)
model = eq.apply(model)
remove_hooks(model)
elif args.rotation == 'layerwise':
eq = LayerwiseActivationRotation()
model = eq.apply(model)
elif args.rotation == 'fused_no_fx':
fused_rotation_no_fx(model, calibration_loader, args)


# Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing
# with all the variability in HF implementations
Expand All @@ -330,6 +327,21 @@ def quantize_llm(args):
with torch.no_grad(), functional_quantization_mode(model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention}):
model(**calibration_loader[0])
remove_hooks(model)

if args.rotation == 'fx':
model = offload_model(model)
eq = GraphRotationEqualization(
orphan_sink=args.rotation_orphan_sink,
full_rotation_method=args.rotation_mode,
sdpa_regions=args.rotation_sdpa_regions)
model = eq.apply(model)
remove_hooks(model)
elif args.rotation == 'layerwise':
eq = LayerwiseActivationRotation()
model = eq.apply(model)
elif args.rotation == 'fused_no_fx':
fused_rotation_no_fx(model, calibration_loader, args)

if args.weight_equalization:
print("Apply weight equalization...")
# In case of float16 model, we need to offload to account for missing ops
Expand Down Expand Up @@ -518,7 +530,7 @@ def quantize_llm(args):
print(f"Saving checkpoint to {args.checkpoint_name}")
torch.save(model.state_dict(), args.checkpoint_name)

if args.eval and not args.no_quantize:
if args.eval:# and not args.no_quantize:
print("Model eval...")
with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
Expand Down

0 comments on commit de1ffad

Please sign in to comment.