Skip to content

Commit

Permalink
VLM: Model Tracing Guide (#1030)
Browse files Browse the repository at this point in the history
## Purpose ##
This guide explains the concepts of tracing as they relate to LLM
Compressor and how to modify your model to support recipes which require
using the Sequential Pipeline.

Through reading this guide, you will learn
1. Why tracing is required when compressing with recipes involving the
Sequential Pipeline and modifiers such as GPTQModifier
2. How to determine if your model is traceable for your dataset
3. How to modify your model definition to be traceable

## Prerequisites ##
* #1031

## Changes ##
* Add a model tracing guide
`src/llmcompressor/transformers/tracing/README.md` with pictures
* Add a readme for the sequential pipeline which points to the Tracing
Guide `src/llmcompressor/pipelines/sequential/README.md`
* Add a debug script to help users debug their models for traceability
`src/llmcompressor/transformers/tracing/debug.py`
  * Add the `llm-compressor.attempt_trace` entrypoint for ease of use
* Swap the order of arguments in `llava_example.py` and and
`pixtral_example.py` to match the order of arguments on the modifier

## Testing ##
Use the `llmcompressor.attempt_trace` debug script
```bash
llmcompressor.attempt_trace \
    --model_id llava-hf/llava-1.5-7b-hf
    --model_class TraceableLlavaForConditionalGeneration
    --sequential-targets LlamaDecoderLayer
    --ignore "re:.*lm_head" "re:vision_tower.*" "re:multi_modal_projector.*"
    --multimodal_data
```

## Stretch ##
It might be nice if this tracing debugger tool also printed the model
graph to an svg

---------

Signed-off-by: Kyle Sayers <[email protected]>
Co-authored-by: Dipika Sikka <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
  • Loading branch information
3 people authored Jan 23, 2025
1 parent 6377f1e commit e48d9db
Show file tree
Hide file tree
Showing 10 changed files with 5,908 additions and 1 deletion.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
"llmcompressor.transformers.text_generation.finetune=llmcompressor.transformers.finetune.text_generation:train", # noqa 501
"llmcompressor.transformers.text_generation.eval=llmcompressor.transformers.finetune.text_generation:eval", # noqa 501
"llmcompressor.transformers.text_generation.oneshot=llmcompressor.transformers.finetune.text_generation:oneshot", # noqa 501
"llmcompressor.trace=llmcompressor.transformers.tracing.debug:main",
]
},
python_requires=">=3.8",
Expand Down
6 changes: 5 additions & 1 deletion src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,11 @@ def on_initialize(self, state: State, **kwargs) -> bool:

except Exception as exception:
if isinstance(exception, torch.fx.proxy.TraceError):
warnings.warn(f"Failed to trace {model_name} with inputs {input_names}")
warnings.warn(
f"Failed to trace {model_name} with inputs {input_names}. For more "
"information on tracing with the sequential pipeline, see "
"`src/llmcompressor/transformers/tracing/GUIDE.md`"
)
if isinstance(exception, unfixable_errors):
raise exception

Expand Down
6 changes: 6 additions & 0 deletions src/llmcompressor/pipelines/sequential/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Sequential Pipeline #
The sequential pipeline is a data pipeline, primarily used for compressing models with the
[GPTQModifier](/src/llmcompressor/modifiers/quantization/gptq/base.py).

If, when using this pipeline, you encounter a `torch.fx.proxy.TraceError`, see the
[Model Tracing Guide](/src/llmcompressor/transformers/tracing/GUIDE.md).
441 changes: 441 additions & 0 deletions src/llmcompressor/transformers/tracing/GUIDE.md

Large diffs are not rendered by default.

5,319 changes: 5,319 additions & 0 deletions src/llmcompressor/transformers/tracing/assets/Llama_3.2-Vision.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
136 changes: 136 additions & 0 deletions src/llmcompressor/transformers/tracing/debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import List, Type, Union, Optional, Dict

import argparse

import torch
import transformers
from transformers import AutoProcessor, PreTrainedModel

from llmcompressor.transformers import tracing
from llmcompressor.utils.pytorch.module import get_no_split_params
from llmcompressor.pipelines.sequential.helpers import trace_subgraphs
from llmcompressor.transformers import DataTrainingArguments, TextGenerationDataset


def parse_args():
parser = argparse.ArgumentParser(description="Trace a model into subgraphs")
parser.add_argument("--model_id", type=str, required=True, help="The stub of the model to load") # noqa: E501
parser.add_argument("--model_class", type=str, required=True, help="The class name of the model") # noqa: E501
parser.add_argument("--sequential_targets", type=str, nargs="*", default=None, metavar="TARGET", help="List of targets for sequential tracing") # noqa: E501
parser.add_argument("--ignore", type=str, nargs="*", default=[], metavar="PATTERN", help="List of patterns to ignore during tracing") # noqa: E501
parser.add_argument("--modality", type=str, default="text", help="Modality of calibration dataset, defaults to text") # noqa: E501
return parser.parse_args()


def trace(
model_id: str,
model_class: Type[PreTrainedModel],
sequential_targets: Optional[Union[List[str], str]] = None,
ignore: Union[List[str], str] = [],
modality: str = "text",
):
"""
Debug traceability by tracing a pre-trained model into subgraphs
:param model_id: stub of the model to load
:param model_class: class constructor of the pre-trained model. Can use either
HF transformers classes or `Traceable` classes defined by LLM Compressor
:param sequential_targets: targets for sequential tracing, defaults to automatic
inference
:param ignore: patterns to ignore during tracing
:param modality: data modality for dummy tracing data, defaults to 'text'
Example usage from CLI
llmcompressor.trace \
--model_id Qwen/Qwen2-VL-2B-Instruct \
--model_class Qwen2VLForConditionalGeneration \
--sequential_targets Qwen2VLDecoderLayer \
--ignore "lm_head" "re:visual.*" \
--modality text
"""
# Load model
model = model_class.from_pretrained(
model_id,
device_map="auto",
torch_dtype="auto",
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
print("Loaded model")

# Prepare sample data
data_args = DataTrainingArguments(**get_dataset_kwargs(modality))
dataset = TextGenerationDataset.load_from_registry(
data_args.dataset,
data_args=data_args,
split=data_args.splits["calibration"],
processor=processor,
)(add_labels=False)
sample_input = next(iter(dataset))
sample_input = {k: torch.tensor(v) for k, v in sample_input.items()}
print("Loaded sample data")

# infer sequential targets
if sequential_targets is None:
sequential_targets = get_no_split_params(model)
if isinstance(sequential_targets, str):
sequential_targets = [sequential_targets]

# infer ignore
if isinstance(ignore, str):
ignore = [ignore]

# Attempt trace
print(
"\nAttempting trace\n"
f" model_id={model_id}\n"
f" model_class={model_class.__name__}\n"
f" dataset={data_args.dataset}\n"
f" split={dataset.split}\n"
f" inputs={sample_input.keys()}\n"
f" sequential_targets={sequential_targets}\n"
f" ignore={ignore}\n"
)
subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore)
print(f"Successfully traced model into {len(subgraphs)} subgraphs!\n")


def get_model_class(model_class: str) -> Type[PreTrainedModel]:
model_cls = getattr(tracing, model_class, getattr(transformers, model_class, None))
if model_cls is None:
raise ValueError(f"Could not import model class {model_class}")

return model_cls


def get_dataset_kwargs(modality: str) -> Dict[str, str]:
dataset_kwargs = {
"text": {
"dataset": "ultrachat-200k",
"splits": {"calibration": "test_sft[:1]"},
},
"vision": {
"dataset": "flickr",
"splits": {"calibration": "test[:1]"},
},
}

if modality not in dataset_kwargs:
raise ValueError(f"Modality must be one of {list(dataset_kwargs.keys())}")

return dataset_kwargs[modality]


def main():
args = parse_args()

trace(
model_id=args.model_id,
model_class=get_model_class(args.model_class),
sequential_targets=args.sequential_targets,
ignore=args.ignore,
modality=args.modality,
)


if __name__ == "__main__":
main()

0 comments on commit e48d9db

Please sign in to comment.