-
Notifications
You must be signed in to change notification settings - Fork 75
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
## Purpose ## This guide explains the concepts of tracing as they relate to LLM Compressor and how to modify your model to support recipes which require using the Sequential Pipeline. Through reading this guide, you will learn 1. Why tracing is required when compressing with recipes involving the Sequential Pipeline and modifiers such as GPTQModifier 2. How to determine if your model is traceable for your dataset 3. How to modify your model definition to be traceable ## Prerequisites ## * #1031 ## Changes ## * Add a model tracing guide `src/llmcompressor/transformers/tracing/README.md` with pictures * Add a readme for the sequential pipeline which points to the Tracing Guide `src/llmcompressor/pipelines/sequential/README.md` * Add a debug script to help users debug their models for traceability `src/llmcompressor/transformers/tracing/debug.py` * Add the `llm-compressor.attempt_trace` entrypoint for ease of use * Swap the order of arguments in `llava_example.py` and and `pixtral_example.py` to match the order of arguments on the modifier ## Testing ## Use the `llmcompressor.attempt_trace` debug script ```bash llmcompressor.attempt_trace \ --model_id llava-hf/llava-1.5-7b-hf --model_class TraceableLlavaForConditionalGeneration --sequential-targets LlamaDecoderLayer --ignore "re:.*lm_head" "re:vision_tower.*" "re:multi_modal_projector.*" --multimodal_data ``` ## Stretch ## It might be nice if this tracing debugger tool also printed the model graph to an svg --------- Signed-off-by: Kyle Sayers <[email protected]> Co-authored-by: Dipika Sikka <[email protected]> Co-authored-by: Michael Goin <[email protected]>
- Loading branch information
1 parent
6377f1e
commit e48d9db
Showing
10 changed files
with
5,908 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Sequential Pipeline # | ||
The sequential pipeline is a data pipeline, primarily used for compressing models with the | ||
[GPTQModifier](/src/llmcompressor/modifiers/quantization/gptq/base.py). | ||
|
||
If, when using this pipeline, you encounter a `torch.fx.proxy.TraceError`, see the | ||
[Model Tracing Guide](/src/llmcompressor/transformers/tracing/GUIDE.md). |
Large diffs are not rendered by default.
Oops, something went wrong.
5,319 changes: 5,319 additions & 0 deletions
5,319
src/llmcompressor/transformers/tracing/assets/Llama_3.2-Vision.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
from typing import List, Type, Union, Optional, Dict | ||
|
||
import argparse | ||
|
||
import torch | ||
import transformers | ||
from transformers import AutoProcessor, PreTrainedModel | ||
|
||
from llmcompressor.transformers import tracing | ||
from llmcompressor.utils.pytorch.module import get_no_split_params | ||
from llmcompressor.pipelines.sequential.helpers import trace_subgraphs | ||
from llmcompressor.transformers import DataTrainingArguments, TextGenerationDataset | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description="Trace a model into subgraphs") | ||
parser.add_argument("--model_id", type=str, required=True, help="The stub of the model to load") # noqa: E501 | ||
parser.add_argument("--model_class", type=str, required=True, help="The class name of the model") # noqa: E501 | ||
parser.add_argument("--sequential_targets", type=str, nargs="*", default=None, metavar="TARGET", help="List of targets for sequential tracing") # noqa: E501 | ||
parser.add_argument("--ignore", type=str, nargs="*", default=[], metavar="PATTERN", help="List of patterns to ignore during tracing") # noqa: E501 | ||
parser.add_argument("--modality", type=str, default="text", help="Modality of calibration dataset, defaults to text") # noqa: E501 | ||
return parser.parse_args() | ||
|
||
|
||
def trace( | ||
model_id: str, | ||
model_class: Type[PreTrainedModel], | ||
sequential_targets: Optional[Union[List[str], str]] = None, | ||
ignore: Union[List[str], str] = [], | ||
modality: str = "text", | ||
): | ||
""" | ||
Debug traceability by tracing a pre-trained model into subgraphs | ||
:param model_id: stub of the model to load | ||
:param model_class: class constructor of the pre-trained model. Can use either | ||
HF transformers classes or `Traceable` classes defined by LLM Compressor | ||
:param sequential_targets: targets for sequential tracing, defaults to automatic | ||
inference | ||
:param ignore: patterns to ignore during tracing | ||
:param modality: data modality for dummy tracing data, defaults to 'text' | ||
Example usage from CLI | ||
llmcompressor.trace \ | ||
--model_id Qwen/Qwen2-VL-2B-Instruct \ | ||
--model_class Qwen2VLForConditionalGeneration \ | ||
--sequential_targets Qwen2VLDecoderLayer \ | ||
--ignore "lm_head" "re:visual.*" \ | ||
--modality text | ||
""" | ||
# Load model | ||
model = model_class.from_pretrained( | ||
model_id, | ||
device_map="auto", | ||
torch_dtype="auto", | ||
) | ||
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | ||
print("Loaded model") | ||
|
||
# Prepare sample data | ||
data_args = DataTrainingArguments(**get_dataset_kwargs(modality)) | ||
dataset = TextGenerationDataset.load_from_registry( | ||
data_args.dataset, | ||
data_args=data_args, | ||
split=data_args.splits["calibration"], | ||
processor=processor, | ||
)(add_labels=False) | ||
sample_input = next(iter(dataset)) | ||
sample_input = {k: torch.tensor(v) for k, v in sample_input.items()} | ||
print("Loaded sample data") | ||
|
||
# infer sequential targets | ||
if sequential_targets is None: | ||
sequential_targets = get_no_split_params(model) | ||
if isinstance(sequential_targets, str): | ||
sequential_targets = [sequential_targets] | ||
|
||
# infer ignore | ||
if isinstance(ignore, str): | ||
ignore = [ignore] | ||
|
||
# Attempt trace | ||
print( | ||
"\nAttempting trace\n" | ||
f" model_id={model_id}\n" | ||
f" model_class={model_class.__name__}\n" | ||
f" dataset={data_args.dataset}\n" | ||
f" split={dataset.split}\n" | ||
f" inputs={sample_input.keys()}\n" | ||
f" sequential_targets={sequential_targets}\n" | ||
f" ignore={ignore}\n" | ||
) | ||
subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore) | ||
print(f"Successfully traced model into {len(subgraphs)} subgraphs!\n") | ||
|
||
|
||
def get_model_class(model_class: str) -> Type[PreTrainedModel]: | ||
model_cls = getattr(tracing, model_class, getattr(transformers, model_class, None)) | ||
if model_cls is None: | ||
raise ValueError(f"Could not import model class {model_class}") | ||
|
||
return model_cls | ||
|
||
|
||
def get_dataset_kwargs(modality: str) -> Dict[str, str]: | ||
dataset_kwargs = { | ||
"text": { | ||
"dataset": "ultrachat-200k", | ||
"splits": {"calibration": "test_sft[:1]"}, | ||
}, | ||
"vision": { | ||
"dataset": "flickr", | ||
"splits": {"calibration": "test[:1]"}, | ||
}, | ||
} | ||
|
||
if modality not in dataset_kwargs: | ||
raise ValueError(f"Modality must be one of {list(dataset_kwargs.keys())}") | ||
|
||
return dataset_kwargs[modality] | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
|
||
trace( | ||
model_id=args.model_id, | ||
model_class=get_model_class(args.model_class), | ||
sequential_targets=args.sequential_targets, | ||
ignore=args.ignore, | ||
modality=args.modality, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |