From aabdb0b3ab331ad8441428d02688a52f959ccb81 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 28 Dec 2024 13:16:06 -0500 Subject: [PATCH 1/3] add skip_missing_weights_context Signed-off-by: Kyle Sayers --- .../compressed_tensors_utils.py | 44 +++++++++++++++---- .../test_compress_tensor_utils.py | 18 +++----- 2 files changed, 40 insertions(+), 22 deletions(-) diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index eba5c5882..f4c14f07f 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -1,3 +1,5 @@ +import contextlib +import logging import os import re import weakref @@ -148,15 +150,6 @@ def save_pretrained_wrapper( # https://github.com/huggingface/transformers/pull/30488 transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size - def skip(*args, **kwargs): - pass - - # Skip the initializer step. This accelerates the loading - # of the models, especially for the quantized models - torch.nn.init.kaiming_uniform_ = skip - torch.nn.init.uniform_ = skip - torch.nn.init.normal_ = skip - # state_dict gets passed in as a kwarg for FSDP models state_dict = kwargs.pop("state_dict", None) if state_dict is None: @@ -305,3 +298,36 @@ def get_model_compressor( sparsity_config=sparsity_config, quantization_format=quantization_format, ) + + +@contextlib.contextmanager +def skip_missing_weights_context(): + kaiming_uniform_ = torch.nn.init.kaiming_uniform_ + uniform_ = torch.nn.init.uniform_ + normal_ = torch.nn.init.normal_ + + transformers_logger = logging.getLogger("transformers.modeling_utils") + restore_log_level = transformers_logger.getEffectiveLevel() + + # skip init functions + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + # TODO: consider skipping other default init functions + + # temporarily set the log level to error, to ignore printing out long missing + # and unexpected key error messages (these are EXPECTED for quantized models) + transformers_logger.setLevel(level=logging.ERROR) + + yield + + # restore original functions + torch.nn.init.kaiming_uniform_ = kaiming_uniform_ + torch.nn.init.uniform_ = uniform_ + torch.nn.init.normal_ = normal_ + + # restore transformers logging level now that model shell is loaded + transformers_logger.setLevel(level=restore_log_level) diff --git a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py index df9726647..dfea8544b 100644 --- a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py +++ b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py @@ -1,4 +1,3 @@ -import logging import math import shutil @@ -22,6 +21,7 @@ from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( modify_save_pretrained, patch_tied_tensors_bug, + skip_missing_weights_context, ) @@ -63,18 +63,10 @@ def test_sparse_model_reload(compressed, config, dtype, tmp_path): clear_sparse_session=False, ) - # temporarily set the log level to error, to ignore printing out long missing - # and unexpected key error messages (these are EXPECTED for quantized models) - transformers_logger = logging.getLogger("transformers.modeling_utils") - restore_log_level = transformers_logger.getEffectiveLevel() - transformers_logger.setLevel(level=logging.ERROR) - - model = AutoModelForCausalLM.from_pretrained( - tmp_path / "oneshot_out", torch_dtype=dtype - ) - - # restore transformers logging level now that model shell is loaded - transformers_logger.setLevel(level=restore_log_level) + with skip_missing_weights_context(): + model = AutoModelForCausalLM.from_pretrained( + tmp_path / "oneshot_out", torch_dtype=dtype + ) # assert that sample layer has the intended sparsity assert math.isclose( From d83b00da7cba63a7d79e75add25b86f040949375 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 28 Dec 2024 13:18:02 -0500 Subject: [PATCH 2/3] add docstring Signed-off-by: Kyle Sayers --- .../transformers/sparsification/compressed_tensors_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index f4c14f07f..026250dcd 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -302,6 +302,10 @@ def get_model_compressor( @contextlib.contextmanager def skip_missing_weights_context(): + """ + Used when loading a quantized model whose state dict does not align model + definition weights + """ kaiming_uniform_ = torch.nn.init.kaiming_uniform_ uniform_ = torch.nn.init.uniform_ normal_ = torch.nn.init.normal_ From d30e8e725eefb7d68c541e4fbf68d9ecc1c8778a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 31 Jan 2025 16:09:38 +0000 Subject: [PATCH 3/3] remove init context in favor of future hfquantizer solution Signed-off-by: Kyle Sayers --- .../compressed_tensors_utils.py | 39 ------------------- .../test_compress_tensor_utils.py | 8 ++-- 2 files changed, 3 insertions(+), 44 deletions(-) diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index 6541b9ce9..30a85f28c 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -1,5 +1,3 @@ -import contextlib -import logging import os import re import weakref @@ -313,40 +311,3 @@ def get_model_compressor( sparsity_config=sparsity_config, quantization_format=quantization_format, ) - - -@contextlib.contextmanager -def skip_missing_weights_context(): - """ - Used when loading a quantized model whose state dict does not align model - definition weights - """ - kaiming_uniform_ = torch.nn.init.kaiming_uniform_ - uniform_ = torch.nn.init.uniform_ - normal_ = torch.nn.init.normal_ - - transformers_logger = logging.getLogger("transformers.modeling_utils") - restore_log_level = transformers_logger.getEffectiveLevel() - - # skip init functions - def skip(*args, **kwargs): - pass - - torch.nn.init.kaiming_uniform_ = skip - torch.nn.init.uniform_ = skip - torch.nn.init.normal_ = skip - # TODO: consider skipping other default init functions - - # temporarily set the log level to error, to ignore printing out long missing - # and unexpected key error messages (these are EXPECTED for quantized models) - transformers_logger.setLevel(level=logging.ERROR) - - yield - - # restore original functions - torch.nn.init.kaiming_uniform_ = kaiming_uniform_ - torch.nn.init.uniform_ = uniform_ - torch.nn.init.normal_ = normal_ - - # restore transformers logging level now that model shell is loaded - transformers_logger.setLevel(level=restore_log_level) diff --git a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py index 75db65fc6..e6c3fe319 100644 --- a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py +++ b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py @@ -28,7 +28,6 @@ get_model_compressor, modify_save_pretrained, patch_tied_tensors_bug, - skip_missing_weights_context, ) @@ -70,10 +69,9 @@ def test_sparse_model_reload(compressed, config, dtype, tmp_path): clear_sparse_session=False, ) - with skip_missing_weights_context(): - model = AutoModelForCausalLM.from_pretrained( - tmp_path / "oneshot_out", torch_dtype=dtype - ) + model = AutoModelForCausalLM.from_pretrained( + tmp_path / "oneshot_out", torch_dtype=dtype + ) # assert that sample layer has the intended sparsity assert math.isclose(