Skip to content

Commit

Permalink
feat(examples): use QuantizedModelForCausalLM
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Sep 30, 2024
1 parent dfe8261 commit 13b2b0f
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions examples/nlp/text-generation/quantize_causal_lm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from optimum.quanto import Calibration, freeze, qfloat8, qint4, qint8, quantize
from optimum.quanto import Calibration, QuantizedModelForCausalLM, qfloat8, qint4, qint8


@torch.no_grad()
Expand Down Expand Up @@ -133,15 +133,14 @@ def main():
print(f"{args.model} (w: {args.weights}, a: {args.activations})")
weights = keyword_to_itype(args.weights)
activations = keyword_to_itype(args.activations)
quantize(model, weights=weights, activations=activations)
qmodel = QuantizedModelForCausalLM.quantize(model, weights=weights, activations=activations)
if activations is not None:
print("Calibrating ...")
cal_dataset.shuffle(args.seed)
with Calibration(streamline=args.no_streamline, debug=args.debug):
cal_samples = args.batch_size * args.validation_batch
calibrate(model, tokenizer, cal_dataset, device, args.batch_size, samples=cal_samples)
freeze(model)
generate(model, tokenizer, device, args.prompt, args.max_new_tokens)
calibrate(qmodel, tokenizer, cal_dataset, device, args.batch_size, samples=cal_samples)
generate(qmodel, tokenizer, device, args.prompt, args.max_new_tokens)


if __name__ == "__main__":
Expand Down

0 comments on commit 13b2b0f

Please sign in to comment.