From eb281bef3f5d96a71fd59a1ee1dd4b3c94f2158c Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Mon, 30 Sep 2024 12:19:15 +0000 Subject: [PATCH 1/2] fix(transformers): add missing call implementation --- optimum/quanto/models/transformers_models.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/optimum/quanto/models/transformers_models.py b/optimum/quanto/models/transformers_models.py index 2d031a69..24580a0d 100644 --- a/optimum/quanto/models/transformers_models.py +++ b/optimum/quanto/models/transformers_models.py @@ -56,6 +56,9 @@ def __getattr__(self, name: str) -> Any: def forward(self, *args, **kwargs): return self._wrapped.forward(*args, **kwargs) + def __call__(self, *args, **kwargs): + return self._wrapped.forward(*args, **kwargs) + @staticmethod def _qmap_name(): return f"{QuantizedTransformersModel.BASE_NAME}_qmap.json" From 1a73135a8d10a07cd973380527d981ba6f04f185 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Mon, 30 Sep 2024 12:19:56 +0000 Subject: [PATCH 2/2] feat(examples): use QuantizedModelForCausalLM --- examples/nlp/text-generation/quantize_causal_lm_model.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/nlp/text-generation/quantize_causal_lm_model.py b/examples/nlp/text-generation/quantize_causal_lm_model.py index 4099286c..da279ab2 100644 --- a/examples/nlp/text-generation/quantize_causal_lm_model.py +++ b/examples/nlp/text-generation/quantize_causal_lm_model.py @@ -19,7 +19,7 @@ from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer -from optimum.quanto import Calibration, freeze, qfloat8, qint4, qint8, quantize +from optimum.quanto import Calibration, QuantizedModelForCausalLM, qfloat8, qint4, qint8 @torch.no_grad() @@ -133,15 +133,14 @@ def main(): print(f"{args.model} (w: {args.weights}, a: {args.activations})") weights = keyword_to_itype(args.weights) activations = keyword_to_itype(args.activations) - quantize(model, weights=weights, activations=activations) + qmodel = QuantizedModelForCausalLM.quantize(model, weights=weights, activations=activations) if activations is not None: print("Calibrating ...") cal_dataset.shuffle(args.seed) with Calibration(streamline=args.no_streamline, debug=args.debug): cal_samples = args.batch_size * args.validation_batch - calibrate(model, tokenizer, cal_dataset, device, args.batch_size, samples=cal_samples) - freeze(model) - generate(model, tokenizer, device, args.prompt, args.max_new_tokens) + calibrate(qmodel, tokenizer, cal_dataset, device, args.batch_size, samples=cal_samples) + generate(qmodel, tokenizer, device, args.prompt, args.max_new_tokens) if __name__ == "__main__":