From 1a73135a8d10a07cd973380527d981ba6f04f185 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Mon, 30 Sep 2024 12:19:56 +0000 Subject: [PATCH] feat(examples): use QuantizedModelForCausalLM --- examples/nlp/text-generation/quantize_causal_lm_model.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/nlp/text-generation/quantize_causal_lm_model.py b/examples/nlp/text-generation/quantize_causal_lm_model.py index 4099286c..da279ab2 100644 --- a/examples/nlp/text-generation/quantize_causal_lm_model.py +++ b/examples/nlp/text-generation/quantize_causal_lm_model.py @@ -19,7 +19,7 @@ from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer -from optimum.quanto import Calibration, freeze, qfloat8, qint4, qint8, quantize +from optimum.quanto import Calibration, QuantizedModelForCausalLM, qfloat8, qint4, qint8 @torch.no_grad() @@ -133,15 +133,14 @@ def main(): print(f"{args.model} (w: {args.weights}, a: {args.activations})") weights = keyword_to_itype(args.weights) activations = keyword_to_itype(args.activations) - quantize(model, weights=weights, activations=activations) + qmodel = QuantizedModelForCausalLM.quantize(model, weights=weights, activations=activations) if activations is not None: print("Calibrating ...") cal_dataset.shuffle(args.seed) with Calibration(streamline=args.no_streamline, debug=args.debug): cal_samples = args.batch_size * args.validation_batch - calibrate(model, tokenizer, cal_dataset, device, args.batch_size, samples=cal_samples) - freeze(model) - generate(model, tokenizer, device, args.prompt, args.max_new_tokens) + calibrate(qmodel, tokenizer, cal_dataset, device, args.batch_size, samples=cal_samples) + generate(qmodel, tokenizer, device, args.prompt, args.max_new_tokens) if __name__ == "__main__":