From a7a47d21e8a0c39e55d4ad4b9dc0c3997ced8104 Mon Sep 17 00:00:00 2001 From: shidongxing Date: Fri, 20 Dec 2024 13:29:11 +0800 Subject: [PATCH] add use_fp32_logits flag --- internlm/model/ops/cross_entropy.py | 6 ++++++ internlm/train/pipeline.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/internlm/model/ops/cross_entropy.py b/internlm/model/ops/cross_entropy.py index 99bf1e04..ac9bce8f 100644 --- a/internlm/model/ops/cross_entropy.py +++ b/internlm/model/ops/cross_entropy.py @@ -67,6 +67,12 @@ def new_cross_entropy( except KeyError: raise KeyError(f"op_type only support: {cross_entropy_op_name_map.keys()}") + if not gpc.config.get("use_fp32_logits", True): + assert op_type in [ + CrossEntropyOpType.flash_vocab_parallel, + CrossEntropyOpType.apex_naive, + ], "use_fp32_logits=False only support 'flash_vocab_parallel' or 'apex_naive' loss function" + if internlm_accelerator.get_accelerator_backend() is not AcceleratorType.GPU: assert op_type in [ CrossEntropyOpType.torch_naive, diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 79e9caf4..242d7c6c 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -300,7 +300,7 @@ def inject_model(model): else: model = NaiveAMPModel( model=model, - output_to_fp32=gpc.is_no_pp_or_last_stage(), + output_to_fp32=gpc.is_no_pp_or_last_stage() and gpc.config.get("use_fp32_logits", True), dtype=gpc.config.model.get("dtype", torch.half), sync_buffer=False, )