From 925c3a553c2df6d4504174eeb79da4b87fe461cf Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 15 Jan 2025 18:02:57 +0000 Subject: [PATCH] fix --- src/brevitas/graph/base.py | 6 +----- src/brevitas_examples/llm/main.py | 11 ++--------- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index f9307e46e..13cd19d2e 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -7,7 +7,6 @@ from inspect import getcallargs from typing import Any, Callable, Dict, Optional, Type, Union -from brevitas.nn import ScaledDotProductAttention import torch from torch import Tensor from torch.nn import Module @@ -19,6 +18,7 @@ from brevitas.fx import immutable_dict from brevitas.fx import Node from brevitas.graph.utils import * +from brevitas.nn import ScaledDotProductAttention from brevitas.utils.python_utils import islambda from brevitas.utils.rotation_utils import RotationWeightParametrization @@ -122,8 +122,6 @@ def _map_origin_vars(self, vars: dict): def _module_attributes(self, module): attrs = vars(module) - if isinstance(module, ScaledDotProductAttention): - print(attrs) # workaround since bias doesn't show up on vars of Linear if hasattr(module, 'bias'): @@ -151,8 +149,6 @@ def _init_new_module(self, old_module: Module, name=None): new_kwargs = self._module_attributes(old_module) # transforms attribute of original module, e.g. bias Parameter -> bool new_kwargs = self._map_origin_vars(new_kwargs) - if isinstance(old_module, ScaledDotProductAttention): - print(new_kwargs) # restrict to only values that are in the init of the new module new_module_signature_keys = signature_keys(self.new_module_class) new_kwargs = {k: v for k, v in new_kwargs.items() if k in new_module_signature_keys} diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 1ebcc9180..5e7e3db93 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -82,17 +82,12 @@ def set_seed(seed): def fused_rotation_no_fx(model, calibration_loader, args): - print("Here") with torch.no_grad(): new_model, guards = torch._dynamo.export(model)(**calibration_loader[0]) - print(getattr(model, str(torch.nn.functional.scaled_dot_product_attention))) if hasattr(model, str(torch.nn.functional.scaled_dot_product_attention)): m_to_add = getattr(model, str(torch.nn.functional.scaled_dot_product_attention)) new_model.add_module(str(torch.nn.functional.scaled_dot_product_attention), m_to_add) - # for m in new_model.modules(): - # print(type(m)) - # if hasattr(m, 'pre_process_q'): - # raise + apply_layernorm_affine_merge(new_model) new_model, rewriters = apply_layernorm_to_rmsnorm(new_model, return_rewriters=True) rewriters = fix_rewriter(rewriters, model, 'weight') @@ -312,8 +307,6 @@ def quantize_llm(args): apply_layernorm_to_rmsnorm(model) print("Layernorm To RMSNorm applied.") - - # Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing # with all the variability in HF implementations if args.replace_mha: @@ -533,7 +526,7 @@ def quantize_llm(args): print(f"Saving checkpoint to {args.checkpoint_name}") torch.save(model.state_dict(), args.checkpoint_name) - if args.eval:# and not args.no_quantize: + if args.eval and not args.no_quantize: print("Model eval...") with torch.no_grad(), quant_inference_mode(model): model(**calibration_loader[0])