Skip to content

Commit

Permalink
add tri_stage and inverse_sqrt schedule (#10)
Browse files Browse the repository at this point in the history
* add decoder_layers_num

* fix convert scripts

* fix name

* add tri_stage and inverse_sqrt schedule

* back

* add tri_stage and inverse_sqrt schedule

* update license

Co-authored-by: janinezhao <[email protected]>
  • Loading branch information
JINGZIjingzi and janinezhao authored Nov 25, 2022
1 parent 5874168 commit c80351e
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 2 deletions.
5 changes: 4 additions & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,9 @@ Copyright (c) Microsoft Corporation.
4. CLUECorpus2020
Copyright (c) 2022 CLUE benchmark

5.fairseq
Copyright (c) Facebook, Inc. and its affiliates.


Terms of the MIT License:
--------------------------------------------------------------------
Expand Down Expand Up @@ -1333,4 +1336,4 @@ specific requirements.
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU AGPL, see
<https://www.gnu.org/licenses/>.
<https://www.gnu.org/licenses/>.
3 changes: 2 additions & 1 deletion tencentpretrain/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
str2scheduler = {"linear": get_linear_schedule_with_warmup, "cosine": get_cosine_schedule_with_warmup,
"cosine_with_restarts": get_cosine_with_hard_restarts_schedule_with_warmup,
"polynomial": get_polynomial_decay_schedule_with_warmup,
"constant": get_constant_schedule, "constant_with_warmup": get_constant_schedule_with_warmup}
"constant": get_constant_schedule, "constant_with_warmup": get_constant_schedule_with_warmup,
"inverse_sqrt": get_inverse_square_root_schedule_with_warmup, "tri_stage": get_tri_stage_schedule}

str2adv = {"fgm": FGM, "pgd": PGD}

Expand Down
106 changes: 106 additions & 0 deletions tencentpretrain/utils/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,67 @@ def lr_lambda(current_step: int):
return LambdaLR(optimizer, lr_lambda, last_epoch)


def get_tri_stage_schedule(optimizer, num_warmup_steps, num_decay_steps, num_training_steps, init_lr_scale=0.01, final_lr_scale=0.05, last_epoch=-1):
"""
Create a schedule with a learning rate that have three stages: a warmup stage, a hold stage and a decay stage.
Implement the learning rate scheduler in https://arxiv.org/pdf/1904.08779.pdf
- warmup stage, starting from `lr` * `init_lr_scale`, linearly
increased to `lr` in `warmup_steps` iterations
- hold stage, after `warmup_steps`, keep the LR as `lr` for `hold_steps`
iterations
- decay stage, after hold stage, decay LR exponetially to
`lr` * `final_lr_scale` in `decay_steps`;
after that LR is keep as `final_lr_scale` * `lr`
During warmup::
init_lr = arg.init_lr_scale * arg.lr
lrs = torch.linspace(init_lr, arg.lr, arg.warmup_steps)
lr = lrs[update_num]
During hold::
lr = arg.lr
During decay::
decay_factor = - math.log(arg.final_lr_scale) / arg.decay_steps
lr = arg.lr * exp(- (update_num - warmup_steps - decay_steps) * decay_factor)
After that::
lr = arg.lr * arg.final_lr_scale
Args:
optimizer (:class:`~torch.optim.Optimizer`):
The optimizer for which to schedule the learning rate.
num_warmup_steps (:obj:`int`):
The number of steps for the warmup phase.
num_decay_steps (:obj:`int`):
The number of steps for the decay phase.
num_training_steps (:obj:`int`):
The total number of training steps.
decay_scale (:obj:`float`):
last_epoch (:obj:`int`, `optional`, defaults to -1):
The index of the last epoch when resuming training.
Return:
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_hold = optimizer.defaults["lr"]
lr_int = lr_hold * init_lr_scale
lr_end = lr_hold * final_lr_scale

def lr_lambda(current_step: int):
warmup_rate = (lr_hold - lr_int) / num_warmup_steps
decay_factor = -math.log(final_lr_scale) / max(num_decay_steps, 1)

if current_step < num_warmup_steps:
return (lr_int + current_step * warmup_rate) / lr_hold
elif current_step >= num_warmup_steps and current_step < num_training_steps - num_decay_steps:
return 1
elif current_step <= num_training_steps:
return math.exp(-decay_factor * (current_step - num_training_steps + num_decay_steps))
else:
return lr_end / lr_hold

return LambdaLR(optimizer, lr_lambda, last_epoch)


def get_cosine_schedule_with_warmup(
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
):
Expand Down Expand Up @@ -196,6 +257,51 @@ def lr_lambda(current_step: int):
return LambdaLR(optimizer, lr_lambda, last_epoch)


def get_inverse_square_root_schedule_with_warmup(
optimizer, num_warmup_steps, num_training_steps, warmup_init_lr=0.0, last_epoch=-1
):
"""
Create a schedule with a learning rate that Decay the LR based on the inverse square root of the update number.
After a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
Args:
optimizer (:class:`~torch.optim.Optimizer`):
The optimizer for which to schedule the learning rate.
num_warmup_steps (:obj:`int`):
The number of steps for the warmup phase.
num_training_steps (:obj:`int`):
The total number of training steps.
warmup_init_lr (:obj:`float`, `optional`, defaults to 0):
The initial LR for warmup.
last_epoch (:obj:`int`, `optional`, defaults to -1):
The index of the last epoch when resuming training.
During warmup::
lrs = torch.linspace(arg.warmup_init_lr, arg.lr, arg.warmup_updates)
lr = lrs[update_num]
After warmup::
decay_factor = arg.lr * sqrt(arg.warmup_updates)
lr = decay_factor / sqrt(update_num)
Return:
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""

lr = optimizer.defaults["lr"]
assert lr > warmup_init_lr, f"lr ({lr}) must be be bigger than initial lr ({warmup_init_lr})"

def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
lr_step = (lr - warmup_init_lr) / num_warmup_steps
return (warmup_init_lr + current_step * lr_step) / lr
elif current_step > num_training_steps:
return 1e-7 / lr # as LambdaLR multiplies by lr_init
else:
decay_factor = lr * num_warmup_steps**0.5
return (decay_factor * current_step**-0.5) / lr

return LambdaLR(optimizer, lr_lambda, last_epoch)


class AdamW(Optimizer):
"""
Implements Adam algorithm with weight decay fix as introduced in `Decoupled Weight Decay Regularization
Expand Down

0 comments on commit c80351e

Please sign in to comment.