From c33a860267120a587683216a6db4fa87105fc9c1 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Thu, 21 Nov 2024 13:48:58 +0000 Subject: [PATCH 01/17] .custom load_state_dict that enables CPU offload --- src/nanotron/helpers.py | 139 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 137 insertions(+), 2 deletions(-) diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 73ca3484..6a46551e 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -4,10 +4,23 @@ import math import os import time +from collections import defaultdict +from copy import deepcopy from datetime import datetime from functools import partial +from itertools import chain from math import ceil -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import ( + Any, + DefaultDict, + Dict, + Hashable, + Iterable, + List, + Optional, + Tuple, + Union, +) import numpy as np import torch @@ -15,6 +28,7 @@ from torch.nn.parallel import DistributedDataParallel from torch.optim.lr_scheduler import LambdaLR from torch.profiler import ProfilerActivity, profile, tensorboard_trace_handler +from typing_extensions import TypeAlias from nanotron import distributed as dist from nanotron import logging @@ -44,6 +58,11 @@ from nanotron.scaling.parametrization import LearningRateForSP, LearningRateForSpectralMup, ParametrizationMethod from nanotron.serialize.metadata import TrainingMetadata +Args: TypeAlias = Tuple[Any, ...] +Kwargs: TypeAlias = Dict[str, Any] +StateDict: TypeAlias = Dict[str, Any] + + logger = logging.get_logger(__name__) @@ -291,6 +310,117 @@ def merge_named_param_groups( return named_param_groups +# Modified from torch.optim.Optimizer._process_value_according_to_param_policy +@staticmethod +def _process_value_according_to_param_policy( + param: torch.Tensor, + value: torch.Tensor, + param_id: int, + param_groups: List[Dict[Any, Any]], + map_location: Optional[Union[str, torch.device]], + key: Hashable = None, +) -> torch.Tensor: + # If map_location is specified, use it instead of param.device + target_device = map_location if map_location is not None else param.device + + fused = False + capturable = False + assert param_groups is not None + for pg in param_groups: + if param_id in pg["params"]: + fused = pg["fused"] if "fused" in pg else False + capturable = pg["capturable"] if "capturable" in pg else False + break + + if key == "step": + if capturable or fused: + return value.to(dtype=torch.float32, device=target_device) + else: + return value + else: + if param.is_floating_point(): + return value.to(dtype=param.dtype, device=target_device) + else: + return value.to(device=target_device) + + +# Modified from torch.optim.Optimizer.load_state_dict +@torch._disable_dynamo +def custom_load_state_dict( + self, state_dict: StateDict, map_location: Optional[Union[str, torch.device]] = "cpu" +) -> None: + r"""Loads the optimizer state. + + Args: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + map_location (str or torch.device, optional): Device where to load the optimizer states. + If None, states will be loaded to the same device as their corresponding parameters. + Default: None + """ + + # shallow copy, to be consistent with module API + state_dict = state_dict.copy() + + for pre_hook in self._optimizer_load_state_dict_pre_hooks.values(): + hook_result = pre_hook(self, state_dict) + if hook_result is not None: + state_dict = hook_result + + # Validate the state_dict + groups = self.param_groups + + # Deepcopy as we write into saved_groups later to update state + saved_groups = deepcopy(state_dict["param_groups"]) + + if len(groups) != len(saved_groups): + raise ValueError("loaded state dict has a different number of " "parameter groups") + param_lens = (len(g["params"]) for g in groups) + saved_lens = (len(g["params"]) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + raise ValueError( + "loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group" + ) + + # Update the state + id_map = dict( + zip(chain.from_iterable(g["params"] for g in saved_groups), chain.from_iterable(g["params"] for g in groups)) + ) + + def _cast(param, value, param_id=None, param_groups=None, key=None): + r"""Make a deep copy of value, casting all tensors to device of param.""" + if isinstance(value, torch.Tensor): + return _process_value_according_to_param_policy(param, value, param_id, param_groups, map_location, key) + elif isinstance(value, dict): + return {k: _cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items()} + elif isinstance(value, Iterable): + return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) + else: + return value + + # Copy state assigned to params (and cast tensors to appropriate types). + # State that is not assigned to params is copied as is (needed for + # backward compatibility). + state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict) + for k, v in state_dict["state"].items(): + if k in id_map: + param = id_map[k] + state[param] = _cast(param, v, param_id=k, param_groups=state_dict["param_groups"]) + else: + state[k] = v + + # Update parameter groups, setting their 'params' value + def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]: + new_group["params"] = group["params"] + return new_group + + param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] + self.__setstate__({"state": state, "param_groups": param_groups}) + + for post_hook in self._optimizer_load_state_dict_post_hooks.values(): + post_hook(self) + + def init_optimizer_and_grad_accumulator( parametrization_method: ParametrizationMethod, model: nn.Module, @@ -328,7 +458,7 @@ def basic_optimizer_builder(named_param_groups): if optimizer_args.optimizer_factory.name == "adamW": def optimizer(param_groups): - return torch.optim.AdamW( + base_optimizer = torch.optim.AdamW( param_groups, lr=optimizer_args.learning_rate_scheduler.learning_rate, weight_decay=optimizer_args.weight_decay, @@ -336,6 +466,11 @@ def optimizer(param_groups): betas=(optimizer_args.optimizer_factory.adam_beta1, optimizer_args.optimizer_factory.adam_beta2), fused=optimizer_args.optimizer_factory.torch_adam_is_fused, ) + # Replace the load_state_dict method with our custom implementation that enables CPU offload + base_optimizer.load_state_dict = lambda state_dict, map_location=None: custom_load_state_dict( + base_optimizer, state_dict, map_location=map_location + ) + return base_optimizer elif optimizer_args.optimizer_factory.name == "sgd": From 312b7597b01a27a84493eaca6fd33dcfdf92675a Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Thu, 21 Nov 2024 13:55:34 +0000 Subject: [PATCH 02/17] . --- src/nanotron/sanity_checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/sanity_checks.py b/src/nanotron/sanity_checks.py index 650607da..43da1fc1 100644 --- a/src/nanotron/sanity_checks.py +++ b/src/nanotron/sanity_checks.py @@ -263,7 +263,7 @@ def check_optim_state_in_sync(optimizer: optim.BaseOptimizer, pg: dist.ProcessGr for _, optim_state in sorted(optimizer.state_dict()["state"].items(), key=lambda x: x[0]): for name, tensor in optim_state.items(): if name == "step": - tensor = tensor.to("cuda") + continue assert_tensor_synced_across_pg( tensor=tensor, pg=pg, msg=lambda err: f"{name} are not synced across DP {err}" From 9e1d76f3b4124e10d9ae361dd565219970f1ce79 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Thu, 21 Nov 2024 15:26:12 +0000 Subject: [PATCH 03/17] load optim states in CPU and move them to GPU after 1st fwd-bwd to avoid peak memory --- src/nanotron/optim/base.py | 4 +- .../optim/inherit_from_other_optimizer.py | 6 +-- src/nanotron/optim/named_optimizer.py | 6 +-- .../optimizer_from_gradient_accumulator.py | 4 +- src/nanotron/sanity_checks.py | 9 ++-- src/nanotron/serialize/optimizer.py | 43 +++++++++++++------ src/nanotron/trainer.py | 27 ++++++++---- 7 files changed, 64 insertions(+), 35 deletions(-) diff --git a/src/nanotron/optim/base.py b/src/nanotron/optim/base.py index fb77f124..34c33f42 100644 --- a/src/nanotron/optim/base.py +++ b/src/nanotron/optim/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Set, TypeVar +from typing import Any, Callable, Dict, List, Optional, Set, TypeVar, Union import torch @@ -34,7 +34,7 @@ def state_dict(self) -> dict: ... @abstractmethod - def load_state_dict(self, state_dict: dict) -> None: + def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None: ... @abstractmethod diff --git a/src/nanotron/optim/inherit_from_other_optimizer.py b/src/nanotron/optim/inherit_from_other_optimizer.py index 2ddd36d0..53b57284 100644 --- a/src/nanotron/optim/inherit_from_other_optimizer.py +++ b/src/nanotron/optim/inherit_from_other_optimizer.py @@ -1,5 +1,5 @@ from functools import cache -from typing import Callable, Dict, Optional, Set +from typing import Callable, Dict, Optional, Set, Union import torch @@ -33,8 +33,8 @@ def state_dict_additional_keys(self) -> Set[str]: def state_dict(self) -> dict: return self.optimizer.state_dict() - def load_state_dict(self, state_dict: dict) -> None: - return self.optimizer.load_state_dict(state_dict) + def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None: + return self.optimizer.load_state_dict(state_dict, map_location=map_location) def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: return self.optimizer.step(closure=closure) diff --git a/src/nanotron/optim/named_optimizer.py b/src/nanotron/optim/named_optimizer.py index 5f11710a..07363caa 100644 --- a/src/nanotron/optim/named_optimizer.py +++ b/src/nanotron/optim/named_optimizer.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, Tuple, Union +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union import torch @@ -58,7 +58,7 @@ def state_dict(self) -> dict: } return optim_state_dict - def load_state_dict(self, state_dict: dict) -> None: + def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None: assert set(self.id_to_name.values()) == set( state_dict["names"].values() ), f"Elements don't match:\n - Elements in `self.id_to_name` that aren't in the other one: {set(self.id_to_name.values()) - set(state_dict['names'].values())}\n - Elements in `state_dict[\"names\"]` that aren't in the other one: {set(state_dict['names'].values()) - set(self.id_to_name.values())}" @@ -71,4 +71,4 @@ def load_state_dict(self, state_dict: dict) -> None: key in state ), f"Key {key} not found in state dict: {state} which corresponds to param_name: {state_dict['names'][k]}" - return super().load_state_dict(state_dict) + return super().load_state_dict(state_dict, map_location=map_location) diff --git a/src/nanotron/optim/optimizer_from_gradient_accumulator.py b/src/nanotron/optim/optimizer_from_gradient_accumulator.py index 01be7cb5..06426b0b 100644 --- a/src/nanotron/optim/optimizer_from_gradient_accumulator.py +++ b/src/nanotron/optim/optimizer_from_gradient_accumulator.py @@ -67,7 +67,7 @@ def state_dict(self) -> dict: state_dict["gradient_accumulator"] = self.gradient_accumulator.state_dict() return state_dict - def load_state_dict(self, state_dict: dict) -> None: + def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None: gradient_accumulator_state_dict = state_dict.pop("gradient_accumulator") - super().load_state_dict(state_dict) + super().load_state_dict(state_dict, map_location=map_location) self.gradient_accumulator.load_state_dict(gradient_accumulator_state_dict) diff --git a/src/nanotron/sanity_checks.py b/src/nanotron/sanity_checks.py index 43da1fc1..56ef1e2e 100644 --- a/src/nanotron/sanity_checks.py +++ b/src/nanotron/sanity_checks.py @@ -164,6 +164,7 @@ def before_optim_step_sanity_checks( parallel_context: ParallelContext, unwrapped_model: NanotronModel, grad_accumulator: GradientAccumulator, + optimizer: optim.BaseOptimizer, ) -> None: if not config.general.ignore_sanity_checks: # SANITY CHECK: Test tied weights gradients are synchronized @@ -232,6 +233,9 @@ def before_optim_step_sanity_checks( msg=lambda err: f"[Before optimizer step] Tied weights {name} are not synchronized. {err}", ) + # SANITY CHECK: Check that optimizer states are synchronized across DP + check_optim_state_in_sync(optimizer.state_dict(), parallel_context.dp_pg) + # SANITY CHECK: run model specific sanity checks unwrapped_model.before_optim_step_sanity_checks() @@ -259,12 +263,11 @@ def after_optim_step_sanity_checks( unwrapped_model.after_optim_step_sanity_checks() -def check_optim_state_in_sync(optimizer: optim.BaseOptimizer, pg: dist.ProcessGroup): - for _, optim_state in sorted(optimizer.state_dict()["state"].items(), key=lambda x: x[0]): +def check_optim_state_in_sync(optim_state_dict: dict, pg: dist.ProcessGroup): + for _, optim_state in sorted(optim_state_dict["state"].items(), key=lambda x: x[0]): for name, tensor in optim_state.items(): if name == "step": continue - assert_tensor_synced_across_pg( tensor=tensor, pg=pg, msg=lambda err: f"{name} are not synced across DP {err}" ) diff --git a/src/nanotron/serialize/optimizer.py b/src/nanotron/serialize/optimizer.py index c9d2da6b..867900cd 100644 --- a/src/nanotron/serialize/optimizer.py +++ b/src/nanotron/serialize/optimizer.py @@ -2,7 +2,7 @@ import warnings from collections import defaultdict from pathlib import Path -from typing import Optional, Tuple +from typing import Dict, Optional, Tuple import torch from torch import nn @@ -19,7 +19,6 @@ ) from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import NanotronParameter -from nanotron.sanity_checks import check_optim_state_in_sync from nanotron.serialize.metadata import TensorMetadata from nanotron.serialize.utils import ObjectType, merge_and_shard_tp_tensors @@ -125,18 +124,34 @@ def save_lr_scheduler( ) +# Helper functions to move optimizer states +@torch.no_grad() +def state_dict_to_device(state_dict: Dict, device: str) -> Dict: + assert ( + state_dict["state"][0]["exp_avg"].device.type == "cpu" + ), "Optimizer states should be on CPU to avoid extra memory usage when loading from checkpoint" + torch.cuda.empty_cache() + + for _, optim_state in sorted(state_dict["state"].items(), key=lambda x: x[0]): + for name, tensor in optim_state.items(): + optim_state[name] = tensor.to(device) + + assert ( + state_dict["state"][0]["exp_avg"].device.type == "cuda" + ), "Optimizer states should be on GPU because model is on GPU" + torch.cuda.empty_cache() + + @torch.no_grad() def load_optimizer( optimizer: optim.BaseOptimizer, parallel_context: ParallelContext, root_folder: Path, - map_location: Optional[str] = None, + map_location: Optional[str] = "cpu", param_shard_metadata: Tuple[Tuple[int, int], TensorMetadata] = None, # (pp_rank, tp_rank) -> TensorMetadata model: Optional[nn.Module] = None, ): root_folder = root_folder / "optimizer" - # `load_state_dict` copies the state dict which can be very large in case of Zero-0 so we load to cpu and then move to the right device - map_location = "cpu" if not optimizer.inherit_from(optim.ZeroDistributedOptimizer) else map_location ckp_optimizer_config_path = root_folder / "optimizer_config.json" with open(ckp_optimizer_config_path, "r") as file: ckp_optimizer_config = json.load(file) @@ -149,9 +164,10 @@ def load_optimizer( if int(ckp_tp_size) != int(parallel_context.tp_pg.size()) or int(ckp_pp_size) != int( parallel_context.pp_pg.size() ): - warnings.warn( - "You are resuming in a different PP size, so optimizer states need to be checked. Feel free to open a PR if you work on this!" - ) + if int(ckp_pp_size) != int(parallel_context.pp_pg.size()): + warnings.warn( + "You are resuming in a different PP size, so optimizer states need to be checked. Feel free to open a PR if you work on this!" + ) assert ( param_shard_metadata is not None ), f"You have to pass how the original parameters are sharded in order to resume in a different tensor parallel size, ckp_tp_size: {ckp_tp_size}, current tp_size: {parallel_context.tp_pg.size()}" @@ -241,8 +257,10 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) - # TODO(xrsrke): free the memory of the shards that isn't # corresponding to the current rank # TODO: maybe better to allocate memory for all states at once - buffer = torch.zeros_like(param, device="cuda", dtype=OPTIMIZER_STATE_DTYPE) - unsharded_buffer = torch.empty(new_unshared_shape, device="cuda", dtype=OPTIMIZER_STATE_DTYPE) + buffer = torch.zeros_like(param, device=map_location, dtype=OPTIMIZER_STATE_DTYPE) + unsharded_buffer = torch.empty( + new_unshared_shape, device=map_location, dtype=OPTIMIZER_STATE_DTYPE + ) for (pp_rank, tp_rank), ckp_optim_state in ckp_sharded_optim_states.items(): old_optim_state_index = find_optim_index_from_param_name( @@ -333,10 +351,7 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) - ) state_dict["state"][param_index][state_name] = sliced_tensor - optimizer.load_state_dict(state_dict) - - if not optimizer.inherit_from(optim.ZeroDistributedOptimizer): - check_optim_state_in_sync(optimizer, parallel_context.dp_pg) + optimizer.load_state_dict(state_dict, map_location="cpu") def load_lr_scheduler( diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 356c3910..44e4e32e 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -93,7 +93,7 @@ save_random_states, ) from nanotron.serialize.metadata import DataStageMetadata, TrainingMetadata -from nanotron.serialize.optimizer import load_optimizer +from nanotron.serialize.optimizer import load_optimizer, state_dict_to_device logger = logging.get_logger(__name__) @@ -432,10 +432,17 @@ def train( # Fix the root_model self.unwrapped_model.module_id_to_prefix[id(self.unwrapped_model)] = "" + self.initial_iter_step = self.metadata.last_train_step + 1 + self.last_iter_step = self.config.tokens.train_steps + prof = get_profiler(config=self.config) + # free memory + import gc + + gc.collect() torch.cuda.empty_cache() with prof: - for self.iteration_step in range(self.metadata.last_train_step + 1, self.config.tokens.train_steps + 1): + for self.iteration_step in range(self.initial_iter_step, self.last_iter_step + 1): if isinstance(prof, torch.profiler.profile): prof.step() @@ -474,7 +481,7 @@ def training_step( self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.lr_scheduler ) - if self.iteration_step < 5: + if self.iteration_step < self.initial_iter_step + 5: log_memory(logger=logger) outputs = self.pipeline_engine.train_batch_iter( @@ -485,7 +492,7 @@ def training_step( grad_accumulator=self.grad_accumulator, ) - if self.iteration_step < 5: + if self.iteration_step < self.initial_iter_step + 5: log_memory(logger=logger) after_tbi_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator) @@ -531,10 +538,6 @@ def training_step( max_norm=self.config.optimizer.clip_grad, ) - before_optim_step_sanity_checks( - self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator - ) - # Compute DP average loss and overlap with optimizer step if isinstance(outputs[0]["loss"], torch.Tensor): # This is an average on only one data rank. @@ -547,6 +550,14 @@ def training_step( loss_avg = None handle = None + # Move optimizer states back to GPU before optimizer step + if self.init_checkpoint_path is not None and self.iteration_step == self.initial_iter_step: + state_dict_to_device(self.optimizer.state_dict(), "cuda") + + before_optim_step_sanity_checks( + self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.optimizer + ) + # Apply gradient self.optimizer.step() self.optimizer.zero_grad() From bc25a353434bbaf5a9ec5ccbe2f6eceb5658ccff Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Thu, 21 Nov 2024 16:07:59 +0000 Subject: [PATCH 04/17] move load custom func to base --- src/nanotron/helpers.py | 125 +--------------------------------- src/nanotron/optim/base.py | 134 ++++++++++++++++++++++++++++++++++++- src/nanotron/trainer.py | 3 +- 3 files changed, 135 insertions(+), 127 deletions(-) diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 6a46551e..6ff564a8 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -4,22 +4,16 @@ import math import os import time -from collections import defaultdict -from copy import deepcopy from datetime import datetime from functools import partial -from itertools import chain from math import ceil from typing import ( Any, - DefaultDict, Dict, - Hashable, Iterable, List, Optional, Tuple, - Union, ) import numpy as np @@ -28,7 +22,6 @@ from torch.nn.parallel import DistributedDataParallel from torch.optim.lr_scheduler import LambdaLR from torch.profiler import ProfilerActivity, profile, tensorboard_trace_handler -from typing_extensions import TypeAlias from nanotron import distributed as dist from nanotron import logging @@ -36,7 +29,7 @@ from nanotron.distributed import ProcessGroup from nanotron.logging import LogItem, log_rank from nanotron.models.base import NanotronModel -from nanotron.optim.base import BaseOptimizer, Optimizer +from nanotron.optim.base import BaseOptimizer, Optimizer, custom_load_state_dict from nanotron.optim.gradient_accumulator import ( FP32GradBucketManager, FP32GradientAccumulator, @@ -58,11 +51,6 @@ from nanotron.scaling.parametrization import LearningRateForSP, LearningRateForSpectralMup, ParametrizationMethod from nanotron.serialize.metadata import TrainingMetadata -Args: TypeAlias = Tuple[Any, ...] -Kwargs: TypeAlias = Dict[str, Any] -StateDict: TypeAlias = Dict[str, Any] - - logger = logging.get_logger(__name__) @@ -310,117 +298,6 @@ def merge_named_param_groups( return named_param_groups -# Modified from torch.optim.Optimizer._process_value_according_to_param_policy -@staticmethod -def _process_value_according_to_param_policy( - param: torch.Tensor, - value: torch.Tensor, - param_id: int, - param_groups: List[Dict[Any, Any]], - map_location: Optional[Union[str, torch.device]], - key: Hashable = None, -) -> torch.Tensor: - # If map_location is specified, use it instead of param.device - target_device = map_location if map_location is not None else param.device - - fused = False - capturable = False - assert param_groups is not None - for pg in param_groups: - if param_id in pg["params"]: - fused = pg["fused"] if "fused" in pg else False - capturable = pg["capturable"] if "capturable" in pg else False - break - - if key == "step": - if capturable or fused: - return value.to(dtype=torch.float32, device=target_device) - else: - return value - else: - if param.is_floating_point(): - return value.to(dtype=param.dtype, device=target_device) - else: - return value.to(device=target_device) - - -# Modified from torch.optim.Optimizer.load_state_dict -@torch._disable_dynamo -def custom_load_state_dict( - self, state_dict: StateDict, map_location: Optional[Union[str, torch.device]] = "cpu" -) -> None: - r"""Loads the optimizer state. - - Args: - state_dict (dict): optimizer state. Should be an object returned - from a call to :meth:`state_dict`. - map_location (str or torch.device, optional): Device where to load the optimizer states. - If None, states will be loaded to the same device as their corresponding parameters. - Default: None - """ - - # shallow copy, to be consistent with module API - state_dict = state_dict.copy() - - for pre_hook in self._optimizer_load_state_dict_pre_hooks.values(): - hook_result = pre_hook(self, state_dict) - if hook_result is not None: - state_dict = hook_result - - # Validate the state_dict - groups = self.param_groups - - # Deepcopy as we write into saved_groups later to update state - saved_groups = deepcopy(state_dict["param_groups"]) - - if len(groups) != len(saved_groups): - raise ValueError("loaded state dict has a different number of " "parameter groups") - param_lens = (len(g["params"]) for g in groups) - saved_lens = (len(g["params"]) for g in saved_groups) - if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): - raise ValueError( - "loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group" - ) - - # Update the state - id_map = dict( - zip(chain.from_iterable(g["params"] for g in saved_groups), chain.from_iterable(g["params"] for g in groups)) - ) - - def _cast(param, value, param_id=None, param_groups=None, key=None): - r"""Make a deep copy of value, casting all tensors to device of param.""" - if isinstance(value, torch.Tensor): - return _process_value_according_to_param_policy(param, value, param_id, param_groups, map_location, key) - elif isinstance(value, dict): - return {k: _cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items()} - elif isinstance(value, Iterable): - return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) - else: - return value - - # Copy state assigned to params (and cast tensors to appropriate types). - # State that is not assigned to params is copied as is (needed for - # backward compatibility). - state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict) - for k, v in state_dict["state"].items(): - if k in id_map: - param = id_map[k] - state[param] = _cast(param, v, param_id=k, param_groups=state_dict["param_groups"]) - else: - state[k] = v - - # Update parameter groups, setting their 'params' value - def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]: - new_group["params"] = group["params"] - return new_group - - param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] - self.__setstate__({"state": state, "param_groups": param_groups}) - - for post_hook in self._optimizer_load_state_dict_post_hooks.values(): - post_hook(self) - - def init_optimizer_and_grad_accumulator( parametrization_method: ParametrizationMethod, model: nn.Module, diff --git a/src/nanotron/optim/base.py b/src/nanotron/optim/base.py index 34c33f42..9418b44a 100644 --- a/src/nanotron/optim/base.py +++ b/src/nanotron/optim/base.py @@ -1,7 +1,28 @@ from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Set, TypeVar, Union +from collections import defaultdict +from copy import deepcopy +from itertools import chain +from typing import ( + Any, + Callable, + DefaultDict, + Dict, + Hashable, + Iterable, + List, + Optional, + Set, + Tuple, + TypeVar, + Union, +) import torch +from typing_extensions import TypeAlias + +Args: TypeAlias = Tuple[Any, ...] +Kwargs: TypeAlias = Dict[str, Any] +StateDict: TypeAlias = Dict[str, Any] class BaseOptimizer(ABC): @@ -46,3 +67,114 @@ def inherit_from(self, cls) -> bool: Optimizer = TypeVar("Optimizer", BaseOptimizer, torch.optim.Optimizer) + + +# Modified from torch.optim.Optimizer._process_value_according_to_param_policy +@staticmethod +def _process_value_according_to_param_policy( + param: torch.Tensor, + value: torch.Tensor, + param_id: int, + param_groups: List[Dict[Any, Any]], + map_location: Optional[Union[str, torch.device]], + key: Hashable = None, +) -> torch.Tensor: + # If map_location is specified, use it instead of param.device + target_device = map_location if map_location is not None else param.device + + fused = False + capturable = False + assert param_groups is not None + for pg in param_groups: + if param_id in pg["params"]: + fused = pg["fused"] if "fused" in pg else False + capturable = pg["capturable"] if "capturable" in pg else False + break + + if key == "step": + if capturable or fused: + return value.to(dtype=torch.float32, device=target_device) + else: + return value + else: + if param.is_floating_point(): + return value.to(dtype=param.dtype, device=target_device) + else: + return value.to(device=target_device) + + +# Modified from torch.optim.Optimizer.load_state_dict +@torch._disable_dynamo +def custom_load_state_dict( + self, state_dict: StateDict, map_location: Optional[Union[str, torch.device]] = "cpu" +) -> None: + r"""Loads the optimizer state. + + Args: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + map_location (str or torch.device, optional): Device where to load the optimizer states. + If None, states will be loaded to the same device as their corresponding parameters. + Default: None + """ + + # shallow copy, to be consistent with module API + state_dict = state_dict.copy() + + for pre_hook in self._optimizer_load_state_dict_pre_hooks.values(): + hook_result = pre_hook(self, state_dict) + if hook_result is not None: + state_dict = hook_result + + # Validate the state_dict + groups = self.param_groups + + # Deepcopy as we write into saved_groups later to update state + saved_groups = deepcopy(state_dict["param_groups"]) + + if len(groups) != len(saved_groups): + raise ValueError("loaded state dict has a different number of " "parameter groups") + param_lens = (len(g["params"]) for g in groups) + saved_lens = (len(g["params"]) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + raise ValueError( + "loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group" + ) + + # Update the state + id_map = dict( + zip(chain.from_iterable(g["params"] for g in saved_groups), chain.from_iterable(g["params"] for g in groups)) + ) + + def _cast(param, value, param_id=None, param_groups=None, key=None): + r"""Make a deep copy of value, casting all tensors to device of param.""" + if isinstance(value, torch.Tensor): + return _process_value_according_to_param_policy(param, value, param_id, param_groups, map_location, key) + elif isinstance(value, dict): + return {k: _cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items()} + elif isinstance(value, Iterable): + return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) + else: + return value + + # Copy state assigned to params (and cast tensors to appropriate types). + # State that is not assigned to params is copied as is (needed for + # backward compatibility). + state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict) + for k, v in state_dict["state"].items(): + if k in id_map: + param = id_map[k] + state[param] = _cast(param, v, param_id=k, param_groups=state_dict["param_groups"]) + else: + state[k] = v + + # Update parameter groups, setting their 'params' value + def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]: + new_group["params"] = group["params"] + return new_group + + param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] + self.__setstate__({"state": state, "param_groups": param_groups}) + + for post_hook in self._optimizer_load_state_dict_post_hooks.values(): + post_hook(self) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 44e4e32e..20ff2ffe 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -1,4 +1,5 @@ import datetime +import gc import json import os import shutil @@ -437,8 +438,6 @@ def train( prof = get_profiler(config=self.config) # free memory - import gc - gc.collect() torch.cuda.empty_cache() with prof: From 3a2a6c70a4f0d811c3a8f4d06cf11fca4816c22b Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Fri, 22 Nov 2024 09:43:16 +0000 Subject: [PATCH 05/17] setup custom_load_state_dict for all torch optimizers --- src/nanotron/helpers.py | 18 +++--------------- .../optim/inherit_from_other_optimizer.py | 11 +++++++++-- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 6ff564a8..73ca3484 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -7,14 +7,7 @@ from datetime import datetime from functools import partial from math import ceil -from typing import ( - Any, - Dict, - Iterable, - List, - Optional, - Tuple, -) +from typing import Any, Dict, Iterable, List, Optional, Tuple import numpy as np import torch @@ -29,7 +22,7 @@ from nanotron.distributed import ProcessGroup from nanotron.logging import LogItem, log_rank from nanotron.models.base import NanotronModel -from nanotron.optim.base import BaseOptimizer, Optimizer, custom_load_state_dict +from nanotron.optim.base import BaseOptimizer, Optimizer from nanotron.optim.gradient_accumulator import ( FP32GradBucketManager, FP32GradientAccumulator, @@ -335,7 +328,7 @@ def basic_optimizer_builder(named_param_groups): if optimizer_args.optimizer_factory.name == "adamW": def optimizer(param_groups): - base_optimizer = torch.optim.AdamW( + return torch.optim.AdamW( param_groups, lr=optimizer_args.learning_rate_scheduler.learning_rate, weight_decay=optimizer_args.weight_decay, @@ -343,11 +336,6 @@ def optimizer(param_groups): betas=(optimizer_args.optimizer_factory.adam_beta1, optimizer_args.optimizer_factory.adam_beta2), fused=optimizer_args.optimizer_factory.torch_adam_is_fused, ) - # Replace the load_state_dict method with our custom implementation that enables CPU offload - base_optimizer.load_state_dict = lambda state_dict, map_location=None: custom_load_state_dict( - base_optimizer, state_dict, map_location=map_location - ) - return base_optimizer elif optimizer_args.optimizer_factory.name == "sgd": diff --git a/src/nanotron/optim/inherit_from_other_optimizer.py b/src/nanotron/optim/inherit_from_other_optimizer.py index 53b57284..7376a0b3 100644 --- a/src/nanotron/optim/inherit_from_other_optimizer.py +++ b/src/nanotron/optim/inherit_from_other_optimizer.py @@ -3,14 +3,21 @@ import torch -from nanotron.optim.base import BaseOptimizer, Optimizer +from nanotron.optim.base import BaseOptimizer, Optimizer, custom_load_state_dict class InheritFromOtherOptimizer(BaseOptimizer): def __init__(self, optimizer: Optimizer, id_to_name: Dict[int, str]): - self.optimizer: Optimizer = optimizer self.id_to_name = id_to_name + # if self.optimizer is from torch we replace load_state_dict with the one from torch + if isinstance(optimizer, torch.optim.Optimizer): + # Replace the load_state_dict method with our custom implementation that enables CPU offload + optimizer.load_state_dict = lambda state_dict, map_location=None: custom_load_state_dict( + optimizer, state_dict, map_location=map_location + ) + self.optimizer: Optimizer = optimizer + def __getstate__(self): return self.optimizer.__getstate__() From 0a7801ad73e21d4ab9d44272733453f2829f83b2 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Fri, 22 Nov 2024 09:57:53 +0000 Subject: [PATCH 06/17] . --- src/nanotron/optim/base.py | 4 +--- src/nanotron/optim/inherit_from_other_optimizer.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/nanotron/optim/base.py b/src/nanotron/optim/base.py index 9418b44a..44066d25 100644 --- a/src/nanotron/optim/base.py +++ b/src/nanotron/optim/base.py @@ -105,9 +105,7 @@ def _process_value_according_to_param_policy( # Modified from torch.optim.Optimizer.load_state_dict @torch._disable_dynamo -def custom_load_state_dict( - self, state_dict: StateDict, map_location: Optional[Union[str, torch.device]] = "cpu" -) -> None: +def custom_load_state_dict(self, state_dict: StateDict, map_location: Union[str, torch.device]) -> None: r"""Loads the optimizer state. Args: diff --git a/src/nanotron/optim/inherit_from_other_optimizer.py b/src/nanotron/optim/inherit_from_other_optimizer.py index 7376a0b3..d67cecdf 100644 --- a/src/nanotron/optim/inherit_from_other_optimizer.py +++ b/src/nanotron/optim/inherit_from_other_optimizer.py @@ -13,9 +13,15 @@ def __init__(self, optimizer: Optimizer, id_to_name: Dict[int, str]): # if self.optimizer is from torch we replace load_state_dict with the one from torch if isinstance(optimizer, torch.optim.Optimizer): # Replace the load_state_dict method with our custom implementation that enables CPU offload - optimizer.load_state_dict = lambda state_dict, map_location=None: custom_load_state_dict( - optimizer, state_dict, map_location=map_location + original_load_state_dict = optimizer.load_state_dict + optimizer.load_state_dict = ( + lambda state_dict, map_location=None: custom_load_state_dict( + optimizer, state_dict, map_location=map_location + ) + if map_location is not None + else original_load_state_dict(state_dict) ) + self.optimizer: Optimizer = optimizer def __getstate__(self): From 77ba96c3c51d9c7fd441a2242a72d3d249175009 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Fri, 22 Nov 2024 09:58:46 +0000 Subject: [PATCH 07/17] . --- src/nanotron/optim/inherit_from_other_optimizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/nanotron/optim/inherit_from_other_optimizer.py b/src/nanotron/optim/inherit_from_other_optimizer.py index d67cecdf..039bb710 100644 --- a/src/nanotron/optim/inherit_from_other_optimizer.py +++ b/src/nanotron/optim/inherit_from_other_optimizer.py @@ -21,7 +21,6 @@ def __init__(self, optimizer: Optimizer, id_to_name: Dict[int, str]): if map_location is not None else original_load_state_dict(state_dict) ) - self.optimizer: Optimizer = optimizer def __getstate__(self): From 26ac3e0ad74ba3b64432bf6b591987f150cadbfa Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Fri, 22 Nov 2024 09:58:57 +0000 Subject: [PATCH 08/17] . --- src/nanotron/optim/inherit_from_other_optimizer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/nanotron/optim/inherit_from_other_optimizer.py b/src/nanotron/optim/inherit_from_other_optimizer.py index 039bb710..d67cecdf 100644 --- a/src/nanotron/optim/inherit_from_other_optimizer.py +++ b/src/nanotron/optim/inherit_from_other_optimizer.py @@ -21,6 +21,7 @@ def __init__(self, optimizer: Optimizer, id_to_name: Dict[int, str]): if map_location is not None else original_load_state_dict(state_dict) ) + self.optimizer: Optimizer = optimizer def __getstate__(self): From 62fa6263bf8db4260962d58e80b6821db45d9680 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Fri, 22 Nov 2024 10:07:17 +0000 Subject: [PATCH 09/17] fix map_location --- src/nanotron/serialize/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/serialize/optimizer.py b/src/nanotron/serialize/optimizer.py index 867900cd..ca52d2df 100644 --- a/src/nanotron/serialize/optimizer.py +++ b/src/nanotron/serialize/optimizer.py @@ -351,7 +351,7 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) - ) state_dict["state"][param_index][state_name] = sliced_tensor - optimizer.load_state_dict(state_dict, map_location="cpu") + optimizer.load_state_dict(state_dict, map_location=map_location) def load_lr_scheduler( From ef931bc66b399d159c44adf04a395aab30ec5b52 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Fri, 22 Nov 2024 10:11:13 +0000 Subject: [PATCH 10/17] step can be on cpu or gpu --- tests/test_serialize.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/tests/test_serialize.py b/tests/test_serialize.py index 329ff279..520c6517 100644 --- a/tests/test_serialize.py +++ b/tests/test_serialize.py @@ -139,12 +139,45 @@ def _test_save_and_load_optimizer(parallel_context: ParallelContext, test_contex else: assert not match, "Newly initialised optimizer should not match." - load_optimizer(optimizer=new_optimizer, parallel_context=parallel_context, root_folder=store_folder) + load_optimizer( + optimizer=new_optimizer, parallel_context=parallel_context, root_folder=store_folder, map_location=None + ) # Assert the optimizer states are exactly the same after loading. match, msg = is_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) assert match, msg + # Test loading optimizer states to CPU + cpu_optimizer = NamedOptimizer( + named_params_or_groups=model.named_parameters(), + optimizer_builder=lambda params: torch.optim.AdamW(params), + ) + + # Load optimizer states to CPU + load_optimizer( + optimizer=cpu_optimizer, parallel_context=parallel_context, root_folder=store_folder, map_location="cpu" + ) + + # Get state dicts + gpu_state = optimizer.state_dict() + cpu_state = cpu_optimizer.state_dict() + + # Check that states match except for device + for param_id in gpu_state["state"]: + for key, gpu_value in gpu_state["state"][param_id].items(): + cpu_value = cpu_state["state"][param_id][key] + if isinstance(gpu_value, torch.Tensor): + assert torch.equal(gpu_value.cpu(), cpu_value), f"Values don't match for param {param_id}, key {key}" + if key != "step": # Skip device checks for 'step' key + assert ( + cpu_value.device.type == "cpu" + ), f"CPU optimizer state should be on CPU for param {param_id}, key {key}" + assert ( + gpu_value.device.type == "cuda" + ), f"GPU optimizer state should be on CUDA for param {param_id}, key {key}" + else: + assert gpu_value == cpu_value, f"Non-tensor values don't match for param {param_id}, key {key}" + parallel_context.destroy() From b10e4a3b2fc93298afcebb38cacde22270ceb7a8 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Fri, 22 Nov 2024 10:58:16 +0000 Subject: [PATCH 11/17] skip pp case for now --- tests/test_serialize.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_serialize.py b/tests/test_serialize.py index 520c6517..6442f6b5 100644 --- a/tests/test_serialize.py +++ b/tests/test_serialize.py @@ -94,7 +94,8 @@ def _test_save_and_load_model(parallel_context: ParallelContext, test_context: T @rerun_if_address_is_in_use() def test_save_and_load_optimizer(tp: int, dp: int, pp: int): test_context = TestContext() - # We use DP=2 as we're interested in testing that one + if pp > 1: + pytest.skip("Pipeline parallelism not supported for this test yet") init_distributed(tp=tp, dp=dp, pp=pp)(_test_save_and_load_optimizer)(test_context=test_context) @@ -138,7 +139,6 @@ def _test_save_and_load_optimizer(parallel_context: ParallelContext, test_contex pass else: assert not match, "Newly initialised optimizer should not match." - load_optimizer( optimizer=new_optimizer, parallel_context=parallel_context, root_folder=store_folder, map_location=None ) From 8e35c4c5193d0bcfb960f511cd795028173193fb Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Fri, 22 Nov 2024 10:59:17 +0000 Subject: [PATCH 12/17] add small tests --- src/nanotron/optim/named_optimizer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/nanotron/optim/named_optimizer.py b/src/nanotron/optim/named_optimizer.py index 07363caa..d51de433 100644 --- a/src/nanotron/optim/named_optimizer.py +++ b/src/nanotron/optim/named_optimizer.py @@ -62,9 +62,11 @@ def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, to assert set(self.id_to_name.values()) == set( state_dict["names"].values() ), f"Elements don't match:\n - Elements in `self.id_to_name` that aren't in the other one: {set(self.id_to_name.values()) - set(state_dict['names'].values())}\n - Elements in `state_dict[\"names\"]` that aren't in the other one: {set(state_dict['names'].values()) - set(self.id_to_name.values())}" - + assert len(state_dict["state"]) == len( + state_dict["names"] + ), f"Number of params in loaded state dict ({len(state_dict['state'])}) doesn't match number of names ({len(state_dict['names'])})" + assert len(state_dict["state"]) > 0, "Loading empty state dict" OPTIMIZER_STATE_KEYS = sorted(state_dict["state"][0].keys() - {"step"}) - assert len(state_dict["state"]) == len(state_dict["names"]) for key in OPTIMIZER_STATE_KEYS: for k, state in state_dict["state"].items(): assert ( From 261e0435f7336c7d3db7f4154e28c918ba0f7c7c Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Fri, 22 Nov 2024 15:10:46 +0000 Subject: [PATCH 13/17] update test_serialize --- tests/test_serialize.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_serialize.py b/tests/test_serialize.py index 6442f6b5..5234710e 100644 --- a/tests/test_serialize.py +++ b/tests/test_serialize.py @@ -139,9 +139,7 @@ def _test_save_and_load_optimizer(parallel_context: ParallelContext, test_contex pass else: assert not match, "Newly initialised optimizer should not match." - load_optimizer( - optimizer=new_optimizer, parallel_context=parallel_context, root_folder=store_folder, map_location=None - ) + load_optimizer(optimizer=new_optimizer, parallel_context=parallel_context, root_folder=store_folder) # Assert the optimizer states are exactly the same after loading. match, msg = is_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) From 9d4a7befbf240a04669f962df505e886dcd2c640 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Fri, 22 Nov 2024 15:10:52 +0000 Subject: [PATCH 14/17] update trainer --- src/nanotron/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 20ff2ffe..d79188bf 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -197,6 +197,7 @@ def __init__( root_folder=self.init_checkpoint_path, param_shard_metadata=self.param_shard_metadata, model=self.unwrapped_model, + map_location="cpu", ) # Init learning rate scheduler From a87b83efee47e0c873e11d1033de05b1482de431 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Fri, 22 Nov 2024 15:10:57 +0000 Subject: [PATCH 15/17] update optimizer --- src/nanotron/serialize/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/serialize/optimizer.py b/src/nanotron/serialize/optimizer.py index ca52d2df..0e1856d4 100644 --- a/src/nanotron/serialize/optimizer.py +++ b/src/nanotron/serialize/optimizer.py @@ -147,7 +147,7 @@ def load_optimizer( optimizer: optim.BaseOptimizer, parallel_context: ParallelContext, root_folder: Path, - map_location: Optional[str] = "cpu", + map_location: Optional[str] = None, param_shard_metadata: Tuple[Tuple[int, int], TensorMetadata] = None, # (pp_rank, tp_rank) -> TensorMetadata model: Optional[nn.Module] = None, ): From ca0ffa6cc498c2608656a38b97f78b062db3cc69 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Fri, 22 Nov 2024 15:11:10 +0000 Subject: [PATCH 16/17] remove unused func --- src/nanotron/serialize/main.py | 44 ++-------------------------------- 1 file changed, 2 insertions(+), 42 deletions(-) diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index 7991cbd4..e87d8dbb 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -5,7 +5,6 @@ import torch from datasets.download.streaming_download_manager import xPath from torch import nn -from torch.nn.parallel import DistributedDataParallel from torch.optim.lr_scheduler import LambdaLR from nanotron import distributed as dist @@ -21,14 +20,12 @@ assert_tensor_synced_across_pg, check_optim_state_in_sync, ) -from nanotron.serialize.metadata import CheckpointMetadata, TrainingMetadata, load_meta, save_meta +from nanotron.serialize.metadata import TrainingMetadata, save_meta from nanotron.serialize.optimizer import ( - load_lr_scheduler, - load_optimizer, save_lr_scheduler, save_optimizer, ) -from nanotron.serialize.weights import load_weights, save_weights +from nanotron.serialize.weights import save_weights """ We're going to use safetensors. The reason is that loading segments is going to be much easier @@ -206,43 +203,6 @@ def save( dist.barrier(parallel_context.world_pg) -def load( - model: nn.Module, - optimizer: optim.BaseOptimizer, - lr_scheduler, - parallel_context: ParallelContext, - root_folder: Path, -) -> CheckpointMetadata: - """ - Load checkpoint, raise if checkpoint is assumed corrupted. Inplace updates `model` and `optimizer` to have the newest parameters. - TODO @thomasw21: Make this topology agnostic - - :param filepath: Path - :return: - """ - checkpoint_metadata = load_meta(parallel_context=parallel_context, root_folder=root_folder) - load_weights(model=model, parallel_context=parallel_context, root_folder=root_folder) - - # SANITY CHECK: assert that optimizer's named_params still point to model's params (check only the first one) - if isinstance(optimizer, optim.ZeroDistributedOptimizer): - if ( - len(optimizer.zero_named_param_groups) > 0 - and len(optimizer.zero_named_param_groups[0]["named_params"]) > 0 - ): - optim_model_param_name, optim_model_param = optimizer.zero_named_param_groups[0]["named_params"][0] - if isinstance(model, DistributedDataParallel): - optim_model_param_name = f"module.{optim_model_param_name}" - param = next(p for n, p in model.named_parameters() if n == optim_model_param_name) - assert param.data_ptr() == optim_model_param.data_ptr() - - load_optimizer(optimizer=optimizer, parallel_context=parallel_context, root_folder=root_folder) - load_lr_scheduler( - lr_scheduler=lr_scheduler, - root_folder=root_folder, - ) - return checkpoint_metadata - - def parse_ckpt_path(config: Config, parallel_context: ParallelContext) -> Optional[Path]: """Parse checkpoint path from config and download checkpoint from S3 if needed. From f00a380e37d87baa3d863410e191969ad5f2a85a Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Fri, 22 Nov 2024 15:13:20 +0000 Subject: [PATCH 17/17] . --- src/nanotron/serialize/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/nanotron/serialize/__init__.py b/src/nanotron/serialize/__init__.py index ae6ef264..7fc7b0a9 100644 --- a/src/nanotron/serialize/__init__.py +++ b/src/nanotron/serialize/__init__.py @@ -2,3 +2,4 @@ from nanotron.serialize.main import * from nanotron.serialize.optimizer import * from nanotron.serialize.random import * +from nanotron.serialize.weights import *