Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add skip_missing_weights_context #1017

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import contextlib
import logging
import os
import re
import weakref
Expand Down Expand Up @@ -148,15 +150,6 @@ def save_pretrained_wrapper(
# https://github.com/huggingface/transformers/pull/30488
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size

def skip(*args, **kwargs):
pass

# Skip the initializer step. This accelerates the loading
# of the models, especially for the quantized models
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip

# state_dict gets passed in as a kwarg for FSDP models
state_dict = kwargs.pop("state_dict", None)
if state_dict is None:
Expand Down Expand Up @@ -305,3 +298,40 @@ def get_model_compressor(
sparsity_config=sparsity_config,
quantization_format=quantization_format,
)


@contextlib.contextmanager
def skip_missing_weights_context():
"""
Used when loading a quantized model whose state dict does not align model
definition weights
"""
kaiming_uniform_ = torch.nn.init.kaiming_uniform_
uniform_ = torch.nn.init.uniform_
normal_ = torch.nn.init.normal_

transformers_logger = logging.getLogger("transformers.modeling_utils")
restore_log_level = transformers_logger.getEffectiveLevel()

# skip init functions
def skip(*args, **kwargs):
pass

torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
# TODO: consider skipping other default init functions

# temporarily set the log level to error, to ignore printing out long missing
# and unexpected key error messages (these are EXPECTED for quantized models)
transformers_logger.setLevel(level=logging.ERROR)

yield

# restore original functions
torch.nn.init.kaiming_uniform_ = kaiming_uniform_
torch.nn.init.uniform_ = uniform_
torch.nn.init.normal_ = normal_

# restore transformers logging level now that model shell is loaded
transformers_logger.setLevel(level=restore_log_level)
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import math
import shutil

Expand All @@ -22,6 +21,7 @@
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_save_pretrained,
patch_tied_tensors_bug,
skip_missing_weights_context,
)


Expand Down Expand Up @@ -63,18 +63,10 @@ def test_sparse_model_reload(compressed, config, dtype, tmp_path):
clear_sparse_session=False,
)

# temporarily set the log level to error, to ignore printing out long missing
# and unexpected key error messages (these are EXPECTED for quantized models)
transformers_logger = logging.getLogger("transformers.modeling_utils")
restore_log_level = transformers_logger.getEffectiveLevel()
transformers_logger.setLevel(level=logging.ERROR)

model = AutoModelForCausalLM.from_pretrained(
tmp_path / "oneshot_out", torch_dtype=dtype
)

# restore transformers logging level now that model shell is loaded
transformers_logger.setLevel(level=restore_log_level)
with skip_missing_weights_context():
model = AutoModelForCausalLM.from_pretrained(
tmp_path / "oneshot_out", torch_dtype=dtype
)

# assert that sample layer has the intended sparsity
assert math.isclose(
Expand Down
Loading