Skip to content

Commit

Permalink
Add rotation optimization tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Jan 16, 2025
1 parent 5f736a8 commit 0f74e0c
Showing 1 changed file with 131 additions and 4 deletions.
135 changes: 131 additions & 4 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ def transformers_version_ge(required_version: str):
# Check that all args in args are used
def validate_args(args):
a = vars(args)
da = vars(parse_args([]))
da = vars(parse_args([])[0])
for k in a.keys():
assert k in da.keys(), f"Key {k} does not seem to be a valid argument for `quantize_llm`"


def validate_args_and_run_main(args):
def validate_args_and_run_main(args, unknown_args=None):
validate_args(args)
float_ppl, quant_ppl, model = quantize_llm(args)
float_ppl, quant_ppl, model = quantize_llm(args, unknown_args=unknown_args)
return float_ppl, quant_ppl, model


Expand Down Expand Up @@ -131,7 +131,7 @@ def small_models_with_ppl(request):

@pytest_cases.fixture()
def default_run_args(request):
args = UpdatableNamespace(**vars(parse_args([])))
args = UpdatableNamespace(**vars(parse_args([])[0]))
args.nsamples = 2
args.seqlen = 2
args.model = "hf-internal-testing/tiny-random-MistralForCausalLM"
Expand All @@ -156,6 +156,11 @@ def run_test_models_run_args(args, model_with_ppl):
float_ppl, quant_ppl, model = validate_args_and_run_main(args)


@pytest.fixture(scope="session", autouse=True)
def set_env():
os.environ["CUDA_VISIBLE_DEVICES"] = "1"


# yapf: disable
@pytest_cases.fixture(
ids=[
Expand Down Expand Up @@ -825,3 +830,125 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
quant_ppl = quant_ppl.detach().cpu().numpy()
assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}"
assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}"


@pytest_cases.fixture(
ids=[
"llama_rotation_optimization_ort",
"llama_rotation_optimization_ort_no_orphan",
"llama_rotation_optimization_had",
"llama_rotation_optimization_had_no_orphan",],
params=[
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_calibration": False,
"weight_bit_width": 4,
"input_bit_width": None,
"replace_rmsnorm": True,
"rotation": "fused_no_fx_optimize",
"rotation_orphan_sink": True,
"rotation_mode": "ort",
"nsamples_rot_calibration": 2,
"no_float16": True,
"unknown_args": [
"--max_steps",
"2",
"--per_device_train_batch_size",
"1",
"--gradient_accumulation_steps",
"1"],
"float_ppl": 33238.8984375,
"quant_ppl": 33232.65234375},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_calibration": False,
"weight_bit_width": 4,
"input_bit_width": None,
"replace_rmsnorm": True,
"rotation": "fused_no_fx_optimize",
"rotation_orphan_sink": False,
"rotation_mode": "ort",
"nsamples_rot_calibration": 2,
"max_steps": 2,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 1,
"no_float16": True,
"unknown_args": [
"--max_steps",
"2",
"--per_device_train_batch_size",
"1",
"--gradient_accumulation_steps",
"1"],
"float_ppl": 33238.8984375,
"quant_ppl": 33420.65234375},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_calibration": False,
"weight_bit_width": 4,
"input_bit_width": None,
"replace_rmsnorm": True,
"rotation": "fused_no_fx_optimize",
"rotation_orphan_sink": True,
"rotation_mode": "had",
"nsamples_rot_calibration": 2,
"max_steps": 2,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 1,
"no_float16": True,
"unknown_args": [
"--max_steps",
"2",
"--per_device_train_batch_size",
"1",
"--gradient_accumulation_steps",
"1"],
"float_ppl": 33238.8984375,
"quant_ppl": 33290.48046875},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_calibration": False,
"weight_bit_width": 4,
"input_bit_width": None,
"replace_rmsnorm": True,
"rotation": "fused_no_fx_optimize",
"rotation_orphan_sink": False,
"rotation_mode": "had",
"nsamples_rot_calibration": 2,
"max_steps": 2,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 1,
"no_float16": True,
"unknown_args": [
"--max_steps",
"2",
"--per_device_train_batch_size",
"1",
"--gradient_accumulation_steps",
"1"],
"float_ppl": 33238.8984375,
"quant_ppl": 33204.80859375},])
def rotation_optimization_args_and_ppl(default_run_args, request):
args = default_run_args
run_dict = request.param
unknown_args = run_dict["unknown_args"]
float_ppl = run_dict["float_ppl"]
quant_ppl = run_dict["quant_ppl"]
del run_dict["float_ppl"]
del run_dict["quant_ppl"]
del run_dict["unknown_args"]
args.update(**run_dict)
yield args, unknown_args, float_ppl, quant_ppl


@requires_pt_ge('2.4')
def test_small_models_rotation_optimization_ppl(caplog, rotation_optimization_args_and_ppl):
if platform.system() == "Windows":
pytest.skip("Skipping dynamo + windows")
caplog.set_level(logging.INFO)
args, unknown_args, exp_float_ppl, exp_quant_ppl = rotation_optimization_args_and_ppl
float_ppl, quant_ppl, model = validate_args_and_run_main(args, unknown_args)
float_ppl = float_ppl.detach().cpu().numpy()
quant_ppl = quant_ppl.detach().cpu().numpy()
assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}"
assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}"

0 comments on commit 0f74e0c

Please sign in to comment.