Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 15, 2025
1 parent a21b771 commit 925c3a5
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 14 deletions.
6 changes: 1 addition & 5 deletions src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from inspect import getcallargs
from typing import Any, Callable, Dict, Optional, Type, Union

from brevitas.nn import ScaledDotProductAttention
import torch
from torch import Tensor
from torch.nn import Module
Expand All @@ -19,6 +18,7 @@
from brevitas.fx import immutable_dict
from brevitas.fx import Node
from brevitas.graph.utils import *
from brevitas.nn import ScaledDotProductAttention
from brevitas.utils.python_utils import islambda
from brevitas.utils.rotation_utils import RotationWeightParametrization

Expand Down Expand Up @@ -122,8 +122,6 @@ def _map_origin_vars(self, vars: dict):

def _module_attributes(self, module):
attrs = vars(module)
if isinstance(module, ScaledDotProductAttention):
print(attrs)

# workaround since bias doesn't show up on vars of Linear
if hasattr(module, 'bias'):
Expand Down Expand Up @@ -151,8 +149,6 @@ def _init_new_module(self, old_module: Module, name=None):
new_kwargs = self._module_attributes(old_module)
# transforms attribute of original module, e.g. bias Parameter -> bool
new_kwargs = self._map_origin_vars(new_kwargs)
if isinstance(old_module, ScaledDotProductAttention):
print(new_kwargs)
# restrict to only values that are in the init of the new module
new_module_signature_keys = signature_keys(self.new_module_class)
new_kwargs = {k: v for k, v in new_kwargs.items() if k in new_module_signature_keys}
Expand Down
11 changes: 2 additions & 9 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,12 @@ def set_seed(seed):


def fused_rotation_no_fx(model, calibration_loader, args):
print("Here")
with torch.no_grad():
new_model, guards = torch._dynamo.export(model)(**calibration_loader[0])
print(getattr(model, str(torch.nn.functional.scaled_dot_product_attention)))
if hasattr(model, str(torch.nn.functional.scaled_dot_product_attention)):
m_to_add = getattr(model, str(torch.nn.functional.scaled_dot_product_attention))
new_model.add_module(str(torch.nn.functional.scaled_dot_product_attention), m_to_add)
# for m in new_model.modules():
# print(type(m))
# if hasattr(m, 'pre_process_q'):
# raise

apply_layernorm_affine_merge(new_model)
new_model, rewriters = apply_layernorm_to_rmsnorm(new_model, return_rewriters=True)
rewriters = fix_rewriter(rewriters, model, 'weight')
Expand Down Expand Up @@ -312,8 +307,6 @@ def quantize_llm(args):
apply_layernorm_to_rmsnorm(model)
print("Layernorm To RMSNorm applied.")



# Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing
# with all the variability in HF implementations
if args.replace_mha:
Expand Down Expand Up @@ -533,7 +526,7 @@ def quantize_llm(args):
print(f"Saving checkpoint to {args.checkpoint_name}")
torch.save(model.state_dict(), args.checkpoint_name)

if args.eval:# and not args.no_quantize:
if args.eval and not args.no_quantize:
print("Model eval...")
with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
Expand Down

0 comments on commit 925c3a5

Please sign in to comment.