diff --git a/src/nanotron/optim/base.py b/src/nanotron/optim/base.py index fb77f124..44066d25 100644 --- a/src/nanotron/optim/base.py +++ b/src/nanotron/optim/base.py @@ -1,7 +1,28 @@ from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Set, TypeVar +from collections import defaultdict +from copy import deepcopy +from itertools import chain +from typing import ( + Any, + Callable, + DefaultDict, + Dict, + Hashable, + Iterable, + List, + Optional, + Set, + Tuple, + TypeVar, + Union, +) import torch +from typing_extensions import TypeAlias + +Args: TypeAlias = Tuple[Any, ...] +Kwargs: TypeAlias = Dict[str, Any] +StateDict: TypeAlias = Dict[str, Any] class BaseOptimizer(ABC): @@ -34,7 +55,7 @@ def state_dict(self) -> dict: ... @abstractmethod - def load_state_dict(self, state_dict: dict) -> None: + def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None: ... @abstractmethod @@ -46,3 +67,112 @@ def inherit_from(self, cls) -> bool: Optimizer = TypeVar("Optimizer", BaseOptimizer, torch.optim.Optimizer) + + +# Modified from torch.optim.Optimizer._process_value_according_to_param_policy +@staticmethod +def _process_value_according_to_param_policy( + param: torch.Tensor, + value: torch.Tensor, + param_id: int, + param_groups: List[Dict[Any, Any]], + map_location: Optional[Union[str, torch.device]], + key: Hashable = None, +) -> torch.Tensor: + # If map_location is specified, use it instead of param.device + target_device = map_location if map_location is not None else param.device + + fused = False + capturable = False + assert param_groups is not None + for pg in param_groups: + if param_id in pg["params"]: + fused = pg["fused"] if "fused" in pg else False + capturable = pg["capturable"] if "capturable" in pg else False + break + + if key == "step": + if capturable or fused: + return value.to(dtype=torch.float32, device=target_device) + else: + return value + else: + if param.is_floating_point(): + return value.to(dtype=param.dtype, device=target_device) + else: + return value.to(device=target_device) + + +# Modified from torch.optim.Optimizer.load_state_dict +@torch._disable_dynamo +def custom_load_state_dict(self, state_dict: StateDict, map_location: Union[str, torch.device]) -> None: + r"""Loads the optimizer state. + + Args: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + map_location (str or torch.device, optional): Device where to load the optimizer states. + If None, states will be loaded to the same device as their corresponding parameters. + Default: None + """ + + # shallow copy, to be consistent with module API + state_dict = state_dict.copy() + + for pre_hook in self._optimizer_load_state_dict_pre_hooks.values(): + hook_result = pre_hook(self, state_dict) + if hook_result is not None: + state_dict = hook_result + + # Validate the state_dict + groups = self.param_groups + + # Deepcopy as we write into saved_groups later to update state + saved_groups = deepcopy(state_dict["param_groups"]) + + if len(groups) != len(saved_groups): + raise ValueError("loaded state dict has a different number of " "parameter groups") + param_lens = (len(g["params"]) for g in groups) + saved_lens = (len(g["params"]) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + raise ValueError( + "loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group" + ) + + # Update the state + id_map = dict( + zip(chain.from_iterable(g["params"] for g in saved_groups), chain.from_iterable(g["params"] for g in groups)) + ) + + def _cast(param, value, param_id=None, param_groups=None, key=None): + r"""Make a deep copy of value, casting all tensors to device of param.""" + if isinstance(value, torch.Tensor): + return _process_value_according_to_param_policy(param, value, param_id, param_groups, map_location, key) + elif isinstance(value, dict): + return {k: _cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items()} + elif isinstance(value, Iterable): + return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) + else: + return value + + # Copy state assigned to params (and cast tensors to appropriate types). + # State that is not assigned to params is copied as is (needed for + # backward compatibility). + state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict) + for k, v in state_dict["state"].items(): + if k in id_map: + param = id_map[k] + state[param] = _cast(param, v, param_id=k, param_groups=state_dict["param_groups"]) + else: + state[k] = v + + # Update parameter groups, setting their 'params' value + def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]: + new_group["params"] = group["params"] + return new_group + + param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] + self.__setstate__({"state": state, "param_groups": param_groups}) + + for post_hook in self._optimizer_load_state_dict_post_hooks.values(): + post_hook(self) diff --git a/src/nanotron/optim/inherit_from_other_optimizer.py b/src/nanotron/optim/inherit_from_other_optimizer.py index 2ddd36d0..d67cecdf 100644 --- a/src/nanotron/optim/inherit_from_other_optimizer.py +++ b/src/nanotron/optim/inherit_from_other_optimizer.py @@ -1,16 +1,29 @@ from functools import cache -from typing import Callable, Dict, Optional, Set +from typing import Callable, Dict, Optional, Set, Union import torch -from nanotron.optim.base import BaseOptimizer, Optimizer +from nanotron.optim.base import BaseOptimizer, Optimizer, custom_load_state_dict class InheritFromOtherOptimizer(BaseOptimizer): def __init__(self, optimizer: Optimizer, id_to_name: Dict[int, str]): - self.optimizer: Optimizer = optimizer self.id_to_name = id_to_name + # if self.optimizer is from torch we replace load_state_dict with the one from torch + if isinstance(optimizer, torch.optim.Optimizer): + # Replace the load_state_dict method with our custom implementation that enables CPU offload + original_load_state_dict = optimizer.load_state_dict + optimizer.load_state_dict = ( + lambda state_dict, map_location=None: custom_load_state_dict( + optimizer, state_dict, map_location=map_location + ) + if map_location is not None + else original_load_state_dict(state_dict) + ) + + self.optimizer: Optimizer = optimizer + def __getstate__(self): return self.optimizer.__getstate__() @@ -33,8 +46,8 @@ def state_dict_additional_keys(self) -> Set[str]: def state_dict(self) -> dict: return self.optimizer.state_dict() - def load_state_dict(self, state_dict: dict) -> None: - return self.optimizer.load_state_dict(state_dict) + def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None: + return self.optimizer.load_state_dict(state_dict, map_location=map_location) def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: return self.optimizer.step(closure=closure) diff --git a/src/nanotron/optim/named_optimizer.py b/src/nanotron/optim/named_optimizer.py index 5f11710a..d51de433 100644 --- a/src/nanotron/optim/named_optimizer.py +++ b/src/nanotron/optim/named_optimizer.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, Tuple, Union +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union import torch @@ -58,17 +58,19 @@ def state_dict(self) -> dict: } return optim_state_dict - def load_state_dict(self, state_dict: dict) -> None: + def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None: assert set(self.id_to_name.values()) == set( state_dict["names"].values() ), f"Elements don't match:\n - Elements in `self.id_to_name` that aren't in the other one: {set(self.id_to_name.values()) - set(state_dict['names'].values())}\n - Elements in `state_dict[\"names\"]` that aren't in the other one: {set(state_dict['names'].values()) - set(self.id_to_name.values())}" - + assert len(state_dict["state"]) == len( + state_dict["names"] + ), f"Number of params in loaded state dict ({len(state_dict['state'])}) doesn't match number of names ({len(state_dict['names'])})" + assert len(state_dict["state"]) > 0, "Loading empty state dict" OPTIMIZER_STATE_KEYS = sorted(state_dict["state"][0].keys() - {"step"}) - assert len(state_dict["state"]) == len(state_dict["names"]) for key in OPTIMIZER_STATE_KEYS: for k, state in state_dict["state"].items(): assert ( key in state ), f"Key {key} not found in state dict: {state} which corresponds to param_name: {state_dict['names'][k]}" - return super().load_state_dict(state_dict) + return super().load_state_dict(state_dict, map_location=map_location) diff --git a/src/nanotron/optim/optimizer_from_gradient_accumulator.py b/src/nanotron/optim/optimizer_from_gradient_accumulator.py index 01be7cb5..06426b0b 100644 --- a/src/nanotron/optim/optimizer_from_gradient_accumulator.py +++ b/src/nanotron/optim/optimizer_from_gradient_accumulator.py @@ -67,7 +67,7 @@ def state_dict(self) -> dict: state_dict["gradient_accumulator"] = self.gradient_accumulator.state_dict() return state_dict - def load_state_dict(self, state_dict: dict) -> None: + def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None: gradient_accumulator_state_dict = state_dict.pop("gradient_accumulator") - super().load_state_dict(state_dict) + super().load_state_dict(state_dict, map_location=map_location) self.gradient_accumulator.load_state_dict(gradient_accumulator_state_dict) diff --git a/src/nanotron/sanity_checks.py b/src/nanotron/sanity_checks.py index 650607da..56ef1e2e 100644 --- a/src/nanotron/sanity_checks.py +++ b/src/nanotron/sanity_checks.py @@ -164,6 +164,7 @@ def before_optim_step_sanity_checks( parallel_context: ParallelContext, unwrapped_model: NanotronModel, grad_accumulator: GradientAccumulator, + optimizer: optim.BaseOptimizer, ) -> None: if not config.general.ignore_sanity_checks: # SANITY CHECK: Test tied weights gradients are synchronized @@ -232,6 +233,9 @@ def before_optim_step_sanity_checks( msg=lambda err: f"[Before optimizer step] Tied weights {name} are not synchronized. {err}", ) + # SANITY CHECK: Check that optimizer states are synchronized across DP + check_optim_state_in_sync(optimizer.state_dict(), parallel_context.dp_pg) + # SANITY CHECK: run model specific sanity checks unwrapped_model.before_optim_step_sanity_checks() @@ -259,12 +263,11 @@ def after_optim_step_sanity_checks( unwrapped_model.after_optim_step_sanity_checks() -def check_optim_state_in_sync(optimizer: optim.BaseOptimizer, pg: dist.ProcessGroup): - for _, optim_state in sorted(optimizer.state_dict()["state"].items(), key=lambda x: x[0]): +def check_optim_state_in_sync(optim_state_dict: dict, pg: dist.ProcessGroup): + for _, optim_state in sorted(optim_state_dict["state"].items(), key=lambda x: x[0]): for name, tensor in optim_state.items(): if name == "step": - tensor = tensor.to("cuda") - + continue assert_tensor_synced_across_pg( tensor=tensor, pg=pg, msg=lambda err: f"{name} are not synced across DP {err}" ) diff --git a/src/nanotron/serialize/__init__.py b/src/nanotron/serialize/__init__.py index ae6ef264..7fc7b0a9 100644 --- a/src/nanotron/serialize/__init__.py +++ b/src/nanotron/serialize/__init__.py @@ -2,3 +2,4 @@ from nanotron.serialize.main import * from nanotron.serialize.optimizer import * from nanotron.serialize.random import * +from nanotron.serialize.weights import * diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index 7991cbd4..e87d8dbb 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -5,7 +5,6 @@ import torch from datasets.download.streaming_download_manager import xPath from torch import nn -from torch.nn.parallel import DistributedDataParallel from torch.optim.lr_scheduler import LambdaLR from nanotron import distributed as dist @@ -21,14 +20,12 @@ assert_tensor_synced_across_pg, check_optim_state_in_sync, ) -from nanotron.serialize.metadata import CheckpointMetadata, TrainingMetadata, load_meta, save_meta +from nanotron.serialize.metadata import TrainingMetadata, save_meta from nanotron.serialize.optimizer import ( - load_lr_scheduler, - load_optimizer, save_lr_scheduler, save_optimizer, ) -from nanotron.serialize.weights import load_weights, save_weights +from nanotron.serialize.weights import save_weights """ We're going to use safetensors. The reason is that loading segments is going to be much easier @@ -206,43 +203,6 @@ def save( dist.barrier(parallel_context.world_pg) -def load( - model: nn.Module, - optimizer: optim.BaseOptimizer, - lr_scheduler, - parallel_context: ParallelContext, - root_folder: Path, -) -> CheckpointMetadata: - """ - Load checkpoint, raise if checkpoint is assumed corrupted. Inplace updates `model` and `optimizer` to have the newest parameters. - TODO @thomasw21: Make this topology agnostic - - :param filepath: Path - :return: - """ - checkpoint_metadata = load_meta(parallel_context=parallel_context, root_folder=root_folder) - load_weights(model=model, parallel_context=parallel_context, root_folder=root_folder) - - # SANITY CHECK: assert that optimizer's named_params still point to model's params (check only the first one) - if isinstance(optimizer, optim.ZeroDistributedOptimizer): - if ( - len(optimizer.zero_named_param_groups) > 0 - and len(optimizer.zero_named_param_groups[0]["named_params"]) > 0 - ): - optim_model_param_name, optim_model_param = optimizer.zero_named_param_groups[0]["named_params"][0] - if isinstance(model, DistributedDataParallel): - optim_model_param_name = f"module.{optim_model_param_name}" - param = next(p for n, p in model.named_parameters() if n == optim_model_param_name) - assert param.data_ptr() == optim_model_param.data_ptr() - - load_optimizer(optimizer=optimizer, parallel_context=parallel_context, root_folder=root_folder) - load_lr_scheduler( - lr_scheduler=lr_scheduler, - root_folder=root_folder, - ) - return checkpoint_metadata - - def parse_ckpt_path(config: Config, parallel_context: ParallelContext) -> Optional[Path]: """Parse checkpoint path from config and download checkpoint from S3 if needed. diff --git a/src/nanotron/serialize/optimizer.py b/src/nanotron/serialize/optimizer.py index c9d2da6b..0e1856d4 100644 --- a/src/nanotron/serialize/optimizer.py +++ b/src/nanotron/serialize/optimizer.py @@ -2,7 +2,7 @@ import warnings from collections import defaultdict from pathlib import Path -from typing import Optional, Tuple +from typing import Dict, Optional, Tuple import torch from torch import nn @@ -19,7 +19,6 @@ ) from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import NanotronParameter -from nanotron.sanity_checks import check_optim_state_in_sync from nanotron.serialize.metadata import TensorMetadata from nanotron.serialize.utils import ObjectType, merge_and_shard_tp_tensors @@ -125,6 +124,24 @@ def save_lr_scheduler( ) +# Helper functions to move optimizer states +@torch.no_grad() +def state_dict_to_device(state_dict: Dict, device: str) -> Dict: + assert ( + state_dict["state"][0]["exp_avg"].device.type == "cpu" + ), "Optimizer states should be on CPU to avoid extra memory usage when loading from checkpoint" + torch.cuda.empty_cache() + + for _, optim_state in sorted(state_dict["state"].items(), key=lambda x: x[0]): + for name, tensor in optim_state.items(): + optim_state[name] = tensor.to(device) + + assert ( + state_dict["state"][0]["exp_avg"].device.type == "cuda" + ), "Optimizer states should be on GPU because model is on GPU" + torch.cuda.empty_cache() + + @torch.no_grad() def load_optimizer( optimizer: optim.BaseOptimizer, @@ -135,8 +152,6 @@ def load_optimizer( model: Optional[nn.Module] = None, ): root_folder = root_folder / "optimizer" - # `load_state_dict` copies the state dict which can be very large in case of Zero-0 so we load to cpu and then move to the right device - map_location = "cpu" if not optimizer.inherit_from(optim.ZeroDistributedOptimizer) else map_location ckp_optimizer_config_path = root_folder / "optimizer_config.json" with open(ckp_optimizer_config_path, "r") as file: ckp_optimizer_config = json.load(file) @@ -149,9 +164,10 @@ def load_optimizer( if int(ckp_tp_size) != int(parallel_context.tp_pg.size()) or int(ckp_pp_size) != int( parallel_context.pp_pg.size() ): - warnings.warn( - "You are resuming in a different PP size, so optimizer states need to be checked. Feel free to open a PR if you work on this!" - ) + if int(ckp_pp_size) != int(parallel_context.pp_pg.size()): + warnings.warn( + "You are resuming in a different PP size, so optimizer states need to be checked. Feel free to open a PR if you work on this!" + ) assert ( param_shard_metadata is not None ), f"You have to pass how the original parameters are sharded in order to resume in a different tensor parallel size, ckp_tp_size: {ckp_tp_size}, current tp_size: {parallel_context.tp_pg.size()}" @@ -241,8 +257,10 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) - # TODO(xrsrke): free the memory of the shards that isn't # corresponding to the current rank # TODO: maybe better to allocate memory for all states at once - buffer = torch.zeros_like(param, device="cuda", dtype=OPTIMIZER_STATE_DTYPE) - unsharded_buffer = torch.empty(new_unshared_shape, device="cuda", dtype=OPTIMIZER_STATE_DTYPE) + buffer = torch.zeros_like(param, device=map_location, dtype=OPTIMIZER_STATE_DTYPE) + unsharded_buffer = torch.empty( + new_unshared_shape, device=map_location, dtype=OPTIMIZER_STATE_DTYPE + ) for (pp_rank, tp_rank), ckp_optim_state in ckp_sharded_optim_states.items(): old_optim_state_index = find_optim_index_from_param_name( @@ -333,10 +351,7 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) - ) state_dict["state"][param_index][state_name] = sliced_tensor - optimizer.load_state_dict(state_dict) - - if not optimizer.inherit_from(optim.ZeroDistributedOptimizer): - check_optim_state_in_sync(optimizer, parallel_context.dp_pg) + optimizer.load_state_dict(state_dict, map_location=map_location) def load_lr_scheduler( diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 356c3910..d79188bf 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -1,4 +1,5 @@ import datetime +import gc import json import os import shutil @@ -93,7 +94,7 @@ save_random_states, ) from nanotron.serialize.metadata import DataStageMetadata, TrainingMetadata -from nanotron.serialize.optimizer import load_optimizer +from nanotron.serialize.optimizer import load_optimizer, state_dict_to_device logger = logging.get_logger(__name__) @@ -196,6 +197,7 @@ def __init__( root_folder=self.init_checkpoint_path, param_shard_metadata=self.param_shard_metadata, model=self.unwrapped_model, + map_location="cpu", ) # Init learning rate scheduler @@ -432,10 +434,15 @@ def train( # Fix the root_model self.unwrapped_model.module_id_to_prefix[id(self.unwrapped_model)] = "" + self.initial_iter_step = self.metadata.last_train_step + 1 + self.last_iter_step = self.config.tokens.train_steps + prof = get_profiler(config=self.config) + # free memory + gc.collect() torch.cuda.empty_cache() with prof: - for self.iteration_step in range(self.metadata.last_train_step + 1, self.config.tokens.train_steps + 1): + for self.iteration_step in range(self.initial_iter_step, self.last_iter_step + 1): if isinstance(prof, torch.profiler.profile): prof.step() @@ -474,7 +481,7 @@ def training_step( self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.lr_scheduler ) - if self.iteration_step < 5: + if self.iteration_step < self.initial_iter_step + 5: log_memory(logger=logger) outputs = self.pipeline_engine.train_batch_iter( @@ -485,7 +492,7 @@ def training_step( grad_accumulator=self.grad_accumulator, ) - if self.iteration_step < 5: + if self.iteration_step < self.initial_iter_step + 5: log_memory(logger=logger) after_tbi_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator) @@ -531,10 +538,6 @@ def training_step( max_norm=self.config.optimizer.clip_grad, ) - before_optim_step_sanity_checks( - self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator - ) - # Compute DP average loss and overlap with optimizer step if isinstance(outputs[0]["loss"], torch.Tensor): # This is an average on only one data rank. @@ -547,6 +550,14 @@ def training_step( loss_avg = None handle = None + # Move optimizer states back to GPU before optimizer step + if self.init_checkpoint_path is not None and self.iteration_step == self.initial_iter_step: + state_dict_to_device(self.optimizer.state_dict(), "cuda") + + before_optim_step_sanity_checks( + self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.optimizer + ) + # Apply gradient self.optimizer.step() self.optimizer.zero_grad() diff --git a/tests/test_serialize.py b/tests/test_serialize.py index 329ff279..5234710e 100644 --- a/tests/test_serialize.py +++ b/tests/test_serialize.py @@ -94,7 +94,8 @@ def _test_save_and_load_model(parallel_context: ParallelContext, test_context: T @rerun_if_address_is_in_use() def test_save_and_load_optimizer(tp: int, dp: int, pp: int): test_context = TestContext() - # We use DP=2 as we're interested in testing that one + if pp > 1: + pytest.skip("Pipeline parallelism not supported for this test yet") init_distributed(tp=tp, dp=dp, pp=pp)(_test_save_and_load_optimizer)(test_context=test_context) @@ -138,13 +139,43 @@ def _test_save_and_load_optimizer(parallel_context: ParallelContext, test_contex pass else: assert not match, "Newly initialised optimizer should not match." - load_optimizer(optimizer=new_optimizer, parallel_context=parallel_context, root_folder=store_folder) # Assert the optimizer states are exactly the same after loading. match, msg = is_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) assert match, msg + # Test loading optimizer states to CPU + cpu_optimizer = NamedOptimizer( + named_params_or_groups=model.named_parameters(), + optimizer_builder=lambda params: torch.optim.AdamW(params), + ) + + # Load optimizer states to CPU + load_optimizer( + optimizer=cpu_optimizer, parallel_context=parallel_context, root_folder=store_folder, map_location="cpu" + ) + + # Get state dicts + gpu_state = optimizer.state_dict() + cpu_state = cpu_optimizer.state_dict() + + # Check that states match except for device + for param_id in gpu_state["state"]: + for key, gpu_value in gpu_state["state"][param_id].items(): + cpu_value = cpu_state["state"][param_id][key] + if isinstance(gpu_value, torch.Tensor): + assert torch.equal(gpu_value.cpu(), cpu_value), f"Values don't match for param {param_id}, key {key}" + if key != "step": # Skip device checks for 'step' key + assert ( + cpu_value.device.type == "cpu" + ), f"CPU optimizer state should be on CPU for param {param_id}, key {key}" + assert ( + gpu_value.device.type == "cuda" + ), f"GPU optimizer state should be on CUDA for param {param_id}, key {key}" + else: + assert gpu_value == cpu_value, f"Non-tensor values don't match for param {param_id}, key {key}" + parallel_context.destroy()