Skip to content

Commit

Permalink
Rename HF args
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Jan 17, 2025
1 parent 80d247a commit b7209d2
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 35 deletions.
11 changes: 6 additions & 5 deletions src/brevitas_examples/llm/llm_quant/rotation_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@
class TrainingArguments(transformers.TrainingArguments):
# By default, arguments are saved in the current working directory
output_dir: Optional[str] = field(default=os.getcwd())
# NOTE: Currently, there is no infrastructure to resume training
# from a checkpoint, so related files are not save by default
save_strategy: Optional[str] = field(default="no")


def parse_optimization_rotation_args(unknown_args=None) -> None:
def parse_rotation_optimization_args(extra_args: Optional[List[str]] = None) -> TrainingArguments:
parser = transformers.HfArgumentParser(TrainingArguments)
training_args = parser.parse_args_into_dataclasses(args=unknown_args)
training_args = parser.parse_args_into_dataclasses(args=extra_args)
# If a single-process is running, only one GPU should be available
# for Trainer, to prevent using DataParallel, which was causing an
# error due to tensors in different devices being operated.
Expand Down Expand Up @@ -83,14 +86,12 @@ def apply_rotation_optimization(
model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase,
train_dataset: DatasetToDevice,
unknown_args: List[str] = None,
training_args: TrainingArguments,
) -> None:

# Prepare dataset and model for training
train_dataset = _prepare_train_dataset(train_dataset)
model = _prepare_model(model)
# Get training arguments
training_args = parse_optimization_rotation_args(unknown_args)
# Enable skipping optimization
if training_args.max_steps <= 0:
return
Expand Down
16 changes: 11 additions & 5 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from copy import deepcopy
import functools
import sys
from typing import List, Optional
from warnings import warn

from lm_eval import evaluator
Expand Down Expand Up @@ -56,6 +57,7 @@
from brevitas_examples.llm.llm_quant.prepare_for_quantize import \
replace_sdpa_with_quantizable_layers
from brevitas_examples.llm.llm_quant.rotation_optimization import apply_rotation_optimization
from brevitas_examples.llm.llm_quant.rotation_optimization import parse_rotation_optimization_args
from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32
from brevitas_examples.llm.llm_quant.run_utils import fix_rewriter
from brevitas_examples.llm.llm_quant.run_utils import get_fx
Expand Down Expand Up @@ -158,7 +160,9 @@ def model_export(model, ref_input, args):
export_torch_qcdq(model, ref_input['input_ids'], export_path=f"{args.export_prefix}.pt")


def validate(args):
def validate(args, extra_args: Optional[List[str]] = None):
if args.rotation != "fused_no_fx_optimize":
assert extra_args is None or len(extra_args) == 0, f"The following unknown arguments were passed: {[extra_arg for extra_arg in extra_args if extra_arg.startswith("--")]}"
if args.functional_sdpa_quant:
assert args.input_scale_type == 'dynamic' or args.input_bit_width is None, "Functional SDPA Quant requires dynamic activation quantization"
if args.rotation == 'fx':
Expand Down Expand Up @@ -225,7 +229,7 @@ def validate(args):
"or decreasing the sequence length (seqlen)")


def quantize_llm(args, unknown_args=None):
def quantize_llm(args, extra_args=None):
validate(args)
set_seed(args.seed)
if args.export_prefix is None:
Expand Down Expand Up @@ -288,6 +292,8 @@ def quantize_llm(args, unknown_args=None):
fuse_sequences=args.fuse_sequences)

if args.rotation in ["fused_no_fx_optimize"]:
# Extra arguments should be used as training arguments for rotation optimization
rot_optimization_args = parse_rotation_optimization_args(extra_args=extra_args)
# Load the data for rotation optimization
rot_calibration_loader = get_dataset_for_model(
args.model,
Expand Down Expand Up @@ -495,7 +501,7 @@ def quantize_llm(args, unknown_args=None):
model=model,
tokenizer=tokenizer,
train_dataset=rot_calibration_loader,
unknown_args=unknown_args,
training_args=rot_optimization_args,
)
# Remove hooks from optimization
remove_hooks(model)
Expand Down Expand Up @@ -964,8 +970,8 @@ def parse_args(args, override_defaults={}):

def main():
overrides = override_defaults(sys.argv[1:])
args, unknown_args = parse_args(sys.argv[1:], override_defaults=overrides)
quantize_llm(args, unknown_args)
args, extra_args = parse_args(sys.argv[1:], override_defaults=overrides)
quantize_llm(args, extra_args)


if __name__ == '__main__':
Expand Down
42 changes: 17 additions & 25 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def validate_args(args):
assert k in da.keys(), f"Key {k} does not seem to be a valid argument for `quantize_llm`"


def validate_args_and_run_main(args, unknown_args=None):
def validate_args_and_run_main(args, extra_args=None):
validate_args(args)
float_ppl, quant_ppl, model = quantize_llm(args, unknown_args=unknown_args)
float_ppl, quant_ppl, model = quantize_llm(args, extra_args=extra_args)
return float_ppl, quant_ppl, model


Expand Down Expand Up @@ -841,17 +841,15 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
"rotation_mode": "ort",
"nsamples_rot_calibration": 2,
"no_float16": True,
"unknown_args": [
"extra_args": [
"--learning_rate",
"1.5",
"--max_steps",
"2",
"--per_device_train_batch_size",
"1",
"--gradient_accumulation_steps",
"1",
"--save_strategy",
"no"],
"1"],
"float_ppl": 33238.8984375,
"quant_ppl": 33278.98828125,
"exp_layer_types_count": {
Expand All @@ -870,17 +868,15 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
"rotation_mode": "ort",
"nsamples_rot_calibration": 2,
"no_float16": True,
"unknown_args": [
"extra_args": [
"--learning_rate",
"1.5",
"--max_steps",
"2",
"--per_device_train_batch_size",
"1",
"--gradient_accumulation_steps",
"1",
"--save_strategy",
"no"],
"1"],
"float_ppl": 33238.8984375,
"quant_ppl": 33424.73046875,
"exp_layer_types_count": {
Expand All @@ -899,17 +895,15 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
"rotation_mode": "had",
"nsamples_rot_calibration": 2,
"no_float16": True,
"unknown_args": [
"extra_args": [
"--learning_rate",
"1.5",
"--max_steps",
"2",
"--per_device_train_batch_size",
"1",
"--gradient_accumulation_steps",
"1",
"--save_strategy",
"no"],
"1"],
"float_ppl": 33238.8984375,
"quant_ppl": 33339.21875,
"exp_layer_types_count": {
Expand All @@ -928,17 +922,15 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
"rotation_mode": "had",
"nsamples_rot_calibration": 2,
"no_float16": True,
"unknown_args": [
"extra_args": [
"--learning_rate",
"1.5",
"--max_steps",
"2",
"--per_device_train_batch_size",
"1",
"--gradient_accumulation_steps",
"1",
"--save_strategy",
"no"],
"1"],
"float_ppl": 33238.8984375,
"quant_ppl": 33219.08984375,
"exp_layer_types_count": {
Expand All @@ -949,16 +941,16 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
def rotation_optimization_args_layer_count_and_ppl(default_run_args, request):
args = default_run_args
run_dict = request.param
unknown_args = run_dict["unknown_args"]
extra_args = run_dict["extra_args"]
float_ppl = run_dict["float_ppl"]
quant_ppl = run_dict["quant_ppl"]
exp_layer_types_count = run_dict["exp_layer_types_count"]
del run_dict["float_ppl"]
del run_dict["quant_ppl"]
del run_dict["unknown_args"]
del run_dict["extra_args"]
del run_dict["exp_layer_types_count"]
args.update(**run_dict)
yield args, unknown_args, float_ppl, quant_ppl, exp_layer_types_count
yield args, extra_args, float_ppl, quant_ppl, exp_layer_types_count


@requires_pt_ge('2.4')
Expand All @@ -970,8 +962,8 @@ def test_small_models_rotation_optimization_ppl(
# with non-optimized quantized perplexities
RTOL_ROT, ATOL_ROT = 1e-05, 2.
caplog.set_level(logging.INFO)
args, unknown_args, exp_float_ppl, exp_quant_ppl, _ = rotation_optimization_args_layer_count_and_ppl
float_ppl, quant_ppl, _ = validate_args_and_run_main(args, unknown_args)
args, extra_args, exp_float_ppl, exp_quant_ppl, _ = rotation_optimization_args_layer_count_and_ppl
float_ppl, quant_ppl, _ = validate_args_and_run_main(args, extra_args)
float_ppl = float_ppl.detach().cpu().numpy()
quant_ppl = quant_ppl.detach().cpu().numpy()
assert allclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}"
Expand All @@ -986,7 +978,7 @@ def test_small_models_rotation_optimization_layer_count(
# Tolerances are stricter for this test, to ensure that it does not pass
# with non-optimized quantized perplexities
caplog.set_level(logging.INFO)
args, unknown_args, _, _, exp_layer_types_count = rotation_optimization_args_layer_count_and_ppl
args, extra_args, _, _, exp_layer_types_count = rotation_optimization_args_layer_count_and_ppl
with patch('brevitas_examples.llm.main.fuse_parametrized_rotations', lambda model: model):
_, _, model = validate_args_and_run_main(args, unknown_args)
_, _, model = validate_args_and_run_main(args, extra_args)
assert_layer_types_count(model, exp_layer_types_count)

0 comments on commit b7209d2

Please sign in to comment.