Skip to content

Commit

Permalink
Update main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 17, 2024
1 parent 4d042b7 commit 7b68194
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,19 +383,23 @@ def main(args):

dict_hooks = dict()

def update_params_post_init(module):
update_internal_dict(module)

# When offloading to CPU + GPU, the CPU scale factors must be updated
# before we move them back to the meta device.
# If we don't, we lose the new value but the internal flag "init_done" is True, thus we will use the wrong scale.
# To do this, we attach a "hook" to the post_forward function, called before the post_forward
# The function will update the dict with the initialized scales
for m in model.modules():
if hasattr(m, '_hf_hook'):
if m._hf_hook.weights_map is not None:
# We store the original function to be restored later
dict_hooks[m] = m._hf_hook.post_forward
new_funct = functools.partial(update_params_post_init, m)
new_funct = functools.partial(update_internal_dict, m)
m._hf_hook.post_forward = hooked_on_a_function(m._hf_hook.post_forward, new_funct)

with torch.no_grad():
model(**calibration_loader[0])

# We restore the original behaviour of the post-forward.
for k, v in dict_hooks.items():
k._hf_hook.post_forward = v

Expand Down

0 comments on commit 7b68194

Please sign in to comment.