diff --git a/examples/trl_mixin/ex_trl_distillation.py b/examples/trl_mixin/ex_trl_distillation.py index ff3ddf000..d1e392e75 100644 --- a/examples/trl_mixin/ex_trl_distillation.py +++ b/examples/trl_mixin/ex_trl_distillation.py @@ -1,9 +1,9 @@ from sft_trainer import SFTTrainer from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator -from llmcompressor.transformers import ( - DataTrainingArguments, - TextGenerationDataset, +from llmcompressor.transformers import TextGenerationDataset +from llmcompressor.transformers.utils.arg_parser import ( + DatasetArguments, TrainingArguments, ) @@ -21,7 +21,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_path) # Load gsm8k using SparseML dataset tools -data_args = DataTrainingArguments( +data_args = DatasetArguments( dataset="gsm8k", dataset_config_name="main", max_seq_length=512 ) dataset_manager = TextGenerationDataset.load_from_registry( diff --git a/src/llmcompressor/__init__.py b/src/llmcompressor/__init__.py index 264d434f0..3f9f14ac3 100644 --- a/src/llmcompressor/__init__.py +++ b/src/llmcompressor/__init__.py @@ -36,7 +36,6 @@ from llmcompressor.core.session_functions import ( active_session, - apply, callbacks, create_session, finalize, diff --git a/src/llmcompressor/core/__init__.py b/src/llmcompressor/core/__init__.py index 171c95395..75335164d 100644 --- a/src/llmcompressor/core/__init__.py +++ b/src/llmcompressor/core/__init__.py @@ -11,7 +11,6 @@ from llmcompressor.core.session_functions import ( LifecycleCallbacks, active_session, - apply, callbacks, create_session, finalize, diff --git a/src/llmcompressor/core/lifecycle.py b/src/llmcompressor/core/lifecycle.py index 232d76b83..30654cf8c 100644 --- a/src/llmcompressor/core/lifecycle.py +++ b/src/llmcompressor/core/lifecycle.py @@ -20,7 +20,9 @@ from llmcompressor.modifiers import StageModifiers from llmcompressor.recipe import RecipeContainer -__all__ = ["CompressionLifecycle"] +__all__ = [ + "CompressionLifecycle", +] @dataclass diff --git a/src/llmcompressor/core/session.py b/src/llmcompressor/core/session.py index 7c489f36f..888db3f1e 100644 --- a/src/llmcompressor/core/session.py +++ b/src/llmcompressor/core/session.py @@ -200,19 +200,6 @@ def finalize(self, **kwargs) -> ModifiedState: modifier_data=mod_data, ) - def apply(self, **kwargs): - """ - Apply the recipe in one-shot manner. This will invoke the initialize - and then finalize methods for each modifier in the session's lifecycle. - This will also set the session's state to the finalized state. - - :param kwargs: additional kwargs to pass to the lifecycle's initialize and - finalize methods - """ - self.initialize(**kwargs) - - return self.finalize(**kwargs) - def event( self, event_type: EventType, diff --git a/src/llmcompressor/core/session_functions.py b/src/llmcompressor/core/session_functions.py index 9a123a030..da54872c4 100644 --- a/src/llmcompressor/core/session_functions.py +++ b/src/llmcompressor/core/session_functions.py @@ -14,7 +14,6 @@ "pre_initialize_structure", "initialize", "finalize", - "apply", "callbacks", "LifecycleCallbacks", ] @@ -143,62 +142,6 @@ def finalize(**kwargs) -> ModifiedState: return active_session().finalize(**kwargs) -def apply( - recipe: Union[str, List[str], "Recipe", List["Recipe"], None] = None, - recipe_stage: Union[str, List[str], None] = None, - recipe_args: Optional[Dict[str, Any]] = None, - model: Optional[Any] = None, - teacher_model: Optional[Any] = None, - train_data: Optional[Any] = None, - val_data: Optional[Any] = None, - test_data: Optional[Any] = None, - calib_data: Optional[Any] = None, - copy_data: bool = True, - start: Optional[float] = None, - steps_per_epoch: Optional[int] = None, - batches_per_step: Optional[int] = None, - **kwargs, -) -> ModifiedState: - """ - A method to apply the recipe in one-shot manner. This will invoke the initialize - and then finalize methods for each modifier in the active session's lifecycle. - - :param recipe: the recipe to use for the sparsification, can be a path to a - recipe file, a raw recipe string, a recipe object, or a list of recipe objects. - :param recipe_stage: the stage to target for the sparsification - :param recipe_args: the args to use for overriding the recipe defaults - :param model: the model to sparsify - :param teacher_model: the teacher model to use for knowledge distillation - :param train_data: the training data to use for the sparsification - :param val_data: the validation data to use for the sparsification - :param test_data: the testing data to use for the sparsification - :param calib_data: the calibration data to use for the sparsification - :param copy_data: True to copy the data, False otherwise - :param start: the start epoch to use for the sparsification - :param steps_per_epoch: the number of steps per epoch to use for the - sparsification - :param batches_per_step: the number of batches per step to use for - :param kwargs: additional kwargs to pass to the current session's apply method - :return: the modified state of the active session after applying the recipe - """ - return active_session().apply( - recipe=recipe, - recipe_stage=recipe_stage, - recipe_args=recipe_args, - model=model, - teacher_model=teacher_model, - train_data=train_data, - val_data=val_data, - test_data=test_data, - calib_data=calib_data, - copy_data=copy_data, - start=start, - steps_per_epoch=steps_per_epoch, - batches_per_step=batches_per_step, - **kwargs, - ) - - class LifecycleCallbacks: """ A class for invoking lifecycle events for the active session diff --git a/src/llmcompressor/transformers/calibration/__init__.py b/src/llmcompressor/transformers/calibration/__init__.py new file mode 100644 index 000000000..65fc2575f --- /dev/null +++ b/src/llmcompressor/transformers/calibration/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .oneshot import Oneshot diff --git a/src/llmcompressor/transformers/calibration/oneshot.py b/src/llmcompressor/transformers/calibration/oneshot.py new file mode 100644 index 000000000..4601a02b1 --- /dev/null +++ b/src/llmcompressor/transformers/calibration/oneshot.py @@ -0,0 +1,263 @@ +from pathlib import PosixPath +from typing import Optional + +from loguru import logger +from torch.utils.data import DataLoader + +from llmcompressor.core.session_functions import active_session +from llmcompressor.transformers.finetune.data.data_helpers import ( + get_calibration_dataloader, +) +from llmcompressor.transformers.finetune.text_generation import ( + initialize_model_from_path, + initialize_processor_from_path, + parse_args, +) +from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( + modify_save_pretrained, + patch_tied_tensors_bug, +) +from llmcompressor.transformers.utils.arg_parser import DEFAULT_OUTPUT_DIR + +__all__ = ["Oneshot"] + + +class Oneshot: + """ + Class responsible for carrying out one-shot calibration on a pretrained model. + + This class handles the entire lifecycle of one-shot calibration, including + preprocessing (model and tokenizer/processor initialization), model optimization + (quantization or sparsification), and postprocessing (saving outputs). The + intructions for model optimization can be specified by using a recipe (fine-grain + details) or by using a scheme (ex. W4A16, W8A8, W4A8). + + - **Input Keyword Arguments:** + `kwargs` are parsed into: + - `model_args`: Arguments for loading and configuring a pretrained model + (e.g., `AutoModelForCausalLM`). + - `data_args`: Arguments for dataset-related configurations, such as + calibration dataloaders. + - `recipe_args`: Arguments for defining and configuring recipes that specify + optimization actions. + + Parsers are defined in `src/llmcompressor/transformers/utils/arg_parser`. + + - **Lifecycle Overview:** + The calibration lifecycle consists of three steps: + 1. **Preprocessing**: + - Instantiates a pretrained model and tokenizer/processor. + - Ensures input and output embedding layers are untied if they share + tensors. + - Patches the model to include additional functionality for saving with + quantization configurations. + 2. **Oneshot Calibration**: + - Optimizes the model using a global `CompressionSession` and applies + recipe-defined modifiers (e.g., `GPTQModifier`, `SparseGPTModifier`) + 3. **Postprocessing**: + - Saves the model, tokenizer/processor, and configuration to the specified + `output_dir`. + + - **Usage:** + ```python + oneshot = Oneshot(model=model, recipe=recipe, dataset=dataset) + oneshot.run() + + # Access the processed components + model = oneshot.model + tokenizer_or_processor = oneshot.tokenizer_or_processor + recipe = oneshot.recipe + ``` + + Methods: + __init__(**kwargs): + Initializes the `Oneshot` object by parsing input arguments, performing + preprocessing, and setting instance attributes. + + run(**kwargs): + Performs the one-shot calibration process by preparing a calibration + dataloader, applying recipe modifiers to the model, and executing + postprocessing steps. + + save(): + Saves the calibrated model and tokenizer/processor to the specified + `output_dir`. Supports saving in compressed formats based on model + arguments. + + _apply_recipe_modifiers(calibration_dataloader, **kwargs): + Applies lifecycle actions (e.g., `initialize`, `finalize`) using modifiers + defined in the recipe. Each action is executed via the global + `CompressionSession`. + + _pre_process(): + Handles preprocessing steps, including model initialization, + tokenizer/processor setup, and resolving tied embedding issues. + + _warn_tied_embeddings(): + Logs a warning if `tie_word_embeddings=True`, which may interfere with + saving in the one-shot workflow. + + _post_process(): + Executes postprocessing steps such as saving the model and resetting + lifecycle actions, especially when a custom `output_dir` is specified. + """ + + MODIFIER_LIFECYCLE_ACTIONS = ( + "initialize", + "finalize", + ) + + def __init__(self, **kwargs): + """ + Initializes the `Oneshot` class with provided arguments. + + Parses the input keyword arguments into `model_args`, `data_args`, and + `recipe_args`. Performs preprocessing to initialize the model and + tokenizer/processor. + + Args: + kwargs: Arbitrary keyword arguments for model, data, and recipe + configurations. + """ + self.model_args, self.data_args, self.recipe_args, _, self.output_dir = ( + parse_args(**kwargs) + ) + + # Preprocess the model and tokenizer/processor + self._pre_process() + + # Set instance attributes + self.model = self.model_args.model + self.tokenizer_or_processor = self.model_args.processor + self.recipe = self.recipe_args.recipe + + def run(self, **kwargs): + """ + Performs one-shot calibration. + + This method prepares a calibration dataloader using dataset arguments and + applies recipe-based modifiers to optimize the model. The lifecycle actions + are executed sequentially, and the modified model is saved during + postprocessing. + + Args: + kwargs: Additional keyword arguments for the recipe modifiers. + """ + calibration_dataloader = get_calibration_dataloader( + self.data_args, self.tokenizer_or_processor + ) + self._apply_recipe_modifiers( + calibration_dataloader=calibration_dataloader, **kwargs + ) + self._post_process() + + def save(self): + """ + Saves the model and tokenizer/processor to the output directory. + + The model is saved in a compressed format if specified in `model_args`. + The tokenizer or processor, if available, is also saved. + + Raises: + ValueError: If saving fails due to an invalid `output_dir` or other issues. + """ + self.model.save_pretrained( + self.output_dir, + save_compressed=self.model_args.save_compressed, + ) + if self.tokenizer_or_processor: + self.tokenizer_or_processor.save_pretrained(self.output_dir) + + def _apply_recipe_modifiers( + self, calibration_dataloader: Optional[DataLoader], **kwargs + ): + """ + Applies recipe modifiers to the model during the lifecycle. + + The modifiers are defined in the recipe and executed via lifecycle actions + (`initialize`, `finalize`) through the global `CompressionSession`. + + Args: + calibration_dataloader (Optional[DataLoader]): Dataloader for calibration + data. + kwargs: Additional arguments for lifecycle actions. + + Raises: + RuntimeError: If any modifier fails during execution. + """ + for action in self.MODIFIER_LIFECYCLE_ACTIONS: + session = active_session() + session_action = getattr(session, action) + session_action( + model=self.model, + recipe=self.recipe, + recipe_args=self.recipe_args.recipe_args, + calib_data=calibration_dataloader, + start=-1, # oneshot-specific argument + copy_data=False, + min_tokens_per_module=getattr(self, "min_tokens_per_module", None), + **kwargs, + ) + + def _pre_process(self): + """ + Prepares the model and tokenizer/processor for calibration. + + - Initializes the model if it's specified as a path or string. + - Applies patches to fix tied tensor issues and modifies `save_pretrained` + behavior. + - Initializes the processor if specified as a path or `None`. + - Sets the minimum tokens per module if `data_args` are provided. + + Raises: + FileNotFoundError: If the model or processor path is invalid. + """ + self._warn_tied_embeddings() + + # Initialize model + if isinstance(self.model_args.model, (str, PosixPath)): + self.model_args.model, _ = initialize_model_from_path(self.model_args) + + patch_tied_tensors_bug(self.model_args.model) + modify_save_pretrained(self.model_args.model) + + # Initialize processor + if isinstance(self.model_args.processor, (str, type(None))): + self.model_args.processor = initialize_processor_from_path( + self.model_args, self.model_args.model + ) + + # Set minimum tokens per module if data arguments are provided + if self.data_args: + self.min_tokens_per_module = self.data_args.min_tokens_per_module + + def _warn_tied_embeddings(self): + """ + Logs a warning if the model has tied word embeddings. + + The `tie_word_embeddings` flag may cause issues during saving in the one-shot + calibration workflow due to shared tensor addresses. + """ + if self.model_args.tie_word_embeddings: + logger.debug( + "The tie_word_embeddings flag is by default set to False. " + "This guarantees that the one-shot algorithm saves the final " + "weights without errors. Detected tie_word_embeddings=True. " + "This may cause issues with the one-shot algorithm on save." + ) + + def _post_process(self): + """ + Executes post-calibration steps. + + This method saves the model and resets lifecycle actions if the `output_dir` + is not the default directory. + + Raises: + ValueError: If saving fails due to invalid configurations. + """ + if ( + isinstance(self.model_args.model, str) + or self.output_dir != DEFAULT_OUTPUT_DIR + ): + self.save() diff --git a/src/llmcompressor/transformers/compression/sparsity_config.py b/src/llmcompressor/transformers/compression/sparsity_config.py index d35ddadd1..881c067e5 100644 --- a/src/llmcompressor/transformers/compression/sparsity_config.py +++ b/src/llmcompressor/transformers/compression/sparsity_config.py @@ -47,7 +47,9 @@ def infer_global_sparsity( return global_sparsity @staticmethod - def infer_sparsity_structure(model: Optional[Module] = None) -> str: + def infer_sparsity_structure( + model: Optional[Module] = None, + ) -> str: """ Determines what sparsity structure, if any, was applied. @@ -107,7 +109,7 @@ def from_pretrained( return None sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure( - model=model + model=model, ) if ( disable_sparse_compression diff --git a/src/llmcompressor/transformers/finetune/__init__.py b/src/llmcompressor/transformers/finetune/__init__.py index aad70ae2c..6c75b902b 100644 --- a/src/llmcompressor/transformers/finetune/__init__.py +++ b/src/llmcompressor/transformers/finetune/__init__.py @@ -1,7 +1,5 @@ # flake8: noqa -from .data import DataTrainingArguments, TextGenerationDataset -from .model_args import ModelArguments +from .data import TextGenerationDataset from .session_mixin import SessionManagerMixIn from .text_generation import apply, compress, eval, oneshot, train -from .training_args import TrainingArguments diff --git a/src/llmcompressor/transformers/finetune/data/__init__.py b/src/llmcompressor/transformers/finetune/data/__init__.py index ddf0b2364..a53caed1b 100644 --- a/src/llmcompressor/transformers/finetune/data/__init__.py +++ b/src/llmcompressor/transformers/finetune/data/__init__.py @@ -4,7 +4,6 @@ from .c4 import C4Dataset from .cnn_dailymail import CNNDailyMailDataset from .custom import CustomDataset -from .data_args import DataTrainingArguments from .evolcodealpaca import EvolCodeAlpacaDataset from .flickr_30k import Flickr30K from .gsm8k import GSM8KDataset diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index fa8e434d4..30c97df7a 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -8,12 +8,12 @@ from datasets.formatting.formatting import LazyRow from loguru import logger -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.transformers.finetune.data.data_helpers import ( LABELS_MASK_VALUE, get_custom_datasets_from_path, get_raw_dataset, ) +from llmcompressor.transformers.utils.arg_parser import DatasetArguments from llmcompressor.transformers.utils.preprocessing_functions import ( PreprocessingFunctionRegistry, ) @@ -41,7 +41,7 @@ class TextGenerationDataset(RegistryMixin): def __init__( self, - data_args: DataTrainingArguments, + data_args: DatasetArguments, split: str, processor: Processor, ): diff --git a/src/llmcompressor/transformers/finetune/data/c4.py b/src/llmcompressor/transformers/finetune/data/c4.py index e50d4d0c6..bf3feeee7 100644 --- a/src/llmcompressor/transformers/finetune/data/c4.py +++ b/src/llmcompressor/transformers/finetune/data/c4.py @@ -5,7 +5,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="c4") @@ -18,7 +18,7 @@ class C4Dataset(TextGenerationDataset): :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "allenai/c4" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py index 06ad3ecfa..506f760d0 100644 --- a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py +++ b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py @@ -5,7 +5,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="cnn_dailymail") @@ -20,7 +20,7 @@ class CNNDailyMailDataset(TextGenerationDataset): SAMPLE_TEMPLATE = "Article:\n{article}\n\n### Summarization:\n{highlights}\n" - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "cnn_dailymail" data_args.dataset_config_name = "3.0.0" diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 23c70e561..6020cd17d 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -1,9 +1,11 @@ import logging import os +import re from typing import Any, Callable, Dict, List, Optional import torch from datasets import Dataset, load_dataset +from loguru import logger from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.data import default_data_collator @@ -15,6 +17,7 @@ "get_raw_dataset", "make_dataset_splits", "get_custom_datasets_from_path", + "get_calibration_dataloader", ] @@ -243,3 +246,76 @@ def do_transform(candidate: str) -> bool: transform_dataset_key(dataset_key) return data_files + + +def get_calibration_dataloader( + data_args, + processor, + add_labels: bool = False, # for oneshot + do_oneshot=True, +): + """ + Loads datasets for each flow based on data_args, stores a Dataset for each + enabled flow in self.datasets + + :param processor: processor or tokenizer to use for dataset tokenization + :param add_labels: if True, add labels column to dataset splits + """ + if data_args.dataset is None: + logger.info( + "Running oneshot without calibration data. This is expected for " + "weight-only and dynamic quantization" + ) + return + + splits = data_args.splits + tokenized_datasets = {} + + def _get_split_name(inp_str): + # strip out split name, for ex train[60%:] -> train + match = re.match(r"(\w*)\[.*\]", inp_str) + if match is not None: + return match.group(1) + return inp_str + + if splits is None: + splits = {"all": None} + elif isinstance(splits, str): + splits = {_get_split_name(splits): splits} + elif isinstance(splits, List): + splits = {_get_split_name(s): s for s in splits} + + # default to custom dataset if dataset provided isn't a string + registry_id = data_args.dataset if isinstance(data_args.dataset, str) else "custom" + for split_name, split_str in splits.items(): + dataset = data_args.dataset + if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names: + # dataset is already tokenized + tokenized_datasets[split_name] = dataset + else: + # dataset needs to be tokenized + from llmcompressor.transformers.finetune.data.base import ( + TextGenerationDataset, + ) + + dataset_manager = TextGenerationDataset.load_from_registry( + registry_id, + data_args=data_args, + split=split_str, + processor=processor, + ) + tokenized_datasets[split_name] = dataset_manager(add_labels=add_labels) + + datasets = make_dataset_splits( + tokenized_datasets, + do_oneshot=do_oneshot, + ) + + calibration_dataset = datasets.get("calibration") + + return format_calibration_data( + tokenized_dataset=calibration_dataset, + num_calibration_samples=data_args.num_calibration_samples, + do_shuffle=data_args.shuffle_calibration_samples, + collate_fn=data_args.data_collator, + ) diff --git a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py index 932bfa54c..ca3caec03 100644 --- a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py +++ b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py @@ -5,7 +5,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="evolcodealpaca") @@ -25,7 +25,7 @@ class EvolCodeAlpacaDataset(TextGenerationDataset): "\n\n### Response:\n" ) - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "theblackcat102/evol-codealpaca-v1" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/data/flickr_30k.py b/src/llmcompressor/transformers/finetune/data/flickr_30k.py index f19b053e1..4528c5340 100644 --- a/src/llmcompressor/transformers/finetune/data/flickr_30k.py +++ b/src/llmcompressor/transformers/finetune/data/flickr_30k.py @@ -7,7 +7,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="flickr", alias="flickr30k") @@ -31,7 +31,7 @@ class Flickr30K(TextGenerationDataset): "{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" ) - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "lmms-lab/flickr30k" diff --git a/src/llmcompressor/transformers/finetune/data/gsm8k.py b/src/llmcompressor/transformers/finetune/data/gsm8k.py index beae5dfec..8ee26145d 100644 --- a/src/llmcompressor/transformers/finetune/data/gsm8k.py +++ b/src/llmcompressor/transformers/finetune/data/gsm8k.py @@ -5,7 +5,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="gsm8k") @@ -20,7 +20,7 @@ class GSM8KDataset(TextGenerationDataset): GSM_TEMPLATE = "Question: {question}\nAnswer:" - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "gsm8k" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/data/open_platypus.py b/src/llmcompressor/transformers/finetune/data/open_platypus.py index 3b25986ca..0dbf064e5 100644 --- a/src/llmcompressor/transformers/finetune/data/open_platypus.py +++ b/src/llmcompressor/transformers/finetune/data/open_platypus.py @@ -5,7 +5,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="open_platypus") @@ -28,7 +28,7 @@ class OpenPlatypusDataset(TextGenerationDataset): "instruction}\n\n### Response:\n", } - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "garage-bAInd/Open-Platypus" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/data/ptb.py b/src/llmcompressor/transformers/finetune/data/ptb.py index c7f0bbac1..db0be0599 100644 --- a/src/llmcompressor/transformers/finetune/data/ptb.py +++ b/src/llmcompressor/transformers/finetune/data/ptb.py @@ -5,7 +5,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="ptb") @@ -18,7 +18,7 @@ class PtbDataset(TextGenerationDataset): :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "ptb_text_only" data_args.text_column = "sentence" diff --git a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py index 62c012e83..f914ae5d4 100644 --- a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py +++ b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py @@ -7,7 +7,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="ultrachat_200k") @@ -33,7 +33,7 @@ class UltraChatDataset(TextGenerationDataset): "{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" ) - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "HuggingFaceH4/ultrachat_200k" data_args.text_column = "messages" diff --git a/src/llmcompressor/transformers/finetune/data/wikitext.py b/src/llmcompressor/transformers/finetune/data/wikitext.py index a559399d8..5e58c3c94 100644 --- a/src/llmcompressor/transformers/finetune/data/wikitext.py +++ b/src/llmcompressor/transformers/finetune/data/wikitext.py @@ -5,7 +5,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="wikitext") @@ -18,7 +18,7 @@ class WikiTextDataset(TextGenerationDataset): :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "Salesforce/wikitext" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index 0a07c45eb..c1aec5164 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -16,13 +16,20 @@ from llmcompressor.pytorch.utils import tensors_to_device from llmcompressor.recipe import Recipe, StageRunType from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.transformers.finetune.data.data_helpers import ( format_calibration_data, make_dataset_splits, ) -from llmcompressor.transformers.finetune.model_args import ModelArguments -from llmcompressor.transformers.finetune.training_args import TrainingArguments +from llmcompressor.transformers.utils.arg_parser import ( + DatasetArguments, + ModelArguments, + RecipeArguments, + TrainingArguments, +) +from llmcompressor.transformers.utils.arg_parser.training_arguments import ( + DEFAULT_OUTPUT_DIR, +) +from llmcompressor.transformers.utils.arg_parser.utils import get_dataclass_as_dict from llmcompressor.typing import Processor from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_model_and_recipe @@ -46,13 +53,15 @@ class StageRunner: def __init__( self, - data_args: "DataTrainingArguments", + data_args: "DatasetArguments", model_args: "ModelArguments", training_args: "TrainingArguments", + recipe_args: "RecipeArguments", ): self._data_args = data_args self._model_args = model_args self._training_args = training_args + self._recipe_args = recipe_args self.datasets = {} self.trainer = None @@ -214,7 +223,7 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None): :param checkpoint: optional checkpoint to pick up a stage from """ - recipe_obj = Recipe.create_instance(self._training_args.recipe) + recipe_obj = Recipe.create_instance(self._recipe_args.recipe) with self.trainer.accelerator.main_process_first(): checkpoint_dir = self._model_args.model completed_stages = get_completed_stages(checkpoint_dir) @@ -251,21 +260,30 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None): # run stage if run_type is StageRunType.ONESHOT: - self.one_shot(stage=stage_name) + from llmcompressor.transformers.calibration import Oneshot + + model = get_session_model() + self._model_args.model = model + + oneshot = Oneshot( + output_dir=self._training_args.output_dir, + **get_dataclass_as_dict(self._model_args, ModelArguments), + **get_dataclass_as_dict(self._data_args, DatasetArguments), + **get_dataclass_as_dict(self._recipe_args, RecipeArguments), + ) + + oneshot.run(stage_name=stage_name) elif run_type is StageRunType.TRAIN: self.train(checkpoint=checkpoint, stage=stage_name) checkpoint = None - if ( - self._training_args.output_dir - != TrainingArguments.__dataclass_fields__["output_dir"].default - ): + if self._training_args.output_dir != DEFAULT_OUTPUT_DIR: save_model_and_recipe( model=self.trainer.model, save_path=self._output_dir, processor=self.processor, save_safetensors=self._training_args.save_safetensors, - save_compressed=self._training_args.save_compressed, + save_compressed=self._model_args.save_compressed, ) # save stage to checkpoint dir diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index 27860aeb4..07b9ba1ef 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -7,13 +7,12 @@ import torch from loguru import logger from torch.nn import Module -from torch.utils.data import DataLoader, IterableDataset +from torch.utils.data import IterableDataset from transformers.trainer_callback import TrainerState from transformers.trainer_utils import get_last_checkpoint from llmcompressor.core import ( active_session, - apply, callbacks, create_session, finalize, @@ -36,8 +35,10 @@ from llmcompressor.utils.pytorch import qat_active if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments - + from llmcompressor.transformers.utils.arg_parser import ( + DatasetArguments, + ModelArguments, + ) __all__ = [ "SessionManagerMixIn", @@ -68,12 +69,14 @@ def __init__( self, recipe: Optional[str] = None, recipe_args: Optional[Union[Dict[str, Any], str]] = None, - data_args: Optional["DataTrainingArguments"] = None, + data_args: Optional["DatasetArguments"] = None, + model_args: Optional["ModelArguments"] = None, teacher: Optional[Union[Module, str]] = None, **kwargs, ): self.recipe = recipe self.recipe_args = recipe_args + self.model_args = model_args self.teacher = teacher # parse training and metadata args @@ -374,8 +377,8 @@ def train(self, *args, stage: Optional[str] = None, **kwargs): self.initialize_session(epoch=epoch, checkpoint=checkpoint, stage=stage) # do not save checkpoints as compressed - original_save_compressed = self.args.save_compressed - self.args.save_compressed = False + original_save_compressed = self.model_args.save_compressed + self.model_args.save_compressed = False # train with accelerator self.accelerator.wait_for_everyone() @@ -383,7 +386,7 @@ def train(self, *args, stage: Optional[str] = None, **kwargs): self.accelerator.wait_for_everyone() # restore original setting for saving final model - self.args.save_compressed = original_save_compressed + self.model_args.save_compressed = original_save_compressed # lifecycle self.finalize_session() @@ -428,31 +431,6 @@ def predict(self, *args, **kwargs): return output - def one_shot( - self, calibration_data: Optional[DataLoader] = None, stage: Optional[str] = None - ): - """ - Run oneshot calibration on the active model - - :param stage: which stage of the recipe to run, or None to run whole recipe - :param calib_data: dataloader of calibration data - """ - apply( - recipe=self.recipe, - recipe_stage=stage, - recipe_args=self.recipe_args, - model=self.model, - calib_data=calibration_data, - start=-1, - copy_data=False, - accelerator=self.accelerator, - min_tokens_per_module=self.min_tokens_per_module, - ) - - # log model sparsity - # self.maybe_log_model_sparsification() - self.accelerator.wait_for_everyone() - def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False): """ Override of the save_model function and expects it to exist in the parent. @@ -474,7 +452,7 @@ def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False): if not is_fsdp_model(self.model): self.model.save_pretrained( output_dir, - save_compressed=self.args.save_compressed, + save_compressed=self.model_args.save_compressed, safe_serialization=self.args.save_safetensors, ) else: # FSDP model @@ -482,7 +460,7 @@ def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False): model=self.model, accelerator=self.accelerator, output_dir=output_dir, - save_compressed=self.args.save_compressed, + save_compressed=self.model_args.save_compressed, save_safetensors=self.metadata.get("save_safetensors", False), ) diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 61e6441bb..6c71610a9 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -20,6 +20,7 @@ import os import warnings from pathlib import PosixPath +from typing import Optional from loguru import logger from transformers import ( @@ -40,18 +41,22 @@ parse_dtype, ) from llmcompressor.recipe import Recipe, StageRunType -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments -from llmcompressor.transformers.finetune.model_args import ModelArguments from llmcompressor.transformers.finetune.runner import StageRunner from llmcompressor.transformers.finetune.trainer import Trainer -from llmcompressor.transformers.finetune.training_args import TrainingArguments from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( modify_fsdp_model_save_pretrained, modify_save_pretrained, patch_tied_tensors_bug, ) from llmcompressor.transformers.sparsification.sparse_model import ( - get_shared_processor_src, + get_processor_from_model, +) +from llmcompressor.transformers.utils.arg_parser import ( + DEFAULT_OUTPUT_DIR, + DatasetArguments, + ModelArguments, + RecipeArguments, + TrainingArguments, ) from llmcompressor.transformers.utils.helpers import ( detect_last_checkpoint, @@ -65,27 +70,33 @@ def train(**kwargs): """ CLI entrypoint for running training """ - model_args, data_args, training_args = parse_args(**kwargs) + model_args, data_args, recipe_args, training_args, _ = parse_args( + include_training_args=True, **kwargs + ) training_args.do_train = True - main(model_args, data_args, training_args) + main(model_args, data_args, recipe_args, training_args) def eval(**kwargs): """ CLI entrypoint for running evaluation """ - model_args, data_args, training_args = parse_args(**kwargs) + model_args, data_args, recipe_args, training_args, _ = parse_args( + include_training_args=True, **kwargs + ) training_args.do_eval = True - main(model_args, data_args, training_args) + main(model_args, data_args, recipe_args, training_args) def oneshot(**kwargs): + from llmcompressor.transformers.calibration.oneshot import Oneshot + """ CLI entrypoint for running oneshot calibration """ - model_args, data_args, training_args = parse_args(**kwargs) - training_args.do_oneshot = True - main(model_args, data_args, training_args) + oneshot = Oneshot(**kwargs) + oneshot.run() + return oneshot # alias @@ -97,12 +108,15 @@ def apply(**kwargs): CLI entrypoint for any of training, eval, predict or oneshot """ report_to = kwargs.get("report_to", None) - model_args, data_args, training_args = parse_args(**kwargs) + model_args, data_args, recipe_args, training_args, _ = parse_args( + include_training_args=True, **kwargs + ) + training_args.run_stages = True if report_to is None: # user didn't specify any reporters # get rid of the reporters inferred from hugging face training_args.report_to = [] - main(model_args, data_args, training_args) + main(model_args, data_args, recipe_args, training_args) def compress(**kwargs): @@ -111,60 +125,100 @@ def compress(**kwargs): def load_dataset(dataset_name: str, **kwargs): parser = HfArgumentParser( - (ModelArguments, DataTrainingArguments, TrainingArguments) + (ModelArguments, DatasetArguments, RecipeArguments, TrainingArguments) ) - model_args, data_args, training_args = parser.parse_dict(kwargs) + _, data_args, _, _ = parser.parse_dict(kwargs) data_args["dataset_name"] = dataset_name -def parse_args(**kwargs): +def parse_args(include_training_args: bool = False, **kwargs): """ Parses kwargs by grouping into model, data or training arg groups: - * model_args in src/llmcompressor/transformers/finetune/model_args.py - * data_args in src/llmcompressor/transformers/finetune/data/data_args.py - * training_args in src/llmcompressor/transformers/finetune/training_args.py + * model_args in + src/llmcompressor/transformers/utils/arg_parser/model_args.py + * data_args in + src/llmcompressor/transformers/utils/arg_parser/data_args.py + * recipe_args in + src/llmcompressor/transformers/utils/arg_parser/recipe_args.py + * training_args in + src/llmcompressor/transformers/utils/arg_parser/training_args.py + + Throws deprecation warnings + + :param include_training_args: Add training_args in the output if set to True. + Note that instantiatng trainng_args will reset HF accelerator and change its + internal state. This dataclass should be instantiated only once to avoid + conflict with Accelerate library's accelerator. - Throws depreciation warnings """ - parser = HfArgumentParser( - (ModelArguments, DataTrainingArguments, TrainingArguments) - ) - if not kwargs: - model_args, data_args, training_args = parser.parse_args_into_dataclasses() + output_dir = kwargs.pop("output_dir", DEFAULT_OUTPUT_DIR) + + if include_training_args: + parser = HfArgumentParser( + (ModelArguments, DatasetArguments, RecipeArguments, TrainingArguments) + ) else: - model_args, data_args, training_args = parser.parse_dict(kwargs) + parser = HfArgumentParser((ModelArguments, DatasetArguments, RecipeArguments)) + + if not kwargs: + # if output_dir passed from cli, pop to avoid using training_args + def _get_output_dir_from_argv() -> Optional[str]: + import sys + + output_dir = None + if "--output_dir" in sys.argv: + index = sys.argv.index("--output_dir") + sys.argv.pop(index) + if index < len(sys.argv): # Check if value exists afer the flag + output_dir = sys.argv.pop(index) + + return output_dir - if training_args.recipe_args is not None: - if not isinstance(training_args.recipe_args, dict): - arg_dict = {} - for recipe_arg in training_args.recipe_args: - key, value = recipe_arg.split("=") - arg_dict[key] = value - training_args.recipe_args = arg_dict + output_dir = _get_output_dir_from_argv() or output_dir - # raise depreciation warnings + parsed_args = parser.parse_args_into_dataclasses() + else: + parsed_args = parser.parse_dict(kwargs) + + # Unpack parsed arguments based on the presence of training arguments + if include_training_args: + model_args, data_args, recipe_args, training_args = parsed_args + if output_dir is not None: + training_args.output_dir = output_dir + else: + model_args, data_args, recipe_args = parsed_args + training_args = None + + if recipe_args.recipe_args is not None: + if not isinstance(recipe_args.recipe_args, dict): + recipe_args.recipe_args = { + key: value + for arg in recipe_args.recipe_args + for key, value in [arg.split("=")] + } + + # Raise deprecation warnings if data_args.remove_columns is not None: warnings.warn( - "`remove_columns` argument is depreciated. When tokenizing datasets, all " - "columns which are invalid inputs the tokenizer will be removed", + "`remove_columns` argument is deprecated. When tokenizing datasets, all " + "columns which are invalid inputs to the tokenizer will be removed.", DeprecationWarning, ) - # silently assign tokenizer to processor + # Silently assign tokenizer to processor if model_args.tokenizer: if model_args.processor: - raise ValueError("Cannot use both a tokenizer and processor") + raise ValueError("Cannot use both a tokenizer and processor.") model_args.processor = model_args.tokenizer - model_args.tokenizer = None + model_args.tokenizer = None - return model_args, data_args, training_args + return model_args, data_args, recipe_args, training_args, output_dir def initialize_model_from_path( model_args: ModelArguments, - training_args: TrainingArguments, + training_args: Optional[TrainingArguments] = None, ): - last_checkpoint = detect_last_checkpoint(training_args, model_args=model_args) # Load pretrained model # The .from_pretrained methods guarantee that only one local process can # concurrently download model & vocab. @@ -177,16 +231,23 @@ def initialize_model_from_path( tie_word_embeddings=model_args.tie_word_embeddings, trust_remote_code=model_args.trust_remote_code_model, ) - teacher_config = ( - AutoConfig.from_pretrained( - model_args.distill_teacher, - use_auth_token=True if model_args.use_auth_token else None, - tie_word_embeddings=model_args.tie_word_embeddings, - trust_remote_code=model_args.trust_remote_code_model, + + last_checkpoint = None + + if training_args is not None: + teacher_config = ( + AutoConfig.from_pretrained( + model_args.distill_teacher, + use_auth_token=True if model_args.use_auth_token else None, + tie_word_embeddings=model_args.tie_word_embeddings, + trust_remote_code=model_args.trust_remote_code_model, + ) + if model_args.distill_teacher + else None ) - if model_args.distill_teacher - else None - ) + last_checkpoint = detect_last_checkpoint(training_args, model_args=model_args) + # Set seed before initializing model. + set_seed(training_args.seed) model_path = ( last_checkpoint or model_args.model @@ -194,21 +255,18 @@ def initialize_model_from_path( else model_args.model_name_or_path ) - # Set seed before initializing model. - set_seed(training_args.seed) - # Fallback to CPU if GPU requested and not available - training_args.oneshot_device = fallback_to_cpu(training_args.oneshot_device) + model_args.oneshot_device = fallback_to_cpu(model_args.oneshot_device) # Trainer handles device assignment for FSDP and training, don't do mapping here # if running oneshot outside of FSDP, apply user device settings - device_map = None + fsdp_enabled = os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" - if not fsdp_enabled and training_args.do_oneshot: - device_map = training_args.oneshot_device - logger.warning(f"Moving {model_path} to device {device_map} for One-Shot") - elif not fsdp_enabled: + + device_map = model_args.oneshot_device + if not fsdp_enabled and training_args is not None and training_args.do_train: device_map = "auto" + model_kwargs = { "config": config, "cache_dir": model_args.cache_dir, @@ -218,15 +276,7 @@ def initialize_model_from_path( "device_map": device_map, "trust_remote_code": model_args.trust_remote_code_model, } - teacher_device_map = None if fsdp_enabled else "auto" - teacher_kwargs = { - "config": teacher_config, - "cache_dir": model_args.cache_dir, - "use_auth_token": True if model_args.use_auth_token else None, - "torch_dtype": parse_dtype(model_args.precision), - "device_map": teacher_device_map, - "trust_remote_code": model_args.trust_remote_code_model, - } + # this calls from_pretrained under the hood so should be FSDP safe # optimized models must be decompressed to carry out oneshot/train/etc @@ -242,25 +292,38 @@ def initialize_model_from_path( if "sequence_length" in model_kwargs: model.seqlen = model_kwargs["sequence_length"] - teacher = ( - AutoModelForCausalLM.from_pretrained( - model_args.distill_teacher, - **teacher_kwargs, + teacher = None + if training_args is not None: + teacher_device_map = None if fsdp_enabled else "auto" + teacher_kwargs = { + "config": teacher_config, + "cache_dir": model_args.cache_dir, + "use_auth_token": True if model_args.use_auth_token else None, + "torch_dtype": parse_dtype(model_args.precision), + "device_map": teacher_device_map, + "trust_remote_code": model_args.trust_remote_code_model, + } + + teacher = ( + AutoModelForCausalLM.from_pretrained( + model_args.distill_teacher, + **teacher_kwargs, + ) + if model_args.distill_teacher is not None + else None ) - if model_args.distill_teacher is not None - else None - ) - if teacher is not None and "sequence_length" in teacher_kwargs: - teacher.seqlen = teacher_kwargs["sequence_length"] + if teacher is not None and "sequence_length" in teacher_kwargs: + teacher.seqlen = teacher_kwargs["sequence_length"] - return teacher, model_path, model + return model, teacher def initialize_processor_from_path( - model_args: ModelArguments, model: PreTrainedModel, teacher: PreTrainedModel + model_args: ModelArguments, + model: PreTrainedModel, + teacher: Optional[PreTrainedModel] = None, ) -> Processor: - processor_src = model_args.processor - processor_src = processor_src or get_shared_processor_src(model, teacher) + processor_src = model_args.processor or get_processor_from_model(model, teacher) # The use_fast=True option is not currently supported safely in Transformers # See: https://github.com/huggingface/transformers/pull/34836#issuecomment-2491809727 # noqa: E501 try: @@ -288,7 +351,8 @@ def initialize_processor_from_path( def main( model_args: ModelArguments, - data_args: DataTrainingArguments, + data_args: DatasetArguments, + recipe_args: RecipeArguments, training_args: TrainingArguments, ): """ @@ -323,8 +387,8 @@ def main( ) # Setup based on stage types if running stage mode - if training_args.run_stages and training_args.recipe is not None: - recipe_obj = Recipe.create_instance(training_args.recipe) + if training_args.run_stages and recipe_args.recipe is not None: + recipe_obj = Recipe.create_instance(recipe_args.recipe) for stage in recipe_obj.stages: run_type = stage.infer_run_type() if run_type is StageRunType.ONESHOT: @@ -348,7 +412,7 @@ def main( model = model_args.model if isinstance(model, str) or isinstance(model, PosixPath): - (teacher, _model_path, model) = initialize_model_from_path( + (model, teacher) = initialize_model_from_path( model_args, training_args, ) @@ -371,7 +435,10 @@ def main( # Load datasets stage_runner = StageRunner( - model_args=model_args, data_args=data_args, training_args=training_args + model_args=model_args, + data_args=data_args, + training_args=training_args, + recipe_args=recipe_args, ) add_labels = training_args.do_train or training_args.run_stages stage_runner.populate_datasets(processor=processor, add_labels=add_labels) @@ -379,13 +446,13 @@ def main( eval_dataset = stage_runner.get_dataset_split("validation") calib_dataset = stage_runner.get_dataset_split("calibration") - # Initialize our Trainer trainer = Trainer( model_init=get_session_model, teacher=teacher, - recipe=training_args.recipe, - recipe_args=training_args.recipe_args, + recipe=recipe_args.recipe, + recipe_args=recipe_args.recipe_args, args=training_args, + model_args=model_args, data_args=data_args, train_dataset=train_dataset or calib_dataset, eval_dataset=eval_dataset, @@ -437,13 +504,13 @@ def main( != TrainingArguments.__dataclass_fields__["output_dir"].default ): model.save_pretrained( - training_args.output_dir, save_compressed=training_args.save_compressed + training_args.output_dir, save_compressed=model_args.save_compressed ) if processor is not None: processor.save_pretrained(training_args.output_dir) # Clean up the CompressionSession before exit if requested - if training_args.clear_sparse_session: + if recipe_args.clear_sparse_session: reset_session() diff --git a/src/llmcompressor/transformers/finetune/training_args.py b/src/llmcompressor/transformers/finetune/training_args.py deleted file mode 100644 index c04fa2807..000000000 --- a/src/llmcompressor/transformers/finetune/training_args.py +++ /dev/null @@ -1,71 +0,0 @@ -from dataclasses import dataclass, field -from typing import List, Optional - -from transformers import TrainingArguments as HFTrainingArgs - -__all__ = ["TrainingArguments"] - - -@dataclass -class TrainingArguments(HFTrainingArgs): - """ - Training arguments specific to LLM Compressor Transformers workflow - - :param best_model_after_epoch (`int`, *optional*, defaults to None): - The epoch after which best model will be saved; used in conjunction - with `load_best_model_at_end` and `metric_for_best_model` training - arguments - """ - - recipe: Optional[str] = field( - default=None, - metadata={ - "help": "Path to a LLM Compressor sparsification recipe", - }, - ) - recipe_args: Optional[List[str]] = field( - default=None, - metadata={ - "help": ( - "List of recipe arguments to evaluate, of the format key1=value1 " - "key2=value2" - ) - }, - ) - save_compressed: Optional[bool] = field( - default=True, - metadata={"help": "Whether to compress sparse models during save"}, - ) - do_oneshot: Optional[bool] = field( - default=False, - metadata={"help": "Whether to run one-shot calibration"}, - ) - run_stages: Optional[bool] = field( - default=False, metadata={"help": "Whether to trigger recipe stage by stage"} - ) - oneshot_device: Optional[str] = field( - default="cuda:0", - metadata={"help": "Device to run oneshot calibration on"}, - ) - clear_sparse_session: Optional[bool] = field( - default=False, - metadata={"help": "Whether to clear CompressionSession data between runs."}, - ) - save_safetensors: Optional[bool] = field( - default=True, - metadata={ - "help": "Use safetensors saving and loading for state dicts instead of " - "default torch.load and torch.save." - }, - ) - output_dir: str = field( - default="./output", - metadata={ - "help": "The output directory where the model predictions and " - "checkpoints will be written." - }, - ) - - @property - def place_model_on_device(self): - return False diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index ec9951f6a..4cae242e5 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -100,7 +100,9 @@ def save_pretrained_wrapper( ) -def modify_save_pretrained(model: torch.nn.Module): +def modify_save_pretrained( + model: torch.nn.Module, +): """ Overrides a PreTrainedModel's save_pretrained() method with a wrapped version that supports compression @@ -209,8 +211,9 @@ def skip(*args, **kwargs): save_pretrained_wrapper._overriden = True return save_pretrained_wrapper - # wrap save_pretrained - model.save_pretrained = save_pretrained_compressed(model.save_pretrained) + # wrap save_pretrained if not already + if not getattr(model.save_pretrained, "_overriden", False): + model.save_pretrained = save_pretrained_compressed(model.save_pretrained) # HACK: Override the dtype_byte_size function in transformers to support float8 types diff --git a/src/llmcompressor/transformers/sparsification/sparse_model.py b/src/llmcompressor/transformers/sparsification/sparse_model.py index d7abc323a..57a9dbb78 100644 --- a/src/llmcompressor/transformers/sparsification/sparse_model.py +++ b/src/llmcompressor/transformers/sparsification/sparse_model.py @@ -7,7 +7,7 @@ __all__ = [ "SparseAutoModelForCausalLM", - "get_shared_processor_src", + "get_processor_from_model", ] @@ -20,7 +20,7 @@ def from_pretrained(*args, **kwargs): return AutoModelForCausalLM.from_pretrained(*args, **kwargs) -def get_shared_processor_src(student: Module, teacher: Optional[Module]) -> str: +def get_processor_from_model(student: Module, teacher: Optional[Module]) -> str: """ Get a processor/tokenizer source used for both student and teacher, assuming that they could be shared diff --git a/src/llmcompressor/transformers/utils/arg_parser/__init__.py b/src/llmcompressor/transformers/utils/arg_parser/__init__.py new file mode 100644 index 000000000..cbb9224af --- /dev/null +++ b/src/llmcompressor/transformers/utils/arg_parser/__init__.py @@ -0,0 +1,6 @@ +# flake8: noqa + +from .data_arguments import DatasetArguments +from .model_arguments import ModelArguments +from .recipe_arguments import RecipeArguments +from .training_arguments import DEFAULT_OUTPUT_DIR, TrainingArguments diff --git a/src/llmcompressor/transformers/finetune/data/data_args.py b/src/llmcompressor/transformers/utils/arg_parser/data_arguments.py similarity index 97% rename from src/llmcompressor/transformers/finetune/data/data_args.py rename to src/llmcompressor/transformers/utils/arg_parser/data_arguments.py index 7d0bc14ce..50d3277f4 100644 --- a/src/llmcompressor/transformers/finetune/data/data_args.py +++ b/src/llmcompressor/transformers/utils/arg_parser/data_arguments.py @@ -5,7 +5,7 @@ @dataclass -class DVCDatasetTrainingArguments: +class DVCDatasetArguments: """ Arguments for training using DVC """ @@ -17,7 +17,7 @@ class DVCDatasetTrainingArguments: @dataclass -class CustomDataTrainingArguments(DVCDatasetTrainingArguments): +class CustomDatasetArguments(DVCDatasetArguments): """ Arguments for training using custom datasets """ @@ -67,10 +67,10 @@ class CustomDataTrainingArguments(DVCDatasetTrainingArguments): @dataclass -class DataTrainingArguments(CustomDataTrainingArguments): +class DatasetArguments(CustomDatasetArguments): """ Arguments pertaining to what data we are going to input our model for - training and eval + calibration, training or eval Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify them on the command line diff --git a/src/llmcompressor/transformers/finetune/model_args.py b/src/llmcompressor/transformers/utils/arg_parser/model_arguments.py similarity index 85% rename from src/llmcompressor/transformers/finetune/model_args.py rename to src/llmcompressor/transformers/utils/arg_parser/model_arguments.py index c81900ee2..ce424812a 100644 --- a/src/llmcompressor/transformers/finetune/model_args.py +++ b/src/llmcompressor/transformers/utils/arg_parser/model_arguments.py @@ -5,7 +5,9 @@ @dataclass class ModelArguments: """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune from + Model variables used for oneshot calibration, training or finetuning and + stage runners (combination of oneshot and finetune going back and forth) + """ model: str = field( @@ -44,17 +46,7 @@ class ModelArguments: default=None, metadata={"help": "Where to store the pretrained data from huggingface.co"}, ) - use_fast_tokenizer: bool = field( - default=True, - metadata={"help": "Whether to use one of the fast tokenizers. Default True"}, - ) - model_revision: str = field( - default="main", - metadata={ - "help": "The specific model version to use " - "(can be a branch name, tag name or commit id)" - }, - ) + use_auth_token: bool = field( default=False, metadata={ @@ -83,3 +75,18 @@ class ModelArguments: "repositories you trust and in which you have read the code" }, ) + save_compressed: Optional[bool] = field( + default=True, + metadata={"help": "Whether to compress sparse models during save"}, + ) + oneshot_device: Optional[str] = field( + default="cuda:0", + metadata={"help": "Device to run oneshot calibration on"}, + ) + model_revision: str = field( + default="main", + metadata={ + "help": "The specific model version to use " + "(can be a branch name, tag name or commit id)" + }, + ) diff --git a/src/llmcompressor/transformers/utils/arg_parser/recipe_arguments.py b/src/llmcompressor/transformers/utils/arg_parser/recipe_arguments.py new file mode 100644 index 000000000..fbe535d7e --- /dev/null +++ b/src/llmcompressor/transformers/utils/arg_parser/recipe_arguments.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass, field +from typing import List, Optional + + +@dataclass +class RecipeArguments: + """Recipe and session variables""" + + recipe: Optional[str] = field( + default=None, + metadata={ + "help": "Path to a LLM Compressor sparsification recipe", + }, + ) + recipe_args: Optional[List[str]] = field( + default=None, + metadata={ + "help": ( + "List of recipe arguments to evaluate, of the format key1=value1 " + "key2=value2" + ) + }, + ) + clear_sparse_session: Optional[bool] = field( + default=False, + metadata={ + "help": ( + "Whether to clear CompressionSession/CompressionLifecycle ", + "data between runs.", + ) + }, + ) diff --git a/src/llmcompressor/transformers/utils/arg_parser/training_arguments.py b/src/llmcompressor/transformers/utils/arg_parser/training_arguments.py new file mode 100644 index 000000000..7b61193b0 --- /dev/null +++ b/src/llmcompressor/transformers/utils/arg_parser/training_arguments.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass, field +from typing import Optional + +from transformers import TrainingArguments as HFTrainingArgs + +__all__ = ["TrainingArguments", "DEFAULT_OUTPUT_DIR"] + +DEFAULT_OUTPUT_DIR = "./output" + + +@dataclass +class TrainingArguments(HFTrainingArgs): + """ + Training arguments specific to LLM Compressor Transformers workflow using + HFTrainingArgs as base class + + """ + + do_oneshot: Optional[bool] = field( + default=False, + metadata={"help": "Whether to run one-shot calibration in stages"}, + ) + run_stages: Optional[bool] = field( + default=False, metadata={"help": "Whether to trigger recipe stage by stage"} + ) + output_dir: str = field( + default=DEFAULT_OUTPUT_DIR, + metadata={ + "help": "The output directory where the model predictions and " + "checkpoints will be written." + }, + ) diff --git a/src/llmcompressor/transformers/utils/arg_parser/utils.py b/src/llmcompressor/transformers/utils/arg_parser/utils.py new file mode 100644 index 000000000..48455fa15 --- /dev/null +++ b/src/llmcompressor/transformers/utils/arg_parser/utils.py @@ -0,0 +1,30 @@ +from dataclasses import fields +from typing import Any, Dict, Union + +from .data_arguments import DatasetArguments +from .model_arguments import ModelArguments +from .recipe_arguments import RecipeArguments +from .training_arguments import TrainingArguments + +__all__ = [ + "get_dataclass_as_dict", +] + + +def get_dataclass_as_dict( + dataclass_instance: Union[ + "ModelArguments", "RecipeArguments", "DatasetArguments", "TrainingArguments" + ], + dataclass_class: Union[ + "ModelArguments", "RecipeArguments", "DatasetArguments", "TrainingArguments" + ], +) -> Dict[str, Any]: + """ + Get the dataclass instance attributes as a dict, neglicting the inherited class. + Ex. dataclass_class=TrainingArguments will ignore HFTrainignArguments + + """ + return { + field.name: getattr(dataclass_instance, field.name) + for field in fields(dataclass_class) + } diff --git a/src/llmcompressor/transformers/utils/helpers.py b/src/llmcompressor/transformers/utils/helpers.py index c1dcef119..80c4b446e 100644 --- a/src/llmcompressor/transformers/utils/helpers.py +++ b/src/llmcompressor/transformers/utils/helpers.py @@ -14,7 +14,10 @@ from transformers.trainer_utils import get_last_checkpoint if TYPE_CHECKING: - from llmcompressor.transformers import ModelArguments, TrainingArguments + from llmcompressor.transformers.utils.arg_parser import ( + ModelArguments, + TrainingArguments, + ) __all__ = [ "RECIPE_FILE_NAME", diff --git a/tests/llmcompressor/transformers/compression/test_quantization.py b/tests/llmcompressor/transformers/compression/test_quantization.py index 13eab66c9..cefcdaa54 100644 --- a/tests/llmcompressor/transformers/compression/test_quantization.py +++ b/tests/llmcompressor/transformers/compression/test_quantization.py @@ -13,7 +13,7 @@ from llmcompressor.pytorch.utils import tensors_to_device from llmcompressor.transformers import oneshot from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments +from llmcompressor.transformers.utils.arg_parser import DatasetArguments from tests.testing_utils import parse_params, requires_gpu CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/compression/configs" @@ -59,10 +59,9 @@ def _run_oneshot(model, recipe, dataset, output_dir): max_seq_length = 512 pad_to_max_length = False - oneshot( + oneshot_run = oneshot( model=model, dataset=dataset, - overwrite_output_dir=True, output_dir=output_dir, max_seq_length=max_seq_length, num_calibration_samples=num_calibration_samples, @@ -72,10 +71,8 @@ def _run_oneshot(model, recipe, dataset, output_dir): splits={"calibration": "train_gen[:5%]"}, save_compressed=False, ) - from llmcompressor.pytorch.model_load.helpers import get_session_model - # note: get_session_model() is None outside of function scope - return get_session_model() + return oneshot_run.model def _get_quant_info(self, model): quant_info_weights = {} @@ -147,7 +144,7 @@ def _get_dataloader(self, data_args, tokenizer): @torch.no_grad() def test_perplexity(self): tokenizer = AutoTokenizer.from_pretrained(self.model_stub) - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset="ultrachat-200k", max_seq_length=self.max_seq_length, ) diff --git a/tests/llmcompressor/transformers/finetune/data/conftest.py b/tests/llmcompressor/transformers/finetune/data/conftest.py index a7a347d99..a4182721d 100644 --- a/tests/llmcompressor/transformers/finetune/data/conftest.py +++ b/tests/llmcompressor/transformers/finetune/data/conftest.py @@ -1,7 +1,7 @@ import pytest from transformers import AutoTokenizer -from llmcompressor.transformers.finetune.model_args import ModelArguments +from llmcompressor.transformers.utils.arg_parser import ModelArguments @pytest.fixture diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py index 812b26a56..4b907b6a0 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py @@ -1,15 +1,15 @@ import pytest -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.transformers.finetune.data.data_helpers import ( get_raw_dataset, make_dataset_splits, ) +from llmcompressor.transformers.utils.arg_parser import DatasetArguments @pytest.mark.unit def test_combined_datasets(): - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1" ) raw_wikitext2 = get_raw_dataset(data_args) @@ -33,7 +33,7 @@ def test_combined_datasets(): @pytest.mark.unit def test_separate_datasets(): splits = {"train": "train[:10%]", "validation": "train[10%:20%]"} - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1" ) datasets = {} diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py index 64514b252..75be8102c 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py @@ -5,22 +5,23 @@ from datasets import IterableDataset, load_dataset from parameterized import parameterized -from llmcompressor.transformers import ( - DataTrainingArguments, - ModelArguments, - TextGenerationDataset, - TrainingArguments, -) +from llmcompressor.transformers import TextGenerationDataset from llmcompressor.transformers.finetune.data.data_helpers import ( format_calibration_data, ) from llmcompressor.transformers.finetune.runner import StageRunner +from llmcompressor.transformers.utils.arg_parser import ( + DatasetArguments, + ModelArguments, + RecipeArguments, + TrainingArguments, +) @pytest.mark.unit class TestConcentrationTokenization(unittest.TestCase): def setUp(self): - self.data_args = DataTrainingArguments( + self.data_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1", concatenate_data=True, @@ -53,7 +54,7 @@ def test_concatenation_tokenization(self): @pytest.mark.unit class TestNoPaddingTokenization(unittest.TestCase): def setUp(self): - self.data_args = DataTrainingArguments( + self.data_args = DatasetArguments( dataset="open_platypus", pad_to_max_length=False ) @@ -96,9 +97,7 @@ def test_no_padding_tokenization(self): @pytest.mark.unit class TestMaxSeqLenClipped(unittest.TestCase): def setUp(self): - self.data_args = DataTrainingArguments( - dataset="open_platypus", max_seq_length=4096 - ) + self.data_args = DatasetArguments(dataset="open_platypus", max_seq_length=4096) @pytest.fixture(autouse=True) def prepare_fixture(self, tiny_llama_tokenizer): @@ -120,7 +119,7 @@ def test_max_seq_len_clipped(self): @pytest.mark.unit class TestDatasetKwargsAndPercent(unittest.TestCase): def setUp(self): - self.data_args = DataTrainingArguments( + self.data_args = DatasetArguments( dataset="wikitext", raw_kwargs={ "data_files": { @@ -167,7 +166,7 @@ def prepare_fixture(self, tiny_llama_tokenizer): ] ) def test_datasets(self, dataset_key, dataset_config, split, do_concat): - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset=dataset_key, dataset_config_name=dataset_config, concatenate_data=do_concat, @@ -206,7 +205,7 @@ def prepare_fixture(self, tiny_llama_tokenizer): self.tiny_llama_tokenizer = tiny_llama_tokenizer def setUp(self): - self.data_args = DataTrainingArguments( + self.data_args = DatasetArguments( dataset="evolcodealpaca", dataset_config_name=None, concatenate_data=False, @@ -235,7 +234,7 @@ def test_evol(self): @pytest.mark.unit class TestStreamLoading(unittest.TestCase): def setUp(self): - self.data_args = DataTrainingArguments( + self.data_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1", concatenate_data=True, @@ -276,15 +275,19 @@ def prepare_fixture(self, tiny_llama_tokenizer): [["train"], ["train[60%:]"], [{"train": "train[:20%]"}], [None]] ) def test_split_loading(self, split_def): - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset="open_platypus", splits=split_def, trust_remote_code_data=True, ) training_args = TrainingArguments(do_train=True, output_dir="dummy") model_args = ModelArguments(model=None) + recipe_args = RecipeArguments() stage_runner = StageRunner( - model_args=model_args, data_args=data_args, training_args=training_args + model_args=model_args, + data_args=data_args, + training_args=training_args, + recipe_args=recipe_args, ) stage_runner.populate_datasets(processor=self.tiny_llama_tokenizer) @@ -318,10 +321,11 @@ def preprocess(sample): ) stage_runner = StageRunner( model_args=None, - data_args=DataTrainingArguments( + data_args=DatasetArguments( dataset=tokenized_dataset, shuffle_calibration_samples=False ), training_args=TrainingArguments(do_oneshot=True), + recipe_args=RecipeArguments(), ) stage_runner.populate_datasets(processor=None) calib_dataset = stage_runner.get_dataset_split("calibration") diff --git a/tests/llmcompressor/transformers/finetune/data/test_registry.py b/tests/llmcompressor/transformers/finetune/data/test_registry.py index 9aee4c20f..11dc9034f 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_registry.py +++ b/tests/llmcompressor/transformers/finetune/data/test_registry.py @@ -6,12 +6,12 @@ TextGenerationDataset, WikiTextDataset, ) -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments +from llmcompressor.transformers.utils.arg_parser import DatasetArguments @pytest.mark.usefixtures("tiny_llama_tokenizer") def test_c4_initializes(tiny_llama_tokenizer): - data_args = DataTrainingArguments(dataset="c4", concatenate_data=True) + data_args = DatasetArguments(dataset="c4", concatenate_data=True) c4_manager = TextGenerationDataset.load_from_registry( data_args.dataset, data_args=data_args, @@ -27,7 +27,7 @@ def test_c4_initializes(tiny_llama_tokenizer): @pytest.mark.usefixtures("tiny_llama_tokenizer") def test_wikitext_initializes(tiny_llama_tokenizer): - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1" ) wiki_manager = TextGenerationDataset.load_from_registry( @@ -45,7 +45,7 @@ def test_wikitext_initializes(tiny_llama_tokenizer): @pytest.mark.usefixtures("tiny_llama_tokenizer") def test_open_platypus_initializes(tiny_llama_tokenizer): - data_args = DataTrainingArguments(dataset="open_platypus", pad_to_max_length=False) + data_args = DatasetArguments(dataset="open_platypus", pad_to_max_length=False) op_manager = TextGenerationDataset.load_from_registry( data_args.dataset, data_args=data_args, diff --git a/tests/llmcompressor/transformers/gptq/test_oneshot.py b/tests/llmcompressor/transformers/gptq/test_oneshot.py index 7f1a1ec99..d75386b94 100644 --- a/tests/llmcompressor/transformers/gptq/test_oneshot.py +++ b/tests/llmcompressor/transformers/gptq/test_oneshot.py @@ -75,7 +75,6 @@ def test_oneshot_application(self): model=self.model, dataset=self.dataset, output_dir=self.output, - overwrite_output_dir=True, recipe=self.recipe, oneshot_device=self.device, num_calibration_samples=9, diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py index fe699570a..6a6fc9bf3 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py @@ -23,12 +23,10 @@ def labeled_dataloader(self, dataset_name, model_name): from transformers import AutoTokenizer, DefaultDataCollator from llmcompressor.transformers.finetune.data import TextGenerationDataset - from llmcompressor.transformers.finetune.data.data_args import ( - DataTrainingArguments, - ) + from llmcompressor.transformers.utils.arg_parser import DatasetArguments tokenizer = AutoTokenizer.from_pretrained(model_name) - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset=dataset_name, max_seq_length=512, pad_to_max_length=False, diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_sparsity.py b/tests/llmcompressor/transformers/obcq/test_obcq_sparsity.py index 0ef7f872d..f370d5ee1 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_sparsity.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_sparsity.py @@ -26,11 +26,10 @@ def setUp(self): self.output = "./oneshot_output" def test_sparsities(self): - from llmcompressor.pytorch.model_load.helpers import get_session_model from llmcompressor.pytorch.utils.helpers import tensor_sparsity from llmcompressor.transformers import oneshot - oneshot( + oneshot = oneshot( model=self.model, dataset=self.dataset, oneshot_device=self.device, @@ -42,7 +41,7 @@ def test_sparsities(self): output_dir=self.output, ) - model = get_session_model() + model = oneshot.model layer_1_sparse = tensor_sparsity(model.model.layers[1].self_attn.k_proj.weight) assert math.isclose(layer_1_sparse.item(), self.sparsity, rel_tol=1e-4) diff --git a/tests/llmcompressor/transformers/oneshot/test_cli.py b/tests/llmcompressor/transformers/oneshot/test_cli.py index 5780ca46f..803d624a3 100644 --- a/tests/llmcompressor/transformers/oneshot/test_cli.py +++ b/tests/llmcompressor/transformers/oneshot/test_cli.py @@ -41,16 +41,20 @@ def test_one_shot_cli(self): "--recipe", self.recipe, "--num_calibration_samples", - "10", + "16", "--pad_to_max_length", "False", ] if len(self.additional_args) > 0: cmd.extend(self.additional_args) + res = run_cli_command(cmd) - self.assertEqual(res.returncode, 0) - print(res.stdout) + + # oneshot has return arg + self.assertIsNone(res.stderr) def tearDown(self): - shutil.rmtree(self.output) + # if a test case was skipped + if hasattr(self, "output"): + shutil.rmtree(self.output) diff --git a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py index e2f51ab8e..c1fc65a51 100644 --- a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py +++ b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py @@ -90,7 +90,7 @@ def test_sparse_model_reload(compressed, config, dtype, tmp_path): rel_tol=1e-3, ) - inferred_structure = SparsityConfigMetadata.infer_sparsity_structure() + inferred_structure = SparsityConfigMetadata.infer_sparsity_structure(model) assert inferred_structure == "0:0" model.save_pretrained( @@ -167,8 +167,6 @@ def test_dense_model_save(tmp_path, skip_compression_stats, save_compressed): ], ) def test_quant_model_reload(format, dtype, tmp_path): - from llmcompressor.pytorch.model_load.helpers import get_session_model - recipe_str = ( "tests/llmcompressor/transformers/compression/recipes/new_quant_simple.yaml" ) @@ -182,7 +180,7 @@ def test_quant_model_reload(format, dtype, tmp_path): splits = {"calibration": "train[:10%]"} # create a quantized model - oneshot( + oneshot_run = oneshot( model=model_path, dataset=dataset, num_calibration_samples=num_calibration_samples, @@ -195,7 +193,7 @@ def test_quant_model_reload(format, dtype, tmp_path): ) # Fetch the oneshot model - model = get_session_model() + model = oneshot_run.model og_state_dict = model.state_dict() save_path_compressed = tmp_path / "compressed"