diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index 8471c84ee..8a160834f 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -50,7 +50,9 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] [--replace-mha] [--weight-equalization] [--rotation {fx,layerwise,fused_no_fx}] [--rotation-mode {had,ort}] [--rotation-orphan-sink] - [--act-equalization {None,layerwise,fx}] [--load-awq LOAD_AWQ] + [--act-equalization {None,layerwise,fx}] + [--act-equalization-alpha ACT_EQUALIZATION_ALPHA] + [--load-awq LOAD_AWQ] [--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}] [--export-prefix EXPORT_PREFIX] [--checkpoint-name CHECKPOINT_NAME] [--fuse-sequences] @@ -182,6 +184,9 @@ options: introduces standalone mul nodes,while fx merges them whenever possible into previous tensors, which is possible on ReLU based models (e.g. OPT). + --act-equalization-alpha ACT_EQUALIZATION_ALPHA + If activation equalization is enabled, decide what + alpha to use --load-awq LOAD_AWQ Load the awq search results. --export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight} Model export. @@ -203,4 +208,5 @@ options: --learned-round-fast-update Whether to use fast update with learned round. Prototype (default: False) + ``` diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 3a678bdf8..65c5334d7 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -302,8 +302,9 @@ def main(args): if args.act_equalization is not None: offload_model(model) - print("Apply act equalization (SmoothQuant)...") - apply_act_equalization(model, args.act_equalization, calibration_loader) + print(f"Apply act equalization (SmoothQuant) with alpha {args.act_equalization_alpha}") + apply_act_equalization( + model, args.act_equalization, calibration_loader, alpha=args.act_equalization_alpha) print("Act equalization applied.") remove_hooks(model) @@ -685,6 +686,11 @@ def parse_args(args): help='Apply activation equalization (SmoothQuant). Layerwise introduces standalone mul nodes,' 'while fx merges them whenever possible into previous tensors, which is possible on ReLU based models (e.g. OPT).' ) + parser.add_argument( + '--act-equalization-alpha', + default=0.5, + type=float, + help='If activation equalization is enabled, decide what alpha to use') parser.add_argument('--load-awq', type=str, default=None, help="Load the awq search results.") parser.add_argument( '--export-target',