diff --git a/setup.py b/setup.py index 7e67a75a9a6..f6eefda0dea 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,7 @@ ] }, python_requires=">=3.8.0", - install_requires=["numpy>=1.17", "packaging>=20.0", "psutil", "pyyaml", "torch>=1.10.0", "huggingface_hub"], + install_requires=["numpy>=1.17", "packaging>=20.0", "psutil", "pyyaml", "torch>=1.10.0", "huggingface_hub", "safetensors>=0.3.1"], extras_require=extras, classifiers=[ "Development Status :: 5 - Production/Stable", diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index cc34cd5cb1c..68a033cea36 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -80,7 +80,6 @@ is_ipex_available, is_megatron_lm_available, is_npu_available, - is_safetensors_available, is_torch_version, is_tpu_available, is_xpu_available, @@ -2536,7 +2535,7 @@ def save_model( model: torch.nn.Module, save_directory: Union[str, os.PathLike], max_shard_size: Union[int, str] = "10GB", - safe_serialization: bool = False, + safe_serialization: bool = True, ): """ Save a model so that it can be re-loaded using load_checkpoint_in_model @@ -2557,7 +2556,7 @@ def save_model( - safe_serialization (`bool`, *optional*, defaults to `False`): + safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). Example: @@ -2571,9 +2570,6 @@ def save_model( ``` """ - if safe_serialization and not is_safetensors_available(): - raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.") - if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return @@ -2690,7 +2686,7 @@ def register_save_state_pre_hook(self, hook: Callable[..., None]) -> hooks.Remov self._save_model_state_pre_hook[handle.id] = hook return handle - def save_state(self, output_dir: str = None, **save_model_func_kwargs): + def save_state(self, output_dir: str = None, safe_serialization: bool = True, **save_model_func_kwargs): """ Saves the current states of the model, optimizer, scaler, RNG generators, and registered objects to a folder. @@ -2711,6 +2707,8 @@ def save_state(self, output_dir: str = None, **save_model_func_kwargs): Args: output_dir (`str` or `os.PathLike`): The name of the folder to save all relevant weights and states. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). save_model_func_kwargs (`dict`, *optional*): Additional keyword arguments for saving model which can be passed to the underlying save function, such as optional arguments for DeepSpeed's `save_checkpoint` function. @@ -2815,6 +2813,7 @@ def _inner(folder): self.state.process_index, self.scaler, save_on_each_node=self.project_configuration.save_on_each_node, + safe_serialization=safe_serialization, ) for i, obj in enumerate(self._custom_objects): save_custom_state(obj, output_dir, i, save_on_each_node=self.project_configuration.save_on_each_node) diff --git a/src/accelerate/checkpointing.py b/src/accelerate/checkpointing.py index 7928adce33e..11d30d9fef1 100644 --- a/src/accelerate/checkpointing.py +++ b/src/accelerate/checkpointing.py @@ -12,22 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import random from pathlib import Path from typing import List import numpy as np import torch +from safetensors.torch import load_file from torch.cuda.amp import GradScaler from .utils import ( MODEL_NAME, OPTIMIZER_NAME, RNG_STATE_NAME, + SAFE_MODEL_NAME, + SAFE_WEIGHTS_NAME, SAMPLER_NAME, SCALER_NAME, SCHEDULER_NAME, + WEIGHTS_NAME, get_pretty_name, is_tpu_available, is_xpu_available, @@ -54,10 +57,18 @@ def save_accelerator_state( process_index: int, scaler: GradScaler = None, save_on_each_node: bool = False, + safe_serialization: bool = True, ): """ Saves the current states of the models, optimizers, scaler, and RNG generators to a given directory. + + + If `safe_serialization` is `True`, models will be saved with `safetensors` while the rest are saved using native + `pickle`. + + + Args: output_dir (`str` or `os.PathLike`): The name of the folder to save all relevant weights and states. @@ -75,31 +86,36 @@ def save_accelerator_state( An optional gradient scaler instance to save save_on_each_node (`bool`, *optional*): Whether to save on every node, or only the main node. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). """ + output_dir = Path(output_dir) # Model states for i, state in enumerate(model_states): - weights_name = f"{MODEL_NAME}.bin" if i == 0 else f"{MODEL_NAME}_{i}.bin" - output_model_file = os.path.join(output_dir, weights_name) - save(state, output_model_file, save_on_each_node=save_on_each_node) + weights_name = WEIGHTS_NAME if not safe_serialization else SAFE_WEIGHTS_NAME + if i > 0: + weights_name = weights_name.replace(".", f"_{i}.") + output_model_file = output_dir.joinpath(weights_name) + save(state, output_model_file, save_on_each_node=save_on_each_node, safe_serialization=safe_serialization) logger.info(f"Model weights saved in {output_model_file}") # Optimizer states for i, opt in enumerate(optimizers): state = opt.state_dict() optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin" - output_optimizer_file = os.path.join(output_dir, optimizer_name) - save(state, output_optimizer_file, save_on_each_node=save_on_each_node) + output_optimizer_file = output_dir.joinpath(optimizer_name) + save(state, output_optimizer_file, save_on_each_node=save_on_each_node, safe_serialization=False) logger.info(f"Optimizer state saved in {output_optimizer_file}") # Scheduler states for i, scheduler in enumerate(schedulers): state = scheduler.state_dict() scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin" - output_scheduler_file = os.path.join(output_dir, scheduler_name) - save(state, output_scheduler_file, save_on_each_node=save_on_each_node) + output_scheduler_file = output_dir.joinpath(scheduler_name) + save(state, output_scheduler_file, save_on_each_node=save_on_each_node, safe_serialization=False) logger.info(f"Scheduler state saved in {output_scheduler_file}") # DataLoader states for i, dataloader in enumerate(dataloaders): sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin" - output_sampler_file = os.path.join(output_dir, sampler_name) + output_sampler_file = output_dir.joinpath(sampler_name) # Only save if we have our custom sampler from .data_loader import IterableDatasetShard, SeedableRandomSampler @@ -107,13 +123,13 @@ def save_accelerator_state( sampler = dataloader.sampler.sampler if isinstance(sampler, SeedableRandomSampler): - save(sampler, output_sampler_file, save_on_each_node=save_on_each_node) + save(sampler, output_sampler_file, save_on_each_node=save_on_each_node, safe_serialization=False) logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}") # GradScaler state if scaler is not None: state = scaler.state_dict() - output_scaler_file = os.path.join(output_dir, SCALER_NAME) + output_scaler_file = output_dir.joinpath(SCALER_NAME) torch.save(state, output_scaler_file) logger.info(f"Gradient scaler state saved in {output_scaler_file}") # Random number generator states @@ -128,7 +144,7 @@ def save_accelerator_state( states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all() if is_tpu_available(): states["xm_seed"] = xm.get_rng_state() - output_states_file = os.path.join(output_dir, states_name) + output_states_file = output_dir.joinpath(states_name) torch.save(states, output_states_file) logger.info(f"Random states saved in {output_states_file}") return output_dir @@ -174,17 +190,25 @@ def load_accelerator_state( map_location = "cpu" elif map_location == "on_device": map_location = PartialState().device + + input_dir = Path(input_dir) # Model states for i, model in enumerate(models): - weights_name = f"{MODEL_NAME}.bin" if i == 0 else f"{MODEL_NAME}_{i}.bin" - input_model_file = os.path.join(input_dir, weights_name) - models[i].load_state_dict(torch.load(input_model_file, map_location=map_location), **load_model_func_kwargs) + ending = f"_{i}" if i > 0 else "" + input_model_file = input_dir.joinpath(f"{SAFE_MODEL_NAME}{ending}.safetensors") + if input_model_file.exists(): + state_dict = load_file(input_model_file, device=str(map_location)) + else: + # Load with torch + input_model_file = input_dir.joinpath(f"{MODEL_NAME}{ending}.bin") + state_dict = torch.load(input_model_file, map_location=map_location) + models[i].load_state_dict(state_dict, **load_model_func_kwargs) logger.info("All model weights loaded successfully") # Optimizer states for i, opt in enumerate(optimizers): optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin" - input_optimizer_file = os.path.join(input_dir, optimizer_name) + input_optimizer_file = input_dir.joinpath(optimizer_name) optimizer_state = torch.load(input_optimizer_file, map_location=map_location) optimizers[i].load_state_dict(optimizer_state) logger.info("All optimizer states loaded successfully") @@ -192,13 +216,13 @@ def load_accelerator_state( # Scheduler states for i, scheduler in enumerate(schedulers): scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin" - input_scheduler_file = os.path.join(input_dir, scheduler_name) + input_scheduler_file = input_dir.joinpath(scheduler_name) scheduler.load_state_dict(torch.load(input_scheduler_file)) logger.info("All scheduler states loaded successfully") for i, dataloader in enumerate(dataloaders): sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin" - input_sampler_file = os.path.join(input_dir, sampler_name) + input_sampler_file = input_dir.joinpath(sampler_name) # Only load if we have our custom sampler from .data_loader import IterableDatasetShard, SeedableRandomSampler @@ -211,13 +235,13 @@ def load_accelerator_state( # GradScaler state if scaler is not None: - input_scaler_file = os.path.join(input_dir, SCALER_NAME) + input_scaler_file = input_dir.joinpath(SCALER_NAME) scaler.load_state_dict(torch.load(input_scaler_file)) logger.info("GradScaler state loaded successfully") # Random states try: - states = torch.load(os.path.join(input_dir, f"{RNG_STATE_NAME}_{process_index}.pkl")) + states = torch.load(input_dir.joinpath(f"{RNG_STATE_NAME}_{process_index}.pkl")) random.setstate(states["random_state"]) np.random.set_state(states["numpy_random_seed"]) torch.set_rng_state(states["torch_manual_seed"]) diff --git a/src/accelerate/test_utils/__init__.py b/src/accelerate/test_utils/__init__.py index f716bf6d25d..b81cea46cc0 100644 --- a/src/accelerate/test_utils/__init__.py +++ b/src/accelerate/test_utils/__init__.py @@ -9,7 +9,6 @@ require_mps, require_multi_gpu, require_multi_xpu, - require_safetensors, require_single_gpu, require_single_xpu, require_torch_min_version, diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index bb887285954..d6d1e2f2f0a 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -37,7 +37,6 @@ is_deepspeed_available, is_mps_available, is_pandas_available, - is_safetensors_available, is_tensorboard_available, is_timm_available, is_torch_version, @@ -179,14 +178,6 @@ def require_multi_xpu(test_case): return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case) -def require_safetensors(test_case): - """ - Decorator marking a test that requires safetensors installed. These tests are skipped when safetensors isn't - installed - """ - return unittest.skipUnless(is_safetensors_available(), "test requires safetensors")(test_case) - - def require_deepspeed(test_case): """ Decorator marking a test that requires DeepSpeed installed. These tests are skipped when DeepSpeed isn't installed diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 317204fc64c..96e3fe61035 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -2,6 +2,7 @@ MODEL_NAME, OPTIMIZER_NAME, RNG_STATE_NAME, + SAFE_MODEL_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAMPLER_NAME, @@ -59,7 +60,6 @@ is_npu_available, is_pandas_available, is_rich_available, - is_safetensors_available, is_sagemaker_available, is_tensorboard_available, is_timm_available, diff --git a/src/accelerate/utils/constants.py b/src/accelerate/utils/constants.py index 638a8ea4529..843eb5756af 100644 --- a/src/accelerate/utils/constants.py +++ b/src/accelerate/utils/constants.py @@ -17,14 +17,15 @@ SCALER_NAME = "scaler.pt" MODEL_NAME = "pytorch_model" +SAFE_MODEL_NAME = "model" RNG_STATE_NAME = "random_states" OPTIMIZER_NAME = "optimizer" SCHEDULER_NAME = "scheduler" SAMPLER_NAME = "sampler" -WEIGHTS_NAME = "pytorch_model.bin" -WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" -SAFE_WEIGHTS_NAME = "model.safetensors" -SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" +WEIGHTS_NAME = f"{MODEL_NAME}.bin" +WEIGHTS_INDEX_NAME = f"{WEIGHTS_NAME}.index.json" +SAFE_WEIGHTS_NAME = f"{SAFE_MODEL_NAME}.safetensors" +SAFE_WEIGHTS_INDEX_NAME = f"{SAFE_WEIGHTS_NAME}.index.json" SAGEMAKER_PYTORCH_VERSION = "1.10.2" SAGEMAKER_PYTHON_VERSION = "py38" SAGEMAKER_TRANSFORMERS_VERSION = "4.17.0" diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 3ff167aecf6..9a60233c96c 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -151,10 +151,6 @@ def is_megatron_lm_available(): return False -def is_safetensors_available(): - return _is_package_available("safetensors") - - def is_transformers_available(): return _is_package_available("transformers") diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 2f5ac451f6f..058b4c855b1 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -30,7 +30,7 @@ from ..state import AcceleratorState from .constants import SAFE_WEIGHTS_NAME, WEIGHTS_NAME from .dataclasses import AutocastKwargs, CustomDtype, DistributedType -from .imports import is_mps_available, is_npu_available, is_safetensors_available, is_xpu_available +from .imports import is_mps_available, is_npu_available, is_xpu_available from .offload import load_offloaded_weight, offload_weight, save_offload_index from .tqdm import is_tqdm_available, tqdm @@ -39,9 +39,9 @@ import torch_npu # noqa: F401 -if is_safetensors_available(): - from safetensors import safe_open - from safetensors.torch import load_file as safe_load_file +from safetensors import safe_open +from safetensors.torch import load_file as safe_load_file + WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" @@ -1156,10 +1156,6 @@ def load_state_dict(checkpoint_file, device_map=None): name, once a given module name is inside, every submodule of it will be sent to the same device. """ if checkpoint_file.endswith(".safetensors"): - if not is_safetensors_available(): - raise ImportError( - f"To load {checkpoint_file}, the `safetensors` library is necessary `pip install safetensors`." - ) with safe_open(checkpoint_file, framework="pt") as f: metadata = f.metadata() weight_names = f.keys() diff --git a/src/accelerate/utils/offload.py b/src/accelerate/utils/offload.py index ca6efc080dc..e6c28c02001 100644 --- a/src/accelerate/utils/offload.py +++ b/src/accelerate/utils/offload.py @@ -19,8 +19,7 @@ import numpy as np import torch - -from .imports import is_safetensors_available +from safetensors import safe_open def offload_weight(weight, weight_name, offload_folder, index=None): @@ -165,11 +164,6 @@ def __getitem__(self, key: str): return self.state_dict[key] weight_info = self.index[key] if weight_info.get("safetensors_file") is not None: - if not is_safetensors_available(): - raise ImportError("These offloaded weights require the use of safetensors: `pip install safetensors`.") - - from safetensors import safe_open - device = "cpu" if self.device is None else self.device with safe_open(weight_info["safetensors_file"], framework="pt", device=device) as f: tensor = f.get_tensor(weight_info.get("weight_name", key)) diff --git a/src/accelerate/utils/other.py b/src/accelerate/utils/other.py index a1e2c3bb2ce..285dd0a5ad9 100644 --- a/src/accelerate/utils/other.py +++ b/src/accelerate/utils/other.py @@ -22,13 +22,14 @@ import torch from packaging.version import Version +from safetensors.torch import save_file as safe_save_file from ..commands.config.default import write_basic_config # noqa: F401 from ..logging import get_logger from ..state import PartialState from .constants import FSDP_PYTORCH_VERSION from .dataclasses import DistributedType -from .imports import is_deepspeed_available, is_safetensors_available, is_torch_distributed_available, is_tpu_available +from .imports import is_deepspeed_available, is_torch_distributed_available, is_tpu_available from .transformer_engine import convert_model from .versions import is_torch_version @@ -39,9 +40,6 @@ if is_tpu_available(check_device=False): import torch_xla.core.xla_model as xm -if is_safetensors_available(): - from safetensors.torch import save_file as safe_save_file - def is_compiled_module(module): """ @@ -117,7 +115,7 @@ def wait_for_everyone(): PartialState().wait_for_everyone() -def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = False): +def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = True): """ Save the data to disk. Use in place of `torch.save()`. @@ -128,8 +126,8 @@ def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = Fal The file (or file-like object) to use to save the data save_on_each_node (`bool`, *optional*, defaults to `False`): Whether to only save on the global main process - safe_serialization (`bool`, *optional*, defaults to `False`): - Whether to save `obj` using `safetensors` + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). """ save_func = torch.save if not safe_serialization else partial(safe_save_file, metadata={"format": "pt"}) if PartialState().distributed_type == DistributedType.TPU: diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 610995cc527..316a9bb41de 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -5,12 +5,13 @@ from unittest.mock import patch import torch +from parameterized import parameterized from torch.utils.data import DataLoader, TensorDataset from accelerate import DistributedType, infer_auto_device_map, init_empty_weights from accelerate.accelerator import Accelerator from accelerate.state import GradientState, PartialState -from accelerate.test_utils import require_bnb, require_multi_gpu, require_safetensors, slow +from accelerate.test_utils import require_bnb, require_multi_gpu, slow from accelerate.test_utils.testing import AccelerateTestCase, require_cuda from accelerate.utils import patch_environment from accelerate.utils.modeling import load_checkpoint_in_model @@ -35,6 +36,13 @@ def load_random_weights(model): model.load_state_dict(state) +def parameterized_custom_name_func(func, param_num, param): + # customize the test name generator function as we want both params to appear in the sub-test + # name, as by default it shows only the first param + param_based_name = "use_safetensors" if param.args[0] is True else "use_pytorch" + return f"{func.__name__}_{param_based_name}" + + class AcceleratorTester(AccelerateTestCase): @require_cuda def test_accelerator_can_be_reinstantiated(self): @@ -97,7 +105,8 @@ def noop(*args, **kwargs): accelerator = Accelerator() self.assertEqual(str(accelerator.state.device), "cuda:64") - def test_save_load_model(self): + @parameterized.expand([True, False], name_func=parameterized_custom_name_func) + def test_save_load_model(self, use_safetensors): accelerator = Accelerator() model, optimizer, scheduler, train_dl, valid_dl = create_components() accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl) @@ -105,7 +114,7 @@ def test_save_load_model(self): model_signature = get_signature(model) with tempfile.TemporaryDirectory() as tmpdirname: - accelerator.save_state(tmpdirname) + accelerator.save_state(tmpdirname, safe_serialization=use_safetensors) # make sure random weights don't match load_random_weights(model) @@ -115,31 +124,20 @@ def test_save_load_model(self): accelerator.load_state(tmpdirname) self.assertTrue(abs(model_signature - get_signature(model)) < 1e-3) - def test_save_model_pytorch(self): - accelerator = Accelerator() - model = torch.nn.Linear(10, 10) - - model_signature = get_signature(model) - with tempfile.TemporaryDirectory() as tmpdirname: - accelerator.save_model(model, tmpdirname, safe_serialization=False) - # make sure loaded weights match - load_checkpoint_in_model(model, tmpdirname) - self.assertTrue(abs(model_signature - get_signature(model)) < 1e-3) - - @require_safetensors - def test_save_model_safetensors(self): + @parameterized.expand([True, False], name_func=parameterized_custom_name_func) + def test_save_model(self, use_safetensors): accelerator = Accelerator() model = torch.nn.Linear(10, 10) model_signature = get_signature(model) with tempfile.TemporaryDirectory() as tmpdirname: - accelerator.save_model(model, tmpdirname, safe_serialization=True) - + accelerator.save_model(model, tmpdirname, safe_serialization=use_safetensors) # make sure loaded weights match load_checkpoint_in_model(model, tmpdirname) self.assertTrue(abs(model_signature - get_signature(model)) < 1e-3) - def test_save_load_model_with_hooks(self): + @parameterized.expand([True, False], name_func=parameterized_custom_name_func) + def test_save_load_model_with_hooks(self, use_safetensors): accelerator = Accelerator() model, optimizer, scheduler, train_dl, valid_dl = create_components() accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl) @@ -164,7 +162,7 @@ def load_config(models, input_dir): load_hook = accelerator.register_load_state_pre_hook(load_config) with tempfile.TemporaryDirectory() as tmpdirname: - accelerator.save_state(tmpdirname) + accelerator.save_state(tmpdirname, safe_serialization=use_safetensors) # make sure random weights don't match with hooks load_random_weights(model) @@ -185,7 +183,7 @@ def load_config(models, input_dir): load_hook.remove() with tempfile.TemporaryDirectory() as tmpdirname: - accelerator.save_state(tmpdirname) + accelerator.save_state(tmpdirname, safe_serialization=use_safetensors) # make sure random weights don't match with hooks removed load_random_weights(model) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 2b750c2050c..7f7bf4c613a 100644 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -20,9 +20,10 @@ import torch import torch.nn as nn +from safetensors.torch import save_file from accelerate import init_empty_weights -from accelerate.test_utils import require_cuda, require_huggingface_suite, require_multi_gpu, require_safetensors +from accelerate.test_utils import require_cuda, require_huggingface_suite, require_multi_gpu from accelerate.utils.modeling import ( check_device_map, clean_device_map, @@ -552,10 +553,7 @@ def test_get_balanced_memory(self): self.assertDictEqual({0: 0, "cpu": 100}, max_memory) @require_cuda - @require_safetensors def test_load_state_dict(self): - from safetensors.torch import save_file - state_dict = {k: torch.randn(4, 5) for k in ["a", "b", "c"]} device_maps = [{"a": "cpu", "b": 0, "c": "disk"}, {"a": 0, "b": 0, "c": "disk"}, {"a": 0, "b": 0, "c": 0}] diff --git a/tests/test_state_checkpointing.py b/tests/test_state_checkpointing.py index 87bfff150d3..052831d8af3 100644 --- a/tests/test_state_checkpointing.py +++ b/tests/test_state_checkpointing.py @@ -24,6 +24,7 @@ import pytest import torch +from parameterized import parameterized_class from torch import nn from torch.utils.data import DataLoader, TensorDataset @@ -80,6 +81,14 @@ def forward(self, x): return x * self.a + self.b +def parameterized_custom_name_func(func, param_num, param): + # customize the test name generator function as we want both params to appear in the sub-test + # name, as by default it shows only the first param + param_based_name = "use_safetensors" if param["use_safetensors"] is True else "use_pytorch" + return f"{func.__name__}_{param_based_name}" + + +@parameterized_class(("use_safetensors",), [[True], [False]], class_name_func=parameterized_custom_name_func) class CheckpointTest(unittest.TestCase): def test_with_save_limit(self): with tempfile.TemporaryDirectory() as tmpdir: @@ -94,10 +103,10 @@ def test_with_save_limit(self): model, optimizer, train_dataloader, valid_dataloader ) # Save initial - accelerator.save_state() + accelerator.save_state(safe_serialization=self.use_safetensors) # Save second state - accelerator.save_state() + accelerator.save_state(safe_serialization=self.use_safetensors) self.assertEqual(len(os.listdir(accelerator.project_dir)), 1) def test_can_resume_training_with_folder(self): @@ -113,7 +122,7 @@ def test_can_resume_training_with_folder(self): ) # Save initial initial = os.path.join(tmpdir, "initial") - accelerator.save_state(initial) + accelerator.save_state(initial, safe_serialization=self.use_safetensors) (a, b) = model.a.item(), model.b.item() opt_state = optimizer.state_dict() ground_truth_rands = train(3, model, train_dataloader, optimizer, accelerator) @@ -139,7 +148,7 @@ def test_can_resume_training_with_folder(self): test_rands = train(2, model, train_dataloader, optimizer, accelerator) # Save everything checkpoint = os.path.join(tmpdir, "checkpoint") - accelerator.save_state(checkpoint) + accelerator.save_state(checkpoint, safe_serialization=self.use_safetensors) # Load everything back in and make sure all states work accelerator.load_state(checkpoint) @@ -165,7 +174,7 @@ def test_can_resume_training(self): model, optimizer, train_dataloader, valid_dataloader ) # Save initial - accelerator.save_state() + accelerator.save_state(safe_serialization=self.use_safetensors) (a, b) = model.a.item(), model.b.item() opt_state = optimizer.state_dict() ground_truth_rands = train(3, model, train_dataloader, optimizer, accelerator) @@ -191,7 +200,7 @@ def test_can_resume_training(self): test_rands = train(2, model, train_dataloader, optimizer, accelerator) # Save everything - accelerator.save_state() + accelerator.save_state(safe_serialization=self.use_safetensors) # Load everything back in and make sure all states work accelerator.load_state(os.path.join(tmpdir, "checkpoints", "checkpoint_1")) @@ -230,7 +239,7 @@ def temporary_relative_directory(): model, optimizer, train_dataloader, valid_dataloader ) # Save initial - accelerator.save_state() + accelerator.save_state(safe_serialization=self.use_safetensors) (a, b) = model.a.item(), model.b.item() opt_state = optimizer.state_dict() ground_truth_rands = train(3, model, train_dataloader, optimizer, accelerator) @@ -256,7 +265,7 @@ def temporary_relative_directory(): test_rands = train(2, model, train_dataloader, optimizer, accelerator) # Save everything - accelerator.save_state() + accelerator.save_state(safe_serialization=self.use_safetensors) # Load everything back in and make sure all states work accelerator.load_state(os.path.join(tmpdir, "checkpoints", "checkpoint_1")) @@ -296,7 +305,7 @@ def test_with_scheduler(self): model, optimizer, train_dataloader, valid_dataloader, scheduler ) # Save initial - accelerator.save_state() + accelerator.save_state(safe_serialization=self.use_safetensors) scheduler_state = scheduler.state_dict() train(3, model, train_dataloader, optimizer, accelerator, scheduler) self.assertNotEqual(scheduler_state, scheduler.state_dict()) @@ -319,11 +328,11 @@ def test_automatic_loading(self): model, optimizer, train_dataloader, valid_dataloader, scheduler ) # Save initial - accelerator.save_state() + accelerator.save_state(safe_serialization=self.use_safetensors) train(2, model, train_dataloader, optimizer, accelerator, scheduler) (a2, b2) = model.a.item(), model.b.item() # Save a first time - accelerator.save_state() + accelerator.save_state(safe_serialization=self.use_safetensors) train(1, model, train_dataloader, optimizer, accelerator, scheduler) (a3, b3) = model.a.item(), model.b.item() @@ -344,7 +353,7 @@ def test_checkpoint_deletion(self): model = accelerator.prepare(model) # Save 3 states: for _ in range(11): - accelerator.save_state() + accelerator.save_state(safe_serialization=self.use_safetensors) self.assertTrue(not os.path.exists(os.path.join(tmpdir, "checkpoints", "checkpoint_0"))) self.assertTrue(os.path.exists(os.path.join(tmpdir, "checkpoints", "checkpoint_9"))) self.assertTrue(os.path.exists(os.path.join(tmpdir, "checkpoints", "checkpoint_10"))) @@ -352,10 +361,14 @@ def test_checkpoint_deletion(self): @require_cuda def test_map_location(self): cmd = ["torchrun", f"--nproc_per_node={torch.cuda.device_count()}", inspect.getfile(self.__class__)] - execute_subprocess_async(cmd, env=os.environ.copy()) + env = os.environ.copy() + env["USE_SAFETENSORS"] = str(self.use_safetensors) + env["OMP_NUM_THREADS"] = "1" + execute_subprocess_async(cmd, env=env) if __name__ == "__main__": + use_safetensors = os.environ.get("USE_SAFETENSORS", "False") == "True" savedir = "/tmp/accelerate/state_checkpointing" model = DummyModel() optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3) @@ -380,7 +393,7 @@ def test_map_location(self): assert param_device.type == accelerator.device.type model = model.cpu() accelerator.wait_for_everyone() - accelerator.save_state() + accelerator.save_state(safe_serialization=use_safetensors) accelerator.wait_for_everyone() # Check CPU state