Skip to content

Commit

Permalink
Deal with shared memory scenarios (#2136)
Browse files Browse the repository at this point in the history
* Deal with duplicates

* refactor

* Keep false for save

* Clean

* Better test for logs
  • Loading branch information
muellerzr authored Nov 10, 2023
1 parent 8256a9c commit fc0a43c
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 32 deletions.
33 changes: 2 additions & 31 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from __future__ import annotations

import collections
import contextlib
import functools
import json
Expand Down Expand Up @@ -64,6 +63,7 @@
RNGType,
TorchDynamoPlugin,
check_os_kernel,
clean_state_dict_for_safetensors,
compare_versions,
convert_model,
convert_outputs_to_fp32,
Expand All @@ -73,7 +73,6 @@
get_mixed_precision_context_manager,
get_pretty_name,
has_transformer_engine_layers,
id_tensor_storage,
is_bf16_available,
is_deepspeed_available,
is_fp8_available,
Expand Down Expand Up @@ -2583,35 +2582,7 @@ def save_model(
state_dict = self.get_state_dict(model)

if safe_serialization:
# Safetensors does not allow tensor aliasing.
# We're going to remove aliases before saving
ptrs = collections.defaultdict(list)
# when bnb serialization is used the weights in the state dict can be strings
for name, tensor in state_dict.items():
if not isinstance(tensor, str):
ptrs[id_tensor_storage(tensor)].append(name)

# These are all the pointers of shared tensors.
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
warn_names = set()
for names in shared_ptrs.values():
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
# If the link between tensors was done at runtime then `from_pretrained` will not get
# the key back leading to random tensor. A proper warning will be shown
# during reload (if applicable), but since the file is not necessarily compatible with
# the config, better show a proper warning.
found = 0
for name in names:
if name in state_dict:
found += 1
if found > 1:
del state_dict[name]
warn_names.add(name)
if len(warn_names) > 0:
logger.warning(
f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading",
)

state_dict = clean_state_dict_for_safetensors(state_dict)
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME

# Shard the model if it is too big.
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@
from .memory import find_executable_batch_size, release_memory
from .other import (
check_os_kernel,
clean_state_dict_for_safetensors,
clear_environment,
convert_bytes,
extract_model_from_parallel,
Expand Down
47 changes: 46 additions & 1 deletion src/accelerate/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import os
import platform
import re
import socket
from contextlib import contextmanager
from functools import partial
from types import MethodType
from typing import OrderedDict

import torch
from packaging.version import Version
Expand All @@ -30,6 +32,7 @@
from .constants import FSDP_PYTORCH_VERSION
from .dataclasses import DistributedType
from .imports import is_deepspeed_available, is_torch_distributed_available, is_tpu_available
from .modeling import id_tensor_storage
from .transformer_engine import convert_model
from .versions import is_torch_version

Expand Down Expand Up @@ -115,6 +118,41 @@ def wait_for_everyone():
PartialState().wait_for_everyone()


def clean_state_dict_for_safetensors(state_dict: dict):
"""
Cleans the state dictionary from a model and removes tensor aliasing if present.
Args:
state_dict (`dict`):
The state dictionary from a model
"""
ptrs = collections.defaultdict(list)
# When bnb serialization is used, weights in state dict can eb strings
for name, tensor in state_dict.items():
if not isinstance(tensor, str):
ptrs[id_tensor_storage(tensor)].append(name)

# These are all pointers of tensors with shared memory
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
warn_names = set()
for names in shared_ptrs.values():
# When not all duplicates have been cleaned, we still remove those keys but put a clear warning.
# If the link between tensors was done at runtime then `from_pretrained` will not get
# the key back leading to random tensor. A proper warning will be shown
# during reload (if applicable), but since the file is not necessarily compatible with
# the config, better show a proper warning.
found_names = [name for name in names if name in state_dict]
warn_names.update(found_names[1:])
for name in found_names[1:]:
del state_dict[name]
if len(warn_names) > 0:
logger.warning(
f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading",
)
state_dict = {k: v.contiguous() for k, v in state_dict.items()}
return state_dict


def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = False):
"""
Save the data to disk. Use in place of `torch.save()`.
Expand All @@ -129,7 +167,14 @@ def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = Fal
safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save `obj` using `safetensors` or the traditional PyTorch way (that uses `pickle`).
"""
save_func = torch.save if not safe_serialization else partial(safe_save_file, metadata={"format": "pt"})
# Check if it's a model and remove duplicates
if safe_serialization:
save_func = partial(safe_save_file, metadata={"format": "pt"})
if isinstance(obj, OrderedDict):
obj = clean_state_dict_for_safetensors(obj)
else:
save_func = torch.save

if PartialState().distributed_type == DistributedType.TPU:
xm.save(obj, f)
elif PartialState().is_main_process and not save_on_each_node:
Expand Down
21 changes: 21 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

import os
import pickle
import tempfile
import unittest
import warnings
from collections import UserDict, namedtuple
from unittest.mock import Mock, patch

import torch
from torch import nn

from accelerate.state import PartialState
from accelerate.test_utils.testing import require_cuda, require_torch_min_version
Expand All @@ -32,6 +34,7 @@
listify,
patch_environment,
recursively_apply,
save,
send_to_device,
)

Expand Down Expand Up @@ -205,3 +208,21 @@ def test_check_os_kernel_warning_when_release_lt_min(self):
self.assertEqual(ctx.records[0].levelname, "WARNING")
self.assertIn("5.4.0", ctx.records[0].msg)
self.assertIn("5.5.0", ctx.records[0].msg)

def test_save_safetensor_shared_memory(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(100, 100)
self.b = self.a

def forward(self, x):
return self.b(self.a(x))

model = Model()
with tempfile.TemporaryDirectory() as tmp_dir:
save_path = os.path.join(tmp_dir, "model.safetensors")
with self.assertLogs(level="WARNING") as log:
save(model.state_dict(), save_path, safe_serialization=True)
self.assertEqual(len(log.records), 1)
self.assertIn("Removed shared tensor", log.output[0])

0 comments on commit fc0a43c

Please sign in to comment.