Skip to content

Commit

Permalink
squash
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <[email protected]>
  • Loading branch information
kylesayrs committed Jan 23, 2025
1 parent d5984db commit 59bdb66
Show file tree
Hide file tree
Showing 38 changed files with 7,022 additions and 1,335 deletions.
11 changes: 9 additions & 2 deletions examples/multimodal_vision/llava_example.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import requests
import torch
from PIL import Image
from transformers import AutoProcessor

from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.tracing import TraceableLlavaForConditionalGeneration
from llmcompressor.transformers.utils.data_collator import llava_data_collator

# Load model.
model_id = "llava-hf/llava-1.5-7b-hf"
Expand All @@ -20,6 +20,13 @@
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048


# Define a oneshot data collator for multimodal inputs.
def data_collator(batch):
assert len(batch) == 1
return {key: torch.tensor(value) for key, value in batch[0].items()}


# Recipe
recipe = [
GPTQModifier(
Expand All @@ -40,7 +47,7 @@
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
trust_remote_code_model=True,
data_collator=llava_data_collator,
data_collator=data_collator,
)

# Confirm generations of the quantized model look sane.
Expand Down
11 changes: 9 additions & 2 deletions examples/multimodal_vision/mllama_example.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import requests
import torch
from PIL import Image
from transformers import AutoProcessor

from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.tracing import TraceableMllamaForConditionalGeneration
from llmcompressor.transformers.utils.data_collator import mllama_data_collator

# Load model.
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
Expand All @@ -20,6 +20,13 @@
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048


# Define a oneshot data collator for multimodal inputs.
def data_collator(batch):
assert len(batch) == 1
return {key: torch.tensor(value) for key, value in batch[0].items()}


# Recipe
recipe = [
GPTQModifier(
Expand All @@ -39,7 +46,7 @@
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
trust_remote_code_model=True,
data_collator=mllama_data_collator,
data_collator=data_collator,
)

# Confirm generations of the quantized model look sane.
Expand Down
12 changes: 10 additions & 2 deletions examples/multimodal_vision/phi3_vision_example.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoProcessor

from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.utils.data_collator import phi3_vision_data_collator

# Load model.
model_id = "microsoft/Phi-3-vision-128k-instruct"
Expand Down Expand Up @@ -59,8 +60,15 @@ def tokenize(sample):
ds = ds.map(tokenize, writer_batch_size=1, remove_columns=ds.column_names)


# Define a oneshot data collator for multimodal inputs.
def data_collator(batch):
assert len(batch) == 1
return {key: torch.tensor(value) for key, value in batch[0].items()}


# Recipe
recipe = [
SmoothQuantModifier(smoothing_strength=0.8),
GPTQModifier(
targets="Linear",
scheme="W4A16",
Expand All @@ -77,7 +85,7 @@ def tokenize(sample):
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
trust_remote_code_model=True,
data_collator=phi3_vision_data_collator,
data_collator=data_collator,
)

# Confirm generations of the quantized model look sane.
Expand Down
15 changes: 13 additions & 2 deletions examples/multimodal_vision/pixtral_example.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import requests
import torch
from PIL import Image
from transformers import AutoProcessor

from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.tracing import TraceableLlavaForConditionalGeneration
from llmcompressor.transformers.utils.data_collator import pixtral_data_collator

# Load model.
model_id = "mgoin/pixtral-12b"
Expand All @@ -20,6 +20,17 @@
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048


# Define a oneshot data collator for multimodal inputs.
def data_collator(batch):
assert len(batch) == 1
return {
"input_ids": torch.LongTensor(batch[0]["input_ids"]),
"attention_mask": torch.tensor(batch[0]["attention_mask"]),
"pixel_values": torch.tensor(batch[0]["pixel_values"])[0],
}


# Recipe
recipe = [
GPTQModifier(
Expand All @@ -40,7 +51,7 @@
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
trust_remote_code_model=True,
data_collator=pixtral_data_collator,
data_collator=data_collator,
)

# Confirm generations of the quantized model look sane.
Expand Down
11 changes: 9 additions & 2 deletions examples/multimodal_vision/qwen2_vl_example.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import base64
from io import BytesIO

import torch
from datasets import load_dataset
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor

from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.tracing import TraceableQwen2VLForConditionalGeneration
from llmcompressor.transformers.utils.data_collator import qwen2_vl_data_collator

# Load model.
model_id = "Qwen/Qwen2-VL-2B-Instruct"
Expand Down Expand Up @@ -65,6 +65,13 @@ def preprocess_and_tokenize(example):

ds = ds.map(preprocess_and_tokenize, remove_columns=ds["calibration"].column_names)


# Define a oneshot data collator for multimodal inputs.
def data_collator(batch):
assert len(batch) == 1
return {key: torch.tensor(value) for key, value in batch[0].items()}


# Recipe
recipe = [
GPTQModifier(
Expand All @@ -84,7 +91,7 @@ def preprocess_and_tokenize(example):
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
trust_remote_code_model=True,
data_collator=qwen2_vl_data_collator,
data_collator=data_collator,
)

# Confirm generations of the quantized model look sane.
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
"llmcompressor.transformers.text_generation.finetune=llmcompressor.transformers.finetune.text_generation:train", # noqa 501
"llmcompressor.transformers.text_generation.eval=llmcompressor.transformers.finetune.text_generation:eval", # noqa 501
"llmcompressor.transformers.text_generation.oneshot=llmcompressor.transformers.finetune.text_generation:oneshot", # noqa 501
"llmcompressor.trace=llmcompressor.transformers.tracing.debug:main",
]
},
python_requires=">=3.8",
Expand Down
Loading

0 comments on commit 59bdb66

Please sign in to comment.