Skip to content

Commit

Permalink
add use_fp32_logits flag
Browse files Browse the repository at this point in the history
  • Loading branch information
shidongxing committed Dec 20, 2024
1 parent 141e9eb commit a7a47d2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
6 changes: 6 additions & 0 deletions internlm/model/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ def new_cross_entropy(
except KeyError:
raise KeyError(f"op_type only support: {cross_entropy_op_name_map.keys()}")

if not gpc.config.get("use_fp32_logits", True):
assert op_type in [
CrossEntropyOpType.flash_vocab_parallel,
CrossEntropyOpType.apex_naive,
], "use_fp32_logits=False only support 'flash_vocab_parallel' or 'apex_naive' loss function"

if internlm_accelerator.get_accelerator_backend() is not AcceleratorType.GPU:
assert op_type in [
CrossEntropyOpType.torch_naive,
Expand Down
2 changes: 1 addition & 1 deletion internlm/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def inject_model(model):
else:
model = NaiveAMPModel(
model=model,
output_to_fp32=gpc.is_no_pp_or_last_stage(),
output_to_fp32=gpc.is_no_pp_or_last_stage() and gpc.config.get("use_fp32_logits", True),
dtype=gpc.config.model.get("dtype", torch.half),
sync_buffer=False,
)
Expand Down

0 comments on commit a7a47d2

Please sign in to comment.