From 349057a4607ff9159b1e6a59c72c836e124d6582 Mon Sep 17 00:00:00 2001 From: Ian Colbert <88047104+i-colbert@users.noreply.github.com> Date: Sat, 4 Jan 2025 11:20:19 -0800 Subject: [PATCH] Fix (setup): fix LLM entry point (#1145) --- src/brevitas_examples/llm/main.py | 10 +++++++--- tests/brevitas_examples/test_llm.py | 6 +++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 1d2c86d88..ed2ebc2c8 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -195,7 +195,7 @@ def validate(args): "or decreasing the sequence length (seqlen)") -def main(args): +def quantize_llm(args): validate(args) set_seed(args.seed) if args.export_prefix is None: @@ -850,7 +850,11 @@ def parse_args(args, override_defaults={}): return parser.parse_args(args) -if __name__ == '__main__': +def main(): overrides = override_defaults(sys.argv[1:]) args = parse_args(sys.argv[1:], override_defaults=overrides) - main(args) + quantize_llm(args) + + +if __name__ == '__main__': + main() diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 1a425296f..c02a3e320 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -18,8 +18,8 @@ from brevitas import config from brevitas import torch_version -from brevitas_examples.llm.main import main from brevitas_examples.llm.main import parse_args +from brevitas_examples.llm.main import quantize_llm from tests.marker import jit_disabled_for_export from tests.marker import requires_pt_ge @@ -49,12 +49,12 @@ def validate_args(args): a = vars(args) da = vars(parse_args([])) for k in a.keys(): - assert k in da.keys(), f"Key {k} does not seem to be a valid argument for `main`" + assert k in da.keys(), f"Key {k} does not seem to be a valid argument for `quantize_llm`" def validate_args_and_run_main(args): validate_args(args) - float_ppl, quant_ppl, model = main(args) + float_ppl, quant_ppl, model = quantize_llm(args) return float_ppl, quant_ppl, model