Skip to content

Commit

Permalink
Fix (setup): fix LLM entry point (#1145)
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert authored Jan 4, 2025
1 parent 3aadf17 commit 349057a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
10 changes: 7 additions & 3 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def validate(args):
"or decreasing the sequence length (seqlen)")


def main(args):
def quantize_llm(args):
validate(args)
set_seed(args.seed)
if args.export_prefix is None:
Expand Down Expand Up @@ -850,7 +850,11 @@ def parse_args(args, override_defaults={}):
return parser.parse_args(args)


if __name__ == '__main__':
def main():
overrides = override_defaults(sys.argv[1:])
args = parse_args(sys.argv[1:], override_defaults=overrides)
main(args)
quantize_llm(args)


if __name__ == '__main__':
main()
6 changes: 3 additions & 3 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

from brevitas import config
from brevitas import torch_version
from brevitas_examples.llm.main import main
from brevitas_examples.llm.main import parse_args
from brevitas_examples.llm.main import quantize_llm
from tests.marker import jit_disabled_for_export
from tests.marker import requires_pt_ge

Expand Down Expand Up @@ -49,12 +49,12 @@ def validate_args(args):
a = vars(args)
da = vars(parse_args([]))
for k in a.keys():
assert k in da.keys(), f"Key {k} does not seem to be a valid argument for `main`"
assert k in da.keys(), f"Key {k} does not seem to be a valid argument for `quantize_llm`"


def validate_args_and_run_main(args):
validate_args(args)
float_ppl, quant_ppl, model = main(args)
float_ppl, quant_ppl, model = quantize_llm(args)
return float_ppl, quant_ppl, model


Expand Down

0 comments on commit 349057a

Please sign in to comment.