Skip to content

Commit

Permalink
Ensure that tied parameter is children of module (#3327)
Browse files Browse the repository at this point in the history
Ensure that tied parameters are assigned to their parent module in
get_module_size_with_ties

Fixes: #3308
  • Loading branch information
pablomlago authored Jan 9, 2025
1 parent 54370d4 commit 58f1436
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,7 +1183,7 @@ def get_module_size_with_ties(
tied_modules = []

for tied_param in tied_params:
tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n in tied_param][0]
tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if tied_param.startswith(n + ".")][0]
tied_module_names.append(modules_to_treat[tied_module_index][0])
tied_modules.append(modules_to_treat[tied_module_index][1])

Expand Down
45 changes: 45 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
convert_file_size_to_int,
find_tied_parameters,
get_balanced_memory,
get_module_size_with_ties,
get_state_dict_offloaded_model,
infer_auto_device_map,
load_checkpoint_in_model,
Expand Down Expand Up @@ -882,6 +883,50 @@ def test_get_balanced_memory(self):
max_memory = get_balanced_memory(model, max_memory={0: 0, "cpu": 100})
assert {0: 0, "cpu": 100} == max_memory

# Tests that get_module_size_with_ties returns the correct tied modules in
# models with tied parameters whose parent modules share the same name prefix
# See issue #3308: https://github.com/huggingface/accelerate/issues/3308
def test_get_module_size_with_ties(self):
# Create a model with a ModuleList containing more than 10 elements
# so the names of some layers share the same prefix, e.g. "1" and "10"
num_layers = 15
model = nn.ModuleList([nn.Linear(10, 10) for _ in range(num_layers)])
# Tie .weight for all the layers
for i in range(1, num_layers):
model[i].weight = model[i - 1].weight
# Each tied parameter group is sorted in alphabetical ordering,
# mimicking the output of find_tied_parameters
tied_parameters = [sorted([f"{i}.weight" for i in range(num_layers)])]
# Compute module sizes
weight_size, bias_size = (
model[0].weight.element_size() * model[0].weight.numel(),
model[0].bias.element_size() * model[0].bias.numel(),
)
module_sizes = dict(
**{"": num_layers * (weight_size + bias_size)},
**{f"{i}": (weight_size + bias_size) for i in range(num_layers)},
**{f"{i}.weight": weight_size for i in range(num_layers)},
**{f"{i}.bias": bias_size for i in range(num_layers)},
)
# Simulate the input for get_module_size_with_ties when invoked from infer_auto_device_map
# when the first module in model is being processed
modules_to_treat = list(model.named_children())[1:]
tied_params = tied_parameters[0][1:]
module_size = weight_size + bias_size

module_size_with_ties, tied_module_names, tied_modules = get_module_size_with_ties(
tied_params, module_size, module_sizes, modules_to_treat
)
# The expected lists are ordered using as key the module names, to follow
# the same order as the tied_parameters returned by find_tied_parameters
expected_tied_module_names, expected_tied_modules = map(
list, zip(*sorted(modules_to_treat, key=lambda x: x[0]))
)

assert module_size_with_ties == module_size + (num_layers - 1) * bias_size
assert tied_module_names == expected_tied_module_names
assert tied_modules == expected_tied_modules

@require_non_cpu
def test_load_state_dict(self):
state_dict = {k: torch.randn(4, 5) for k in ["a", "b", "c"]}
Expand Down

0 comments on commit 58f1436

Please sign in to comment.