Skip to content

Commit

Permalink
Enable specifying custom number of samples for rotation optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Jan 16, 2025
1 parent 99c196f commit 5f736a8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
7 changes: 5 additions & 2 deletions src/brevitas_examples/llm/llm_quant/rotation_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,14 @@ def apply_rotation_optimization(
# Prepare dataset and model for training
train_dataset = _prepare_train_dataset(train_dataset)
model = _prepare_model(model)
# Get training arguments
training_args = parse_optimization_rotation_args(unknown_args)
# Enable skipping optimization
if training_args.max_steps <= 0:
return
# Remove hooks and empty cache before starting optimization
remove_hooks(model)
torch.cuda.empty_cache()
# Get training arguments
training_args = parse_optimization_rotation_args(unknown_args)
# Set to False the model parameters
for param in model.parameters():
param.requires_grad = False
Expand Down
21 changes: 20 additions & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,20 @@ def quantize_llm(args, unknown_args=None):
device=None,
fuse_sequences=args.fuse_sequences)

if args.rotation in ["fused_no_fx_optimize"]:
# Load the data for rotation optimization
rot_calibration_loader = get_dataset_for_model(
args.model,
dataset_name=args.dataset,
tokenizer=tokenizer,
nsamples=args.nsamples_rot_calibration,
seqlen=args.seqlen,
split="train",
seed=args.seed,
require_fx=require_fx and args.export_target is not None,
device=None,
fuse_sequences=args.fuse_sequences)

device = next(iter(model.parameters())).device
print("Data loaded.")

Expand Down Expand Up @@ -467,7 +481,7 @@ def quantize_llm(args, unknown_args=None):
apply_rotation_optimization(
model=model,
tokenizer=tokenizer,
train_dataset=calibration_loader,
train_dataset=rot_calibration_loader,
unknown_args=unknown_args,
)
# Remove hooks from optimization
Expand Down Expand Up @@ -625,6 +639,11 @@ def parse_args(args, override_defaults={}):
type=int,
default=128,
help='Number of calibration data samples. Default: 128.')
parser.add_argument(
'--nsamples-rot-calibration',
type=int,
default=800,
help='Number of calibration data samples for rotation. Default: %(default)d.')
parser.add_argument('--seqlen', type=int, default=2048, help='Sequence length. Default: 2048.')
parser.add_argument('--eval', action='store_true', help='Eval model PPL on the chosen Dataset.')
parser.add_argument(
Expand Down

0 comments on commit 5f736a8

Please sign in to comment.