Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix missing call in QuantizedTransformersModel #325

Merged
merged 2 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions examples/nlp/text-generation/quantize_causal_lm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from optimum.quanto import Calibration, freeze, qfloat8, qint4, qint8, quantize
from optimum.quanto import Calibration, QuantizedModelForCausalLM, qfloat8, qint4, qint8


@torch.no_grad()
Expand Down Expand Up @@ -133,15 +133,14 @@ def main():
print(f"{args.model} (w: {args.weights}, a: {args.activations})")
weights = keyword_to_itype(args.weights)
activations = keyword_to_itype(args.activations)
quantize(model, weights=weights, activations=activations)
qmodel = QuantizedModelForCausalLM.quantize(model, weights=weights, activations=activations)
if activations is not None:
print("Calibrating ...")
cal_dataset.shuffle(args.seed)
with Calibration(streamline=args.no_streamline, debug=args.debug):
cal_samples = args.batch_size * args.validation_batch
calibrate(model, tokenizer, cal_dataset, device, args.batch_size, samples=cal_samples)
freeze(model)
generate(model, tokenizer, device, args.prompt, args.max_new_tokens)
calibrate(qmodel, tokenizer, cal_dataset, device, args.batch_size, samples=cal_samples)
generate(qmodel, tokenizer, device, args.prompt, args.max_new_tokens)


if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions optimum/quanto/models/transformers_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def __getattr__(self, name: str) -> Any:
def forward(self, *args, **kwargs):
return self._wrapped.forward(*args, **kwargs)

def __call__(self, *args, **kwargs):
return self._wrapped.forward(*args, **kwargs)

@staticmethod
def _qmap_name():
return f"{QuantizedTransformersModel.BASE_NAME}_qmap.json"
Expand Down
Loading