Skip to content

Commit

Permalink
Make safetensors the default (#2120)
Browse files Browse the repository at this point in the history
* Make safetensors default

* Rm location

* Actually flip flags

* Tests + update checkpointing

* Add to setup

* Start of tests with both safetensors and without

* Update tests to use both

* Remove from load state

* Explicit tip

* With suggestions

* Simplify, don't abstract. Need to bring back to deepspeed however

* Refactor to use consts

* Keep how it was

* Typo fix
  • Loading branch information
muellerzr authored Nov 8, 2023
1 parent 76de60d commit e638b1e
Show file tree
Hide file tree
Showing 14 changed files with 115 additions and 108 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
]
},
python_requires=">=3.8.0",
install_requires=["numpy>=1.17", "packaging>=20.0", "psutil", "pyyaml", "torch>=1.10.0", "huggingface_hub"],
install_requires=["numpy>=1.17", "packaging>=20.0", "psutil", "pyyaml", "torch>=1.10.0", "huggingface_hub", "safetensors>=0.3.1"],
extras_require=extras,
classifiers=[
"Development Status :: 5 - Production/Stable",
Expand Down
13 changes: 6 additions & 7 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@
is_ipex_available,
is_megatron_lm_available,
is_npu_available,
is_safetensors_available,
is_torch_version,
is_tpu_available,
is_xpu_available,
Expand Down Expand Up @@ -2536,7 +2535,7 @@ def save_model(
model: torch.nn.Module,
save_directory: Union[str, os.PathLike],
max_shard_size: Union[int, str] = "10GB",
safe_serialization: bool = False,
safe_serialization: bool = True,
):
"""
Save a model so that it can be re-loaded using load_checkpoint_in_model
Expand All @@ -2557,7 +2556,7 @@ def save_model(
</Tip>
safe_serialization (`bool`, *optional*, defaults to `False`):
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
Example:
Expand All @@ -2571,9 +2570,6 @@ def save_model(
```
"""

if safe_serialization and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")

if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
Expand Down Expand Up @@ -2690,7 +2686,7 @@ def register_save_state_pre_hook(self, hook: Callable[..., None]) -> hooks.Remov
self._save_model_state_pre_hook[handle.id] = hook
return handle

def save_state(self, output_dir: str = None, **save_model_func_kwargs):
def save_state(self, output_dir: str = None, safe_serialization: bool = True, **save_model_func_kwargs):
"""
Saves the current states of the model, optimizer, scaler, RNG generators, and registered objects to a folder.
Expand All @@ -2711,6 +2707,8 @@ def save_state(self, output_dir: str = None, **save_model_func_kwargs):
Args:
output_dir (`str` or `os.PathLike`):
The name of the folder to save all relevant weights and states.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
save_model_func_kwargs (`dict`, *optional*):
Additional keyword arguments for saving model which can be passed to the underlying save function, such
as optional arguments for DeepSpeed's `save_checkpoint` function.
Expand Down Expand Up @@ -2815,6 +2813,7 @@ def _inner(folder):
self.state.process_index,
self.scaler,
save_on_each_node=self.project_configuration.save_on_each_node,
safe_serialization=safe_serialization,
)
for i, obj in enumerate(self._custom_objects):
save_custom_state(obj, output_dir, i, save_on_each_node=self.project_configuration.save_on_each_node)
Expand Down
64 changes: 44 additions & 20 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import random
from pathlib import Path
from typing import List

import numpy as np
import torch
from safetensors.torch import load_file
from torch.cuda.amp import GradScaler

from .utils import (
MODEL_NAME,
OPTIMIZER_NAME,
RNG_STATE_NAME,
SAFE_MODEL_NAME,
SAFE_WEIGHTS_NAME,
SAMPLER_NAME,
SCALER_NAME,
SCHEDULER_NAME,
WEIGHTS_NAME,
get_pretty_name,
is_tpu_available,
is_xpu_available,
Expand All @@ -54,10 +57,18 @@ def save_accelerator_state(
process_index: int,
scaler: GradScaler = None,
save_on_each_node: bool = False,
safe_serialization: bool = True,
):
"""
Saves the current states of the models, optimizers, scaler, and RNG generators to a given directory.
<Tip>
If `safe_serialization` is `True`, models will be saved with `safetensors` while the rest are saved using native
`pickle`.
</Tip>
Args:
output_dir (`str` or `os.PathLike`):
The name of the folder to save all relevant weights and states.
Expand All @@ -75,45 +86,50 @@ def save_accelerator_state(
An optional gradient scaler instance to save
save_on_each_node (`bool`, *optional*):
Whether to save on every node, or only the main node.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
"""
output_dir = Path(output_dir)
# Model states
for i, state in enumerate(model_states):
weights_name = f"{MODEL_NAME}.bin" if i == 0 else f"{MODEL_NAME}_{i}.bin"
output_model_file = os.path.join(output_dir, weights_name)
save(state, output_model_file, save_on_each_node=save_on_each_node)
weights_name = WEIGHTS_NAME if not safe_serialization else SAFE_WEIGHTS_NAME
if i > 0:
weights_name = weights_name.replace(".", f"_{i}.")
output_model_file = output_dir.joinpath(weights_name)
save(state, output_model_file, save_on_each_node=save_on_each_node, safe_serialization=safe_serialization)
logger.info(f"Model weights saved in {output_model_file}")
# Optimizer states
for i, opt in enumerate(optimizers):
state = opt.state_dict()
optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
output_optimizer_file = os.path.join(output_dir, optimizer_name)
save(state, output_optimizer_file, save_on_each_node=save_on_each_node)
output_optimizer_file = output_dir.joinpath(optimizer_name)
save(state, output_optimizer_file, save_on_each_node=save_on_each_node, safe_serialization=False)
logger.info(f"Optimizer state saved in {output_optimizer_file}")
# Scheduler states
for i, scheduler in enumerate(schedulers):
state = scheduler.state_dict()
scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
output_scheduler_file = os.path.join(output_dir, scheduler_name)
save(state, output_scheduler_file, save_on_each_node=save_on_each_node)
output_scheduler_file = output_dir.joinpath(scheduler_name)
save(state, output_scheduler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
logger.info(f"Scheduler state saved in {output_scheduler_file}")
# DataLoader states
for i, dataloader in enumerate(dataloaders):
sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
output_sampler_file = os.path.join(output_dir, sampler_name)
output_sampler_file = output_dir.joinpath(sampler_name)
# Only save if we have our custom sampler
from .data_loader import IterableDatasetShard, SeedableRandomSampler

if isinstance(dataloader.dataset, IterableDatasetShard):
sampler = dataloader.sampler.sampler

if isinstance(sampler, SeedableRandomSampler):
save(sampler, output_sampler_file, save_on_each_node=save_on_each_node)
save(sampler, output_sampler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}")

# GradScaler state
if scaler is not None:
state = scaler.state_dict()
output_scaler_file = os.path.join(output_dir, SCALER_NAME)
output_scaler_file = output_dir.joinpath(SCALER_NAME)
torch.save(state, output_scaler_file)
logger.info(f"Gradient scaler state saved in {output_scaler_file}")
# Random number generator states
Expand All @@ -128,7 +144,7 @@ def save_accelerator_state(
states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all()
if is_tpu_available():
states["xm_seed"] = xm.get_rng_state()
output_states_file = os.path.join(output_dir, states_name)
output_states_file = output_dir.joinpath(states_name)
torch.save(states, output_states_file)
logger.info(f"Random states saved in {output_states_file}")
return output_dir
Expand Down Expand Up @@ -174,31 +190,39 @@ def load_accelerator_state(
map_location = "cpu"
elif map_location == "on_device":
map_location = PartialState().device

input_dir = Path(input_dir)
# Model states
for i, model in enumerate(models):
weights_name = f"{MODEL_NAME}.bin" if i == 0 else f"{MODEL_NAME}_{i}.bin"
input_model_file = os.path.join(input_dir, weights_name)
models[i].load_state_dict(torch.load(input_model_file, map_location=map_location), **load_model_func_kwargs)
ending = f"_{i}" if i > 0 else ""
input_model_file = input_dir.joinpath(f"{SAFE_MODEL_NAME}{ending}.safetensors")
if input_model_file.exists():
state_dict = load_file(input_model_file, device=str(map_location))
else:
# Load with torch
input_model_file = input_dir.joinpath(f"{MODEL_NAME}{ending}.bin")
state_dict = torch.load(input_model_file, map_location=map_location)
models[i].load_state_dict(state_dict, **load_model_func_kwargs)
logger.info("All model weights loaded successfully")

# Optimizer states
for i, opt in enumerate(optimizers):
optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
input_optimizer_file = os.path.join(input_dir, optimizer_name)
input_optimizer_file = input_dir.joinpath(optimizer_name)
optimizer_state = torch.load(input_optimizer_file, map_location=map_location)
optimizers[i].load_state_dict(optimizer_state)
logger.info("All optimizer states loaded successfully")

# Scheduler states
for i, scheduler in enumerate(schedulers):
scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
input_scheduler_file = os.path.join(input_dir, scheduler_name)
input_scheduler_file = input_dir.joinpath(scheduler_name)
scheduler.load_state_dict(torch.load(input_scheduler_file))
logger.info("All scheduler states loaded successfully")

for i, dataloader in enumerate(dataloaders):
sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
input_sampler_file = os.path.join(input_dir, sampler_name)
input_sampler_file = input_dir.joinpath(sampler_name)
# Only load if we have our custom sampler
from .data_loader import IterableDatasetShard, SeedableRandomSampler

Expand All @@ -211,13 +235,13 @@ def load_accelerator_state(

# GradScaler state
if scaler is not None:
input_scaler_file = os.path.join(input_dir, SCALER_NAME)
input_scaler_file = input_dir.joinpath(SCALER_NAME)
scaler.load_state_dict(torch.load(input_scaler_file))
logger.info("GradScaler state loaded successfully")

# Random states
try:
states = torch.load(os.path.join(input_dir, f"{RNG_STATE_NAME}_{process_index}.pkl"))
states = torch.load(input_dir.joinpath(f"{RNG_STATE_NAME}_{process_index}.pkl"))
random.setstate(states["random_state"])
np.random.set_state(states["numpy_random_seed"])
torch.set_rng_state(states["torch_manual_seed"])
Expand Down
1 change: 0 additions & 1 deletion src/accelerate/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
require_mps,
require_multi_gpu,
require_multi_xpu,
require_safetensors,
require_single_gpu,
require_single_xpu,
require_torch_min_version,
Expand Down
9 changes: 0 additions & 9 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
is_deepspeed_available,
is_mps_available,
is_pandas_available,
is_safetensors_available,
is_tensorboard_available,
is_timm_available,
is_torch_version,
Expand Down Expand Up @@ -179,14 +178,6 @@ def require_multi_xpu(test_case):
return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case)


def require_safetensors(test_case):
"""
Decorator marking a test that requires safetensors installed. These tests are skipped when safetensors isn't
installed
"""
return unittest.skipUnless(is_safetensors_available(), "test requires safetensors")(test_case)


def require_deepspeed(test_case):
"""
Decorator marking a test that requires DeepSpeed installed. These tests are skipped when DeepSpeed isn't installed
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
MODEL_NAME,
OPTIMIZER_NAME,
RNG_STATE_NAME,
SAFE_MODEL_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
SAMPLER_NAME,
Expand Down Expand Up @@ -59,7 +60,6 @@
is_npu_available,
is_pandas_available,
is_rich_available,
is_safetensors_available,
is_sagemaker_available,
is_tensorboard_available,
is_timm_available,
Expand Down
9 changes: 5 additions & 4 deletions src/accelerate/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

SCALER_NAME = "scaler.pt"
MODEL_NAME = "pytorch_model"
SAFE_MODEL_NAME = "model"
RNG_STATE_NAME = "random_states"
OPTIMIZER_NAME = "optimizer"
SCHEDULER_NAME = "scheduler"
SAMPLER_NAME = "sampler"
WEIGHTS_NAME = "pytorch_model.bin"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
SAFE_WEIGHTS_NAME = "model.safetensors"
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
WEIGHTS_NAME = f"{MODEL_NAME}.bin"
WEIGHTS_INDEX_NAME = f"{WEIGHTS_NAME}.index.json"
SAFE_WEIGHTS_NAME = f"{SAFE_MODEL_NAME}.safetensors"
SAFE_WEIGHTS_INDEX_NAME = f"{SAFE_WEIGHTS_NAME}.index.json"
SAGEMAKER_PYTORCH_VERSION = "1.10.2"
SAGEMAKER_PYTHON_VERSION = "py38"
SAGEMAKER_TRANSFORMERS_VERSION = "4.17.0"
Expand Down
4 changes: 0 additions & 4 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,6 @@ def is_megatron_lm_available():
return False


def is_safetensors_available():
return _is_package_available("safetensors")


def is_transformers_available():
return _is_package_available("transformers")

Expand Down
12 changes: 4 additions & 8 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ..state import AcceleratorState
from .constants import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from .dataclasses import AutocastKwargs, CustomDtype, DistributedType
from .imports import is_mps_available, is_npu_available, is_safetensors_available, is_xpu_available
from .imports import is_mps_available, is_npu_available, is_xpu_available
from .offload import load_offloaded_weight, offload_weight, save_offload_index
from .tqdm import is_tqdm_available, tqdm

Expand All @@ -39,9 +39,9 @@
import torch_npu # noqa: F401


if is_safetensors_available():
from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file
from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file


WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"

Expand Down Expand Up @@ -1156,10 +1156,6 @@ def load_state_dict(checkpoint_file, device_map=None):
name, once a given module name is inside, every submodule of it will be sent to the same device.
"""
if checkpoint_file.endswith(".safetensors"):
if not is_safetensors_available():
raise ImportError(
f"To load {checkpoint_file}, the `safetensors` library is necessary `pip install safetensors`."
)
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
weight_names = f.keys()
Expand Down
8 changes: 1 addition & 7 deletions src/accelerate/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

import numpy as np
import torch

from .imports import is_safetensors_available
from safetensors import safe_open


def offload_weight(weight, weight_name, offload_folder, index=None):
Expand Down Expand Up @@ -165,11 +164,6 @@ def __getitem__(self, key: str):
return self.state_dict[key]
weight_info = self.index[key]
if weight_info.get("safetensors_file") is not None:
if not is_safetensors_available():
raise ImportError("These offloaded weights require the use of safetensors: `pip install safetensors`.")

from safetensors import safe_open

device = "cpu" if self.device is None else self.device
with safe_open(weight_info["safetensors_file"], framework="pt", device=device) as f:
tensor = f.get_tensor(weight_info.get("weight_name", key))
Expand Down
Loading

0 comments on commit e638b1e

Please sign in to comment.