Skip to content

Commit

Permalink
fix: Enable non-strict loading of state dicts
Browse files Browse the repository at this point in the history
Resolves #278

PyTorch allows to load state dicts with they strict=False argument to
ignore missing keys. This is now also supported in optimum-quanto.
Before this fix, a KeyError would be raised.

One context where this is important is for parameter-efficient
fine-tuning adapters such as LoRA. There, we want to load only a small
subset of parameters and leave the other model weights untouched. This
requires non-strict loading.
  • Loading branch information
BenjaminBossan authored and dacorvo committed Aug 27, 2024
1 parent f3b39ce commit f9b71f4
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 9 deletions.
10 changes: 7 additions & 3 deletions optimum/quanto/nn/qmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def _load_from_state_dict(
if self.weight_qtype is not None and weight_name not in state_dict:
# The weight Tensor is not present because it is a flattened QTensor
weight_prefix = weight_name + "."
# note: deserialized_weight can be None if a key is missing in the state_dict
if self.weight_qtype.bits == 8:
deserialized_weight = WeightQBytesTensor.load_from_state_dict(
state_dict,
Expand All @@ -165,6 +166,7 @@ def _load_from_state_dict(
axis=0,
size=self.weight.size(),
stride=self.weight.stride(),
missing_keys=missing_keys,
)
else:
deserialized_weight = QBitsTensor.load_from_state_dict(
Expand All @@ -175,13 +177,15 @@ def _load_from_state_dict(
group_size=self.weight_group_size,
size=self.weight.size(),
stride=self.weight.stride(),
missing_keys=missing_keys,
)
deserialized_weight = deserialized_weight.optimize()
if deserialized_weight is not None:
deserialized_weight = deserialized_weight.optimize()

assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
if assign_to_params_buffers:
if assign_to_params_buffers and (deserialized_weight is not None):
self.weight = torch.nn.Parameter(deserialized_weight)
else:
elif deserialized_weight is not None:
if type(self.weight.data) is not type(deserialized_weight):
# Reloading frozen weights into unfrozen module: move to the correct device and force assignment
self.weight = torch.nn.Parameter(deserialized_weight.to(self.weight.device))
Expand Down
6 changes: 5 additions & 1 deletion optimum/quanto/tensor/qbits/packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,11 @@ def dtype(self):
return torch.uint8

@staticmethod
def load_from_state_dict(state_dict, prefix, bits, size, stride):
def load_from_state_dict(state_dict, prefix, bits, size, stride, missing_keys):
if prefix + "_data" not in state_dict:
missing_keys.append(prefix + "_data")
return

inner_tensors_dict = {"_data": state_dict.pop(prefix + "_data")}
meta = [name.replace(prefix, "") for name in state_dict.keys() if name.startswith(prefix)]
meta = {"bits": str(bits), "size": str(list(size)), "stride": str(stride)}
Expand Down
15 changes: 12 additions & 3 deletions optimum/quanto/tensor/qbits/qbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def dequantize(self):
return QBitsDequantizer.apply(self)

@staticmethod
def load_from_state_dict(state_dict, prefix, qtype, axis, group_size, size, stride):
def load_from_state_dict(state_dict, prefix, qtype, axis, group_size, size, stride, missing_keys):
if group_size is None:
data_size = size
data_stride = stride
Expand All @@ -176,11 +176,20 @@ def load_from_state_dict(state_dict, prefix, qtype, axis, group_size, size, stri
data_stride = (data_size[1], 1)
inner_tensors_dict = {
"_data": PackedTensor.load_from_state_dict(
state_dict, prefix + "_data.", qtype.bits, data_size, data_stride
state_dict, prefix + "_data.", qtype.bits, data_size, data_stride, missing_keys=missing_keys
)
}
missing = inner_tensors_dict["_data"] is None
for name in ["_scale", "_shift"]:
inner_tensors_dict[name] = state_dict.pop(prefix + name)
if prefix + name not in state_dict:
missing_keys.append(prefix + name)
missing = True
else:
inner_tensors_dict[name] = state_dict.pop(prefix + name)

if missing: # could not deserialize because of missing keys
return None

meta = {
"qtype": qtype.name,
"axis": str(axis),
Expand Down
13 changes: 11 additions & 2 deletions optimum/quanto/tensor/weights/qbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,19 @@ def quantize(cls, base: torch.Tensor, qtype: qtype, axis: int, scale: torch.Tens
return WeightQBytesQuantizer.apply(base, qtype, axis, scale)

@staticmethod
def load_from_state_dict(state_dict, prefix, qtype, axis, size, stride):
def load_from_state_dict(state_dict, prefix, qtype, axis, size, stride, missing_keys):
inner_tensors_dict = {}
missing = False
for name in ["_data", "_scale"]:
inner_tensors_dict[name] = state_dict.pop(prefix + name)
if prefix + name not in state_dict:
missing_keys.append(prefix + name)
missing = True
else:
inner_tensors_dict[name] = state_dict.pop(prefix + name)

if missing: # could not deserialize because of missing keys
return None

meta = {
"qtype": qtype.name,
"axis": str(axis),
Expand Down
21 changes: 21 additions & 0 deletions test/models/test_quantized_model_for_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,24 @@ def test_causal_lm_base_push_to_hub(staging, in_org):
compare_models(quantized, requantized)

delete_repo(hub_repo_id, token=staging["token"])


@pytest.mark.skipif(not is_transformers_available(), reason="requires transformers")
@pytest.mark.parametrize("model_id", ["facebook/opt-125m"])
@pytest.mark.parametrize("qtype", [qint4, qint8], ids=["qint4", "qint8"])
def test_quantized_model_load_state_dict_non_strict(model_id, qtype):
# see issue #278
quantized = quantized_model_for_causal_lm(model_id, qtype, exclude=None)
sd = quantized.state_dict()

# delete a key used by both qint4 and qint8 from the state dict
key = "model.decoder.layers.0.self_attn.k_proj.weight._scale"
del sd[key]

# strict loading should raise a RuntimeError, which is what PyTorch does in this case
with pytest.raises(RuntimeError, match=key):
quantized.load_state_dict(sd)

# non-strict loading should not raise an errror
result = quantized.load_state_dict(sd, strict=False)
assert result.missing_keys == [key]

0 comments on commit f9b71f4

Please sign in to comment.