Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Jan 16, 2025
1 parent 406c807 commit 15cb1bb
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions src/brevitas/graph/quantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def find_module(
Specifically, it allows to map nn.MultiheadAttetion to its quantized counterpart and not its
Linear submodules.
"""
if _module_class_name(parametrize.type_before_parametrizations(model)) in layer_map.keys():
if _module_class_name(type_before_parametrizations(model)) in layer_map.keys():
module_to_replace.append(model)
else:
for name, module in model.named_children():
Expand All @@ -537,9 +537,8 @@ def layerwise_layer_handler(
find_module(model, layer_map, module_to_replace, name_blacklist)
rewriters = []
for module in module_to_replace:
if layer_map[_module_class_name(
parametrize.type_before_parametrizations(module))] is not None:
quant_module_class, quant_module_kwargs = layer_map[_module_class_name(parametrize.type_before_parametrizations(module))]
if layer_map[_module_class_name(type_before_parametrizations(module))] is not None:
quant_module_class, quant_module_kwargs = layer_map[_module_class_name(type_before_parametrizations(module))]
rewriter = ModuleToModuleByInstance(module, quant_module_class, **quant_module_kwargs)
rewriters.append(rewriter)
for rewriter in rewriters:
Expand Down

0 comments on commit 15cb1bb

Please sign in to comment.