From e71d2570fa1facf5c01fc8009ee5381d123cea6d Mon Sep 17 00:00:00 2001 From: Min Xu <24926999+min-xu-ai@users.noreply.github.com> Date: Sat, 24 Sep 2022 12:14:32 -0700 Subject: [PATCH] [cleanup] remove ssd offload to simplify the FSDP code (#1080) * simlificed the readme * clean up ssd offload * try to fix readthedocs Co-authored-by: Min Xu --- .readthedocs.yaml | 29 + README.md | 160 +-- benchmarks/fsdp.py | 13 +- docs/source/conf.py | 4 +- fairscale/experimental/nn/ssd_offload.py | 957 ------------------ fairscale/nn/data_parallel/__init__.py | 1 - .../fully_sharded_data_parallel.py | 171 +--- fairscale/nn/misc/flatten_params_wrapper.py | 39 +- tests/ci_test_list_2.txt | 2 - tests/experimental/nn/test_ssd_offload.py | 439 -------- tests/nn/data_parallel/test_fsdp_offload.py | 517 ---------- 11 files changed, 66 insertions(+), 2266 deletions(-) create mode 100644 .readthedocs.yaml delete mode 100644 fairscale/experimental/nn/ssd_offload.py delete mode 100644 tests/experimental/nn/test_ssd_offload.py delete mode 100644 tests/nn/data_parallel/test_fsdp_offload.py diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 000000000..579620101 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,29 @@ +# .readthedocs.yaml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# We need python > 3.8 due to a dependency on numpy. +build: + os: ubuntu-20.04 + tools: + python: "3.9" + # You can also specify other tool versions: + # nodejs: "16" + # rust: "1.55" + # golang: "1.17" + +# Build documentation in the docs/ directory with Sphinx +sphinx: + configuration: docs/source/conf.py + +# If using Sphinx, optionally build your docs in additional formats such as PDF +# formats: +# - pdf + +# Optionally declare the Python requirements required to build your docs +python: + install: + - requirements: docs/requirements.txt diff --git a/README.md b/README.md index 821c88810..49d3ed386 100644 --- a/README.md +++ b/README.md @@ -25,23 +25,6 @@ FairScale was designed with the following values in mind: [![Explain Like I’m 5: FairScale](https://img.youtube.com/vi/oDt7ebOwWIc/0.jpg)](https://www.youtube.com/watch?v=oDt7ebOwWIc) -## What's New: - -* March 2022 [fairscale 0.4.6 was released](https://github.com/facebookresearch/fairscale/releases/tag/v0.4.6). -* We have support for CosFace's LMCL in MEVO. This is a loss function that is suitable for large number of prediction target classes. -* January 2022 [fairscale 0.4.5 was released](https://github.com/facebookresearch/fairscale/releases/tag/v0.4.5). -* We have experimental support for layer wise gradient scaling. -* We enabled reduce_scatter operation overlapping in FSDP backward propagation. -* December 2021 [fairscale 0.4.4 was released](https://github.com/facebookresearch/fairscale/releases/tag/v0.4.4). -* FairScale is tested with the following PyTorch versions (with CUDA 11.2): 1.8.1, 1.10.0 and 1.11.0.dev20211101+cu111. -* November 2021 [fairscale 0.4.3 was released](https://github.com/facebookresearch/fairscale/releases/tag/v0.4.3). -* We have experimental support for offloading params to disk when using the FSDP API for evaluation workloads. -* We have an experimental layer that fuses multiple layers together to support large vocab size trainings. -* November 2021 [fairscale 0.4.2 was released](https://github.com/facebookresearch/fairscale/releases/tag/v0.4.2). -* We have a new experimental API called the LayerwiseMemoryTracker to help track, visualize and suggest fixes for memory issues occurring during the forward/backward pass of your models. -* Introducing SlowMoDistributedDataParallel API, a distributed training wrapper that is useful on clusters with slow network interconnects (e.g. Ethernet). -* September 2021 [`master` branch renamed to `main`](https://github.com/github/renaming). - ## Installation To install FairScale, please see the following [instructions](https://github.com/facebookresearch/fairscale/blob/main/docs/source/installation_instructions.rst). @@ -50,134 +33,26 @@ You should be able to install a package with pip or conda, or build directly fro ## Getting Started The full [documentation](https://fairscale.readthedocs.io/) contains instructions for getting started, deep dives and tutorials about the various FairScale APIs. -## Examples - -Here are a few sample snippets from a subset of FairScale offerings: - -### Pipe - -Run a 4-layer model on 2 GPUs. The first two layers run on cuda:0 and the next two layers run on cuda:1. - -```python -import torch - -import fairscale - -model = torch.nn.Sequential(a, b, c, d) -model = fairscale.nn.Pipe(model, balance=[2, 2], devices=[0, 1], chunks=8) -``` - -### Optimizer state sharding (ZeRO) -See a more complete example [here](https://github.com/facebookresearch/fairscale/blob/main/benchmarks/oss.py), but a minimal example could look like the following : - -```python -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -from fairscale.optim.oss import OSS -from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP - -def train( - rank: int, - world_size: int, - epochs: int): - - # DDP init example - dist.init_process_group(backend='nccl', init_method="tcp://localhost:29501", rank=rank, world_size=world_size) - - # Problem statement - model = myAwesomeModel().to(rank) - dataloader = mySuperFastDataloader() - loss_fn = myVeryRelevantLoss() - base_optimizer = torch.optim.SGD # pick any pytorch compliant optimizer here - base_optimizer_arguments = {} # pass any optimizer specific arguments here, or directly below when instantiating OSS - - # Wrap the optimizer in its state sharding brethren - optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments) - - # Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks - model = ShardedDDP(model, optimizer) - - # Any relevant training loop, nothing specific to OSS. For example: - model.train() - for e in range(epochs): - for batch in dataloader: - # Train - model.zero_grad() - outputs = model(batch["inputs"]) - loss = loss_fn(outputs, batch["label"]) - loss.backward() - optimizer.step() - - dist.destroy_process_group() - -if __name__ == "__main__": - # Supposing that WORLD_SIZE and EPOCHS are somehow defined somewhere - mp.spawn( - train, - args=( - WORLD_SIZE, - EPOCHS, - ), - nprocs=WORLD_SIZE, - join=True, - ) -``` - -### AdaScale SGD - -AdaScale can be used to wrap a SGD optimizer and to be used in DDP (Distributed Data Parallel) -training or non-DDP with gradient accumulation. The benefit is to re-use the same LR -schedule from a baseline batch size when effective batch size is bigger. - -Note that AdaScale does _not_ help increase per-GPU batch size. - -```python -from torch.optim import SGD -from torch.optim.lr_scheduler import LambdaLR # or your scheduler -from fairscale.optim import AdaScale - -... -optim = AdaScale(SGD(model.parameters(), lr=0.1)) -scheduler = LambdaLR(optim, ...) -... -# Note: the train loop should be with DDP or with gradient accumulation. -last_epoch = 0 -step = 0 -done = False -while not done: - for sample in dataset: - ... - step += optim.gain() - optim.step() - epoch = step // len(dataset) - if last_epoch != epoch: - scheduler.step() - last_epoch = epoch - if epoch > max_epoch: - done = True -``` +## FSDP -Primary goal is to allow scaling to bigger batch sizes without losing model accuracy. -(However, training time might be longer comparing to without AdaScale.) - -At a high level, we want ML researchers to: - * go parallel more easily (i.e. no need to find new learning rate schedules) - * not worrying about losing accuracy - * potentially higher GPU efficiency (fewer steps, less networking overhead, etc.) +FullyShardedDataParallel (FSDP) is the recommended method for scaling to large NN models. +This library has been [upstreamed to PyTorch](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/). +The version of FSDP here is for historical references as well as for experimenting with +new and crazy ideas in research of scaling techniques. Please see the following blog +for [how to use FairScale FSDP and how does it work](https://engineering.fb.com/2021/07/15/open-source/fsdp/). ## Testing We use circleci to test FairScale with the following PyTorch versions (with CUDA 11.2): -* the latest stable release (1.10.0) -* the latest LTS release (1.8.1) -* a recent nightly release (1.11.0.dev20211101+cu111) +* the latest stable release (e.g. 1.10.0) +* the latest LTS release (e.g. 1.8.1) +* a recent nightly release (e.g. 1.11.0.dev20211101+cu111) Please create an [issue](https://github.com/facebookresearch/fairscale/issues) if you are having trouble with installation. ## Contributors -We welcome outside contributions! Please see the [CONTRIBUTING](CONTRIBUTING.md) instructions for how you can contribute to FairScale. +We welcome contributions! Please see the [CONTRIBUTING](CONTRIBUTING.md) instructions for how you can contribute to FairScale. ## License @@ -198,22 +73,9 @@ If you use FairScale in your publication, please cite it by using the following ```BibTeX @Misc{FairScale2021, - author = {Mandeep Baines and Shruti Bhosale and Vittorio Caggiano and Naman Goyal and Siddharth Goyal and Myle Ott and Benjamin Lefaudeux and Vitaliy Liptchinsky and Mike Rabbat and Sam Sheiffer and Anjali Sridhar and Min Xu}, + author = {FairScale authors}, title = {FairScale: A general purpose modular PyTorch library for high performance and large scale training}, howpublished = {\url{https://github.com/facebookresearch/fairscale}}, year = {2021} } ``` - -## FAQ -1. If you experience an error indicating a default branch does not exist, it probably due to the latest update, switching the default branch from "master" to "main" -``` -error: pathspec 'non-existing-branch' did not match any file(s) known to git -``` -Please run the following commands to update to the main branch. -``` -git branch -m master main -git fetch origin -git branch -u origin/main main -git remote set-head origin -a -``` diff --git a/benchmarks/fsdp.py b/benchmarks/fsdp.py index d1ce87f90..4a372e50d 100644 --- a/benchmarks/fsdp.py +++ b/benchmarks/fsdp.py @@ -25,7 +25,6 @@ from benchmarks.golden_configs.lm_wikitext2 import FSDP as lm_wikitext2 from fairscale.nn import auto_wrap, default_auto_wrap_policy, enable_wrap from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP -from fairscale.nn.data_parallel import OffloadConfig RPC_PORT = 29501 @@ -95,10 +94,7 @@ def get_lm_model(args, device, config): nhid = config["nhid"] ndecoder = config["num_decoder_layers"] - if args.ssd_offload: - return transformer_lm.TransformerLM(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder) - else: - return transformer_lm.TransformerLM(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device) + return transformer_lm.TransformerLM(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device) def get_tensors_by_size_bucket(): @@ -200,7 +196,7 @@ def get_batch(source): if i > 0: total_tokens += source.numel() - if args.benchmark_eval or args.ssd_offload: + if args.benchmark_eval: input = source.cuda() target = target.cuda() output = model(input) @@ -250,7 +246,6 @@ def get_number_of_words(data): def benchmark_language_model(model_config, model, benchmark_config, model_specs, args): - # TODO(anj): Uncomment and add a check for regression once we have a couple of runs. golden_config = get_golden_config(args.model_name, args) epoch = benchmark_config["epochs"] start_time = time.time() @@ -358,8 +353,6 @@ def benchmark_fsdp(rank, args, world_size): model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs) model = model_config["model"] config = {} - if args.ssd_offload: - config["offload_config"] = OffloadConfig(offload_type="ssd_offload") if args.full_fp16: config["compute_dtype"] = torch.float16 @@ -386,7 +379,6 @@ def benchmark_fsdp(rank, args, world_size): parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.") parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.") parser.add_argument( - # TODO(anj-s): In the process of adding more models and hence the requirement for a flag. "--model_name", default="lm", help="Language Model(LM) used to benchmark FSDP.", @@ -394,7 +386,6 @@ def benchmark_fsdp(rank, args, world_size): parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information") parser.add_argument("--enable_auto_wrap", action="store_true", default=False, help="Use auto_wrap with FSDP") parser.add_argument("--benchmark_eval", action="store_true", default=False, help="Benchmark evaluation workflow.") -parser.add_argument("--ssd_offload", action="store_true", default=False, help="Benchmark ssd_offload workflow.") parser.add_argument("--full_fp16", action="store_true", default=False, help="Benchmark in full fp16 mode.") if __name__ == "__main__": diff --git a/docs/source/conf.py b/docs/source/conf.py index 34a46b03a..b511981be 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -30,7 +30,7 @@ # -- Project information ----------------------------------------------------- project = "FairScale" -copyright = "2020-2021, Facebook/Meta AI Research" +copyright = "2020-2022, Facebook/Meta AI Research" author = "Facebook/Meta AI Research" # -- General configuration --------------------------------------------------- @@ -68,7 +68,7 @@ autodoc_member_order = "bysource" intersphinx_mapping = { - "python": ("https://docs.python.org/3.6", None), + "python": ("https://docs.python.org/3.8", None), "numpy": ("https://numpy.org/doc/stable/", None), "torch": ("https://pytorch.org/docs/stable/", None), } diff --git a/fairscale/experimental/nn/ssd_offload.py b/fairscale/experimental/nn/ssd_offload.py deleted file mode 100644 index d80ff25a8..000000000 --- a/fairscale/experimental/nn/ssd_offload.py +++ /dev/null @@ -1,957 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -from enum import Enum, auto -from functools import reduce -import io -import os -import pickle -from types import TracebackType -from typing import IO, Any, BinaryIO, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Type, Union - -import numpy as np -import torch -from torch.serialization import DEFAULT_PROTOCOL as DEFAULT_PROTOCOL - -from fairscale.internal import torch_version - -try: - from torch.utils._pytree import tree_map -except ImportError: - # The PyTorch version(<1.9) we test with does not support the tree_map API. - pass - -if torch_version() < (1, 12, 0): - raise ImportError( - f"ssd_offload only works on torch versions 1.12.0 and beyond, but torch version is: {torch.__version__}" - ) - -DEFAULT_CHUNK_SIZE = 2048 * 2048 - - -def _get_num_chunks(input_tensor: torch.Tensor, chunk_size_bytes: int = DEFAULT_CHUNK_SIZE) -> int: - """Returns the number of chunks that the given tensor can be divided into.""" - size_in_bytes = input_tensor.nelement() * input_tensor.element_size() - num_chunks = (size_in_bytes + (chunk_size_bytes - 1)) // chunk_size_bytes - return num_chunks - - -def _tensor_to_bytes_chunks( - input_tensor: torch.Tensor, chunk_idx: int, chunk_size_bytes: int = DEFAULT_CHUNK_SIZE -) -> bytes: - """Converts the given tensor into a chunked array containing chunk_size_bytes.""" - size_in_bytes = input_tensor.nelement() * input_tensor.element_size() - assert chunk_idx < _get_num_chunks(input_tensor, chunk_size_bytes) - input_tensor_np = input_tensor.detach().numpy().view(np.uint8).reshape(-1) - chunk_start = chunk_idx * chunk_size_bytes - chunk_end = min(size_in_bytes, chunk_start + chunk_size_bytes) - return input_tensor_np[chunk_start:chunk_end].tobytes() - - -def write(input_tensor: torch.Tensor, filename: str, file_offset_bytes: int = 0) -> None: - """Populates the file with the data stored in the given tensor.""" - num_chunks = _get_num_chunks(input_tensor) - file_flags = "r+b" if os.path.exists(filename) else "wb" - with open(filename, file_flags) as f: - f.seek(file_offset_bytes) - for i in range(num_chunks): - f.write(_tensor_to_bytes_chunks(input_tensor, i)) - - -def read(input_tensor: torch.Tensor, filename: str, file_offset_bytes: int = 0) -> None: - """Populates the given tensor with the data stored in a file.""" - size_in_bytes = input_tensor.nelement() * input_tensor.element_size() - chunk_size_bytes = DEFAULT_CHUNK_SIZE - num_chunks = _get_num_chunks(input_tensor) - input_tensor_np = input_tensor.detach().numpy() - input_tensor_mv = memoryview(input_tensor_np.view(dtype=np.uint8).reshape(-1)) - with io.open(filename, "rb") as f: - f.seek(file_offset_bytes) - for i in range(num_chunks): - chunk_start = i * chunk_size_bytes - chunk_end = min(size_in_bytes, chunk_start + chunk_size_bytes) - data_read = f.readinto(input_tensor_mv[chunk_start:chunk_end]) - if data_read != chunk_end - chunk_start: - raise RuntimeError( - f"Attempted to read {chunk_end - chunk_start} more bytes from {filename}, but only read: {data_read} bytes. Total Bytes read = {chunk_start + data_read}, total bytes expected: {size_in_bytes}" - ) - - -class StorageState(Enum): - """ - Simple enum to indicate whether the tensor handle is pointing - to data on disk or memory. This is useful for asserting on - whether the tensor is available for operations or if it needs - to be moved from disk to CPU or device. - """ - - UNALLOCATED = auto() - ON_DISK = auto() - ON_CPU_CLEAN = auto() - ON_CPU_DIRTY = auto() - - -class SsdTensorHandle(torch.Tensor): - """ - This class extends from torch.Tensor and represents a Tensor which is backed by SSD storage. - The SsdTensorHandle object can point to a file or a tensor and there are corresponding functions to read - data into the tensor that is an attribute of the SsdTensorHandle object or write the tensor to file. At any - point in time the Tensor may be in memory or on disk. - - Class Variables: - override_directory_path: This variable is used by CheckpointPathContextManager to modify the path to any - SsdTensorHandles that are saved to a checkpoint via pickling (e.g. torch.save) - - Args: - shape torch.Size: Shape of the tensor that is represented by the handle. - dtype: torch.dtype: Dtype of the tensor that is represented by the handle. - requires_grad: bool: Property of the tensor that is represeneted by the handle. - - Returns: - A SSDTensorHandle object representing a Tensor. - """ - - override_directory_path: Optional[str] = None - - @staticmethod - def __new__( - cls: Type[SsdTensorHandle], - shape: torch.Size, - dtype: torch.dtype, - requires_grad: bool = False, - device: torch.device = torch.device("cpu"), - flush_on_dirty: bool = True, - allow_unsafe_changes: bool = False, - ) -> SsdTensorHandle: - r = super(SsdTensorHandle, cls)._make_wrapper_subclass(cls, shape, dtype=dtype, requires_grad=requires_grad, device=device) # type: ignore - return r - - def __init__( - self, - shape: torch.Size, - dtype: torch.dtype, - requires_grad: bool, - device: torch.device = torch.device("cpu"), - flush_on_dirty: bool = True, - allow_unsafe_changes: bool = False, - ) -> None: - self._unpickle_f: Optional[Union[BinaryIO, IO[bytes]]] = None - - self._shape = shape - if len(shape) == 0: - self._numel = 0 - else: - self._numel = reduce((lambda x, y: x * y), shape) - self._dtype = dtype - # valid if offloaded to file - self.filename = "" - self.offset = -1 - # valid if loaded to memory - self.tensor: Optional[torch.Tensor] = None - self.storage_state = StorageState.UNALLOCATED - self.flush_on_dirty = flush_on_dirty - self.allow_unsafe_changes = allow_unsafe_changes - - def mark_dirty(self) -> None: - assert self.tensor is not None - assert self.storage_state in [StorageState.ON_CPU_CLEAN, StorageState.ON_CPU_DIRTY] - self.storage_state = StorageState.ON_CPU_DIRTY - # hack to force write on mark_dirty - if self.flush_on_dirty: - self.to_file() - - @classmethod - def from_file( - cls, shape: torch.Size, dtype: torch.dtype, filename: str, offset: int = 0, requires_grad: bool = False - ) -> SsdTensorHandle: - """Returns a new SsdTensorHandle from a file.""" - handle = cls(shape=shape, dtype=dtype, requires_grad=requires_grad) - handle.point_to_file(filename, offset=offset) - return handle - - @classmethod - def from_tensor(cls: Type[SsdTensorHandle], tensor: torch.Tensor) -> SsdTensorHandle: - """Returns a new SsdTensorHandle from a tensor.""" - handle = cls(shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad, device=tensor.device) - handle.point_to_tensor(tensor) - return handle - - def is_available(self) -> bool: - return self.tensor is not None - - def get_tensor(self) -> torch.Tensor: - assert self.tensor is not None - return self.tensor - - def set_file_params(self, filename: str, offset: int) -> None: - self.filename = filename - self.offset = offset - - def point_to_file(self, filename: str, offset: int) -> None: - self.set_file_params(filename, offset) - self.tensor = None - self.storage_state = StorageState.ON_DISK - - def point_to_tensor(self, tensor: torch.Tensor) -> None: - assert self.tensor is None - if not self.allow_unsafe_changes: - assert self._shape == tensor.shape - assert self._dtype == tensor.dtype - self.tensor = tensor - self.storage_state = StorageState.ON_CPU_DIRTY - - # if resizing a handle that is part of an ssd buffer, care must be taken that the new size - # doesn't conflict with adjacent handles! - def point_to_resized_tensor(self, tensor: torch.Tensor) -> None: - assert self._dtype == tensor.dtype - self._shape = tensor.shape - self.tensor = tensor - - def to_tensor(self) -> torch.Tensor: - """Returns the tensor represented by the SsdTensorHandle object. - - If the tensor is on disk, it is copied into the tensor attribute and returned. - """ - if self.tensor is not None: - return self.tensor - else: - if self.device != torch.device("cpu"): - raise RuntimeError( - f"to_tensor called on an SsdTensorHandle when the tensor has been offloaded to disk. self.device = {self.device}, it should be {torch.device('cpu')}. Some unexpected .data override has occured!!" - ) - result_tensor = torch.empty(size=self.shape, dtype=self.dtype, requires_grad=self.requires_grad) - self.copy_into_tensor(result_tensor) - self.tensor = result_tensor - self.storage_state = StorageState.ON_CPU_CLEAN - return self.tensor - - def to_file(self, permit_when_tensor_none: bool = False, release_tensor_after_write: bool = True) -> None: - """Saves the tensor to disk and releases memory if specified.""" - assert self.tensor is not None or permit_when_tensor_none - - # if it's available in Memory but not modified, no need to write-back - if self.tensor is not None: - if self.storage_state is StorageState.ON_CPU_DIRTY: - if self.device != torch.device("cpu"): - raise RuntimeError( - f"to_file called on an SsdTensorHandle when self.device = {self.device}, it should be {torch.device('cpu')}. Some unexpected .data override has occured!!" - ) - write(self.tensor, self.filename, self.offset * self.tensor.element_size()) - if release_tensor_after_write: - self.tensor = None - self.storage_state = StorageState.ON_DISK - else: - self.storage_state = StorageState.ON_CPU_CLEAN - - def copy_into_tensor(self, tensor: torch.Tensor) -> None: - """Copies SsdTensorHandle's data into the given tensor. - - If the tensor is in memory, this function copies the data - into the passed in tensor. Otherwise, it reads from file into tensor, - using the read() function. - This does not modify modify self.tensor unlike the to_tensor() - function. This can be useful for calls like named_parameters() when - the tensor is already offloaded to disk. - """ - # ideally this should be checked but .data shenanigans forces it to - # be disabled due to the way FSDP shards parameters - # assert self._shape == tensor.shape - assert self._dtype == tensor.dtype - if self.tensor is not None: - tensor.copy_(self.tensor) - else: - read(tensor, self.filename, self.offset * tensor.element_size()) - - __torch_function__ = torch._C._disabled_torch_function_impl - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore - """Intercepts all operations performed on this handle object. - - Before any operation, the tensor attribute is unwrapped from the handle - and used in the operation. We maintain a refernce to the tensor and its current - versions to track if modifications have been made. If we detect changes to the - tensor, we write it to the file maintained by the Handle. - """ - func_name = func.overloadpacket.__name__ - ssd_tensor_handles = [] - - def unwrap(e: Any) -> torch.Tensor: - if isinstance(e, SsdTensorHandle): - t = e.to_tensor() - ssd_tensor_handles.append(e) - return t - else: - return e - - r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) - - for e in ssd_tensor_handles: - inplace_is_this_tensor = ( - (func_name.endswith("_") and not func_name.endswith("__")) or func_name.startswith("__i") - ) and e is args[0] - out_is_this_tensor = False if "out" not in kwargs else e is kwargs["out"] - if inplace_is_this_tensor or out_is_this_tensor: - e.mark_dirty() - return r - - def __setattr__(self, name: str, value: Any) -> None: - if name == "data": - assert isinstance(value, torch.Tensor) - if not self.allow_unsafe_changes: - # Respect .data changes, and the user better know what they are doing! - if self.storage_state == StorageState.ON_CPU_DIRTY: - raise RuntimeError( - "Attempting to override tensor when the existing tensor is dirty, this is an error!" - ) - if value.shape != self.shape: - raise RuntimeError( - f"Attempting to override tensor metadata using .data to change shape of tensor. Orig shape: {self.shape} New shape: {value.shape}" - ) - if value.requires_grad != self.requires_grad: - raise RuntimeError( - f"Attempting to override tensor metadata using .data to change requires_grad. Orig value: {self.requires_grad} New value: {value.requires_grad}" - ) - self.tensor = value - super(SsdTensorHandle, self).__setattr__(name, value) - - @classmethod - def __unpickle__( - cls: Type[SsdTensorHandle], shape: torch.Size, dtype: torch.dtype, requires_grad: bool, filename: str - ) -> SsdTensorHandle: - result = cls(shape, dtype, requires_grad) - result.point_to_file(filename, 0) - result._unpickle_f = io.open(result.filename, "wb") - return result - - def __reduce_ex__(self, proto: int) -> Tuple[Any, Any, Any, Any]: - byte_iter = None - filename = self.filename - if self.override_directory_path is not None: - head, tail = os.path.split(self.filename) - filename = os.path.join(self.override_directory_path, tail) - if self.is_available(): - byte_iter = iter(TensorChunkingIterator(self.tensor)) # ignore: type - else: - byte_iter = iter( - FileChunkingIterator(self.filename, expected_size_bytes=self.numel() * self.element_size()) - ) - return ( - self.__unpickle__, # Callable - # Args to the callable above - (self._shape, self._dtype, self.requires_grad, filename), - None, - byte_iter, - ) - - def append(self, item: bytes) -> None: - assert self._unpickle_f - self._unpickle_f.write(item) - - def extend(self, items: List[bytes]) -> None: - for i in items: - self.append(i) - - -class CheckpointPathContextManager: - """ - This Context allows the user to override the directory path when pickling an SsdTensorHandle Object. - It is needed because the filename which the SsdTensorHandle points to (and is used when unpickling) - is already baked into the pickled data. - - Consider the following example code - ssd_handle = SsdTensorHandle.from_tensor(ref_tensor) - ssd_handle.set_file_params('/home/user/handle.bin', 0) - torch.save(ssd_handle, '/home/user/checkpoint.pkl') - ssd_handle += 1 - ssd_handle.to_file() - ssd_handle2 = torch.load('/home/user/checkpoint.pkl') - - print(f"handles are equal: {torch.equals(ssd_handle, ssd_handle2)}") - - One would expect this to print False, however unintuitively it will print True. - ssd_handle.filename and ssd_handle2.filename are equal. This means that - when we execute torch.load, we read from the .pkl file and write the result into - /home/user/handle.bin, clobbering the updated result from `ssd_handle += 1` - - We want to give the user the possibility of not clobbering the data using this - Context Manager. - - ssd_handle = SsdTensorHandle.from_tensor(ref_tensor) - ssd_handle.set_file_params('/home/user/handle.bin', 0) - with CheckpointPathContextManager(override_path='/home/user/checkpoint_data/'): - torch.save(ssd_handle, '/home/user/checkpoint.pkl') - ssd_handle += 1 - ssd_handle.to_file() - ssd_handle2 = torch.load('/home/user/checkpoint.pkl') - - print(f"handles are equal: {torch.equals(ssd_handle, ssd_handle2)}") - - This code results with ssd_handle.filename = '/home/user/handle.bin' and ssd_handle2.filename = - `/home/user/checkpoint_data/handle.bin'. Therefore the torch.load won't clobber ssd_handle, and - the printed result is False. - - """ - - def __init__(self, override_path: str) -> None: - self.old_path = SsdTensorHandle.override_directory_path - self.override_path = override_path - - def __enter__(self) -> None: - SsdTensorHandle.override_directory_path = self.override_path - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - exec_traceback: Optional[TracebackType], - ) -> None: - SsdTensorHandle.override_directory_path = self.old_path - - -# Classes supporting torch.save/load -class TorchSaver: - def __init__(self) -> None: - self.pickle_module = DisableMemoizationPicklerModule - - def save( - self, obj: Any, f: Union[str, os.PathLike, BinaryIO, IO[bytes]], pickle_protocol: int = DEFAULT_PROTOCOL - ) -> None: - torch.serialization.save( - obj, f, self.pickle_module, pickle_protocol=pickle_protocol, _use_new_zipfile_serialization=False - ) - - -class SsdParameter(SsdTensorHandle, torch.nn.Parameter): - @classmethod - def from_tensor(cls: Type[SsdParameter], tensor: SsdTensorHandle) -> SsdParameter: # type: ignore - r = cls(tensor.shape, tensor.dtype, tensor.requires_grad, device=tensor.device) - r.point_to_tensor(tensor) - return r - - @staticmethod - def __new__( - cls: Type[SsdParameter], - shape: torch.Size, - dtype: torch.dtype, - requires_grad: bool = True, - device: torch.device = torch.device("cpu"), - ) -> SsdParameter: - r = super(SsdParameter, cls).__new__(cls, shape=shape, dtype=dtype, requires_grad=requires_grad, device=device) - return r # type: ignore - - def __init__( - self, - shape: torch.Size, - dtype: torch.dtype, - requires_grad: bool = True, - device: torch.device = torch.device("cpu"), - ) -> None: - super(SsdParameter, self).__init__(shape=shape, dtype=dtype, requires_grad=requires_grad, device=device) - - -class SsdFlatParameter(SsdParameter): - """A parameter that is initialized from a list of parameters and can be - turned into a list of views as needed. - - This class should eventually be moved to fairscale/nn/misc/flatten_params_wrapper.py - """ - - def __new__( - cls: Type[SsdFlatParameter], - shapes: Sequence[torch.Size], - dtype: torch.dtype, - requires_grad: bool = True, - device: torch.device = torch.device("cpu"), - ) -> SsdFlatParameter: - """Make an object using the parent's __new__ function.""" - - # A empty of non-list input doesn't make sense. - if not isinstance(shapes, (list, tuple)) or len(shapes) == 0: - raise ValueError("An non-empty list or tuple argument is needed") - - size = sum([np.prod(s) for s in shapes]) - r = super(SsdFlatParameter, cls).__new__( - cls, torch.Size((size,)), dtype=dtype, requires_grad=requires_grad, device=device - ) - return r # type: ignore - - def __init__( - self, - shapes: Sequence[torch.Size], - dtype: torch.dtype, - requires_grad: bool = True, - device: torch.device = torch.device("cpu"), - ): - """Initialize the _param_numels and _param_shapes lists.""" - self._param_shapes = shapes - self._param_numels = [np.prod(s) for s in shapes] - total_numels = sum(self._param_numels) - assert ( - self.numel() <= total_numels - ), f"Something wrong with __new__ method, {self.numel()} vs. {sum(self._param_numels)}" - - self.views: List[SsdFlatParameterView] = [] - # These are set by FPW class below, not by this class itself. - self._param_infos: List[Tuple[str, torch.nn.Module, str]] = [] - self._shared_param_infos: List[Tuple[str, str, torch.nn.Module, str, torch.nn.Module, str]] = [] - - super(SsdFlatParameter, self).__init__( - shape=torch.Size((total_numels,)), dtype=dtype, requires_grad=requires_grad - ) - - def __setattr__(self, name: str, value: Any) -> None: - super(SsdFlatParameter, self).__setattr__(name, value) - if name == "data": - # if .data has changed, we need to totally destroy any existing views because things - # like device might have changed. It won't destroy any pointers to those views outside - # of here, however resetting self.views will trigger the old view's assertion in - # __torch_dispatch__ that it is the current view of it's parent object - self.views = [] - self._refresh_views() - - def _invalidate_views(self) -> None: - for v in self.views: - v.tensor = None - - @torch.enable_grad() - def _refresh_views(self) -> None: - if self._shape != self.shape: - self.views = [] - return - if len(self.views) == 0: - self.views = [s.view(v) for s, v in zip(self.split(self._param_numels), self._param_shapes)] # type: ignore - else: - for v, t, s in zip(self.views, self.tensor.split(self._param_numels), self._param_shapes): - v.tensor = t.view(s) - - def get_param_views(self, external_data: Optional[torch.Tensor] = None) -> Iterator[torch.Tensor]: - """Return a generator of views that map to the original parameters.""" - # Note, self.data could be sharded, so its numel is <= to the sum. - """ - assert self.data.numel() <= sum( - self._param_numels - ), f"Incorrect internal state {self.data.numel()} vs. {sum(self._param_numels)}" - """ - if external_data is not None: - if external_data.numel() != sum(self._param_numels): - raise ValueError( - f"Incorrect numel of supplied data: got {external_data.numel()} but expected {sum(self._param_numels)}" - ) - return (t.view(s) for (t, s) in zip(external_data.split(self._param_numels), self._param_shapes)) - else: - # this needs to return SsdFlatParameterViews - if not self.is_available(): - self.to_tensor() - - if len(self.views) == 0: - raise RuntimeError( - "Trying to call get_param_views when self.views is empty, this means that .data games have been played and the current .data shape doesn't match the constructed shape." - ) - return (v for v in self.views) - - def metadata(self) -> Tuple[List[str], Sequence[torch.Size], List[int]]: - """Return tuple of (names, shapes, numels) metadata for this flat parameter.""" - names = [".".join([m, n]) if m else n for (m, _, n) in self._param_infos] - return names, self._param_shapes, self._param_numels - - @classmethod - def from_tensors( - cls: Type[SsdFlatParameter], - tensors: Sequence[torch.Tensor], - direct_to_file: bool = False, - filename: str = "", - offset: int = 0, - ) -> "SsdFlatParameter": - """Returns a new SsdFlatParameter from a sequence of tensors.""" - assert ( - len(tensors) > 0 - ), "SsdFlatParameter.from_tensors must be called with at least one tensor in the tensors argument" - - # Flattening involves (1) making a tensor flat (i.e. single dimensional) and (2) making a module - # heirarchy flat (using a single tensor to replace a tree of tensors). Therefore, - # adding back nesting and heirarchy is counter-productive. If nesting is encountered - # in the future, the reasonable thing to do is likely for the top level SsdFlatParameter to - # absorb the nested one and keep the result flat, free from hierarchy. - if any(isinstance(t, SsdFlatParameter) for t in tensors): - raise ValueError("Nesting SsdFlatParameter is not supported") - - requires_grad = tensors[0].requires_grad - dtype = tensors[0].dtype - device = tensors[0].device - for t in tensors: - if t.requires_grad != requires_grad: - raise RuntimeError("Not all tensors have identical requires_grad option") - if t.dtype != dtype: - raise RuntimeError("Not all tensors have identical dtype option") - if t.device != device: - raise RuntimeError("Not all tensors have identical device option") - handle = cls( - shapes=[t.size() for t in tensors], - dtype=tensors[0].dtype, - requires_grad=tensors[0].requires_grad, - device=device, - ) - handle.set_file_params(filename, offset) - if direct_to_file: - assert filename != "" - offset = offset - for t in tensors: - write(t, handle.filename, offset) - offset += t.numel() * t.element_size() - - handle.storage_state = StorageState.ON_DISK - else: - tensor = torch.cat( - [t.reshape(-1) if isinstance(t, torch.nn.Parameter) else t.reshape(-1) for t in tensors], - 0, - ).detach() - tensor.requires_grad_() - handle.point_to_tensor(tensor) - return handle - - __torch_function__ = torch._C._disabled_torch_function_impl - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore - func_name = func.overloadpacket.__name__ - r = super(SsdFlatParameter, cls).__torch_dispatch__(func, types, args, kwargs) # type: ignore - if func_name.startswith("split"): - assert isinstance(args[0], SsdFlatParameter) - parent = args[0] - return [SsdFlatParameterView(parent, t, idx) for idx, t in enumerate(r)] - else: - return r - - # need to subclass these methods to support Views - def point_to_tensor(self, tensor: torch.Tensor) -> None: - super(SsdFlatParameter, self).point_to_tensor(tensor) - self._refresh_views() - - def point_to_file(self, filename: str, offset: int) -> None: - super(SsdFlatParameter, self).point_to_file(filename, offset) - self._invalidate_views() - - def to_tensor(self) -> torch.Tensor: - call_refresh_views = False - if self.tensor is None: - call_refresh_views = True - result = super(SsdFlatParameter, self).to_tensor() - if call_refresh_views: - self._refresh_views() - return result - - def to_file(self, permit_when_tensor_none: bool = False, release_tensor_after_write: bool = True) -> None: - super(SsdFlatParameter, self).to_file(permit_when_tensor_none, release_tensor_after_write) - self._invalidate_views() - - @classmethod - def __unpickle_SFP__( - cls: Type[SsdFlatParameter], - shapes: Sequence[torch.Size], - dtype: torch.dtype, - requires_grad: bool, - filename: str, - ) -> SsdFlatParameter: - result = cls(shapes, dtype, requires_grad) - result.point_to_file(filename, 0) - result._unpickle_f = io.open(result.filename, "wb") - return result - - def __reduce_ex__(self, proto: int) -> Tuple[Any, Any, Any, Any]: - byte_iter = None - filename = self.filename - if self.override_directory_path is not None: - head, tail = os.path.split(self.filename) - filename = os.path.join(self.override_directory_path, tail) - if self.is_available(): - byte_iter = iter(TensorChunkingIterator(self.tensor)) - else: - byte_iter = iter( - FileChunkingIterator(self.filename, expected_size_bytes=self.numel() * self.element_size()) - ) - return ( - self.__unpickle_SFP__, # Callable - # Args to the callable above - (self._param_shapes, self._dtype, self.requires_grad, filename), - None, - byte_iter, - ) - - -class SsdFlatParameterView(torch.Tensor): - """ - Represents a view into an SsdFlatParameter. It is needed due to FSDP's usage of flattening parameters. - """ - - def __new__( - cls: Type[SsdFlatParameterView], parent: SsdFlatParameter, tensor: torch.Tensor, id: int - ) -> SsdFlatParameterView: - r = super(SsdFlatParameterView, cls)._make_wrapper_subclass(cls, tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad, device=tensor.device) # type: ignore - return r - - def __init__(self: SsdFlatParameterView, parent: SsdFlatParameter, tensor: torch.Tensor, id: int) -> None: - self.parent = parent - self.tensor: Optional[torch.Tensor] = tensor - self.id = id - - __torch_function__ = torch._C._disabled_torch_function_impl - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore - """Intercepts all operations performed on this handle object. - - Before any operation, the tensor attribute is unwrapped from the handle - and used in the operation. We maintain a refernce to the tensor and its current - versions to track if modifications have been made. If we detect changes to the - tensor, we write it to the file maintained by the Handle. - """ - func_name = func.overloadpacket.__name__ - ssd_tensor_handles = [] - - def unwrap(e: Any) -> torch.Tensor: - if isinstance(e, SsdFlatParameterView): - if not e.parent.is_available(): - e.parent.to_tensor() - # first condition is to take care of the case where we are first constructing e.parent.views as a list comprehension which hasn't - # completed yet - if len(e.parent.views) != 0 and e is not e.parent.views[e.id]: - raise RuntimeError( - "This view should no longer be used as the parent object has had it's .data overwritten (e.parent.views[e.id])!!!" - ) - # e.parent will ensure that e.tensor is valid and points to tensor view - t = e.tensor - ssd_tensor_handles.append(e) - return t - else: - return e - - r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) - - for e in ssd_tensor_handles: - inplace_is_this_tensor = ( - (func_name.endswith("_") and not func_name.endswith("__")) or func_name.startswith("__i") - ) and e is args[0] - out_is_this_tensor = False if "out" not in kwargs else e is kwargs["out"] - if inplace_is_this_tensor or out_is_this_tensor: - e.parent.mark_dirty() - - if func_name.startswith("view"): - assert isinstance(args[0], SsdFlatParameterView) - flat_view = args[0] - return SsdFlatParameterView(flat_view.parent, r, flat_view.id) - return r - - -# ################################### -# ### BEGIN OVERRIDE_PROPERTY FNs ### -# ################################### -# This code is taken mostly from pytorch core parameterization -# pytorch/torch/nn/utils/parametrize.py - - -def _inject_new_class(module: torch.nn.Module) -> None: - r"""Sets up a module to be parametrized. - - This works by substituting the class of the module by a class - that extends it to be able to inject a property - - Args: - module (nn.Module): module into which to inject the property - """ - cls = module.__class__ - - def getstate(self): # type: ignore - raise RuntimeError( - "Serialization of parametrized modules is only " - "supported through state_dict(). See:\n" - "https://pytorch.org/tutorials/beginner/saving_loading_models.html" - "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training" - ) - - param_cls = type( - f"Parametrized{cls.__name__}", - (cls,), - { - "__getstate__": getstate, - }, - ) - - module.__class__ = param_cls - module.override_properties: Dict[str, Callable[[], torch.Tensor]] = {} # type: ignore - # setattr(module, "override_properties", {}) - - -def _inject_property(module: torch.nn.Module, property_name: str) -> None: - r"""Injects a property into module[property_name]. - - It assumes that the class in the module has already been modified from its - original one using _inject_new_class and that the tensor under :attr:`property_name` - has already been moved out - - Args: - module (nn.Module): module into which to inject the property - property_name (str): name of the name of the property to create - """ - - def get_parametrized(self: torch.nn.Module) -> torch.Tensor: - prop: Callable[[], torch.Tensor] = self.override_properties[property_name] # type: ignore - # If caching is not active, this function just evaluates the parameterization - return prop() - - def set_original(self: torch.nn.Module, value: Callable[[], torch.Tensor]) -> None: - self.override_properties[property_name] = value # type: ignore - - def del_fn(self: torch.nn.Module) -> None: - _remove_property(self, property_name) - - setattr(module.__class__, property_name, property(get_parametrized, set_original, del_fn)) - - -def _register_property(module: torch.nn.Module, property_name: str, property_value: Callable[[], torch.Tensor]) -> None: - has_injected_class = hasattr(module, "override_properties") - if not has_injected_class: - _inject_new_class(module) - if hasattr(module, property_name): - delattr(module, property_name) - module.override_properties[property_name] = property_value # type: ignore - _inject_property(module, property_name) - - -def _remove_property(module: torch.nn.Module, property_name: str, new_property_value: Optional[Any] = None) -> None: - delattr(module.__class__, property_name) - del module.override_properties[property_name] # type: ignore - - # Roll back the parametrized class if no other buffer or parameter - # is currently parametrized in this class - if len(module.override_properties) == 0: # type: ignore - delattr(module, "override_properties") - # Restore class - orig_cls = module.__class__.__bases__[0] - module.__class__ = orig_cls - if new_property_value is not None: - setattr(module.__class__, property_name, new_property_value) - - -# ################################# -# ### END OVERRIDE_PROPERTY FNs ### -# ################################# - - -class SsdFlatParameterViewProperty: - """ - Allows for a mutable view to replace a layer's trainable parameters. - This is needed since FSDP is changing .data under the covers, - SsdFlatParameter cannot just rely on this since each view (of type SsdFlatParameterView) has - an internal representation. So every time we access a view, we need to - make sure we get the up-to-date version, and not the original version - when flattening the parameters. - """ - - def __init__(self, parent: SsdFlatParameter, view_id: int) -> None: - super().__init__() - self.parent = parent - self.view_id = view_id - - def __call__(self) -> SsdFlatParameterView: - return self.parent.views[self.view_id] - - -class SsdFlatParameterViewParameterization(torch.nn.Module): - def __init__(self, parent: SsdFlatParameter, view_id: int) -> None: - super().__init__() - self.parent = parent - self.view_id = view_id - - def forward(self, *args: Any, **kwargs: Any) -> SsdFlatParameterView: - return self.parent.views[self.view_id] - - -class DisableMemoizationPicklerModule: - @classmethod - def Pickler(cls, data_buf: io.BytesIO, protocol: int) -> pickle.Pickler: - p = pickle.Pickler(data_buf, protocol) - p.fast = True - return p - - @classmethod - def dump(cls, obj: Any, f: io.BytesIO, protocol: int) -> None: - pickle.dump(obj, f, protocol) - - -class TensorChunkingIterator: - """ - chunk_size_bytes determines how large each chunk that we break the tensor - into. It is important to consider limiting the size because by when - python unpickles an object, by default it will read up to 1000 list - elements at a time. So memory usage while unpickling will be on the - order of O(min(file_size, 1000 * chunk_size_bytes)). - """ - - def __init__(self, tensor: torch.Tensor, chunk_size_bytes: int = DEFAULT_CHUNK_SIZE) -> None: - - self.tensor = tensor - self.chunk_size_bytes = chunk_size_bytes - - def __iter__(self) -> Iterator[bytes]: - - self.num_chunks = _get_num_chunks(self.tensor, self.chunk_size_bytes) - self.num_chunks_read = 0 - return self - - def __next__(self) -> bytes: - if self.num_chunks_read >= self.num_chunks: - raise StopIteration - next_chunk = _tensor_to_bytes_chunks( - self.tensor, chunk_idx=self.num_chunks_read, chunk_size_bytes=self.chunk_size_bytes - ) - - self.num_chunks_read += 1 - - return next_chunk - - -class FileChunkingIterator: - """ - chunk_size_bytes determines how large each chunk that we break the file - into. It is important to consider limiting the size because by when - python unpickles an object, by default it will read up to 1000 list - elements at a time. So memory usage while unpickling will be on the - order of O(min(file_size, 1000 * chunk_size_bytes)). - """ - - def __init__( - self, filename: str, expected_size_bytes: int = -1, chunk_size_bytes: int = DEFAULT_CHUNK_SIZE - ) -> None: - self.filename = filename - self.file: Optional[Union[BinaryIO, IO[bytes]]] = None - self.chunk_size_bytes = chunk_size_bytes - self.expected_size_bytes = expected_size_bytes - - def __iter__(self) -> Iterator[bytes]: - - if self.expected_size_bytes != -1: - file_size = os.stat(self.filename).st_size - assert ( - file_size == self.expected_size_bytes - ), f"FileChunkingIterator Failed, expecting file to be of size: {self.expected_size_bytes} but got {file_size}" - self.file = io.open(self.filename, "rb", buffering=0) - self.num_chunks_read = 0 - return self - - def __next__(self) -> bytes: - assert self.file - next_chunk = self.file.read(self.chunk_size_bytes) - - if len(next_chunk) == 0: - raise StopIteration - self.num_chunks_read += 1 - - return next_chunk - - -torch_saver = TorchSaver() diff --git a/fairscale/nn/data_parallel/__init__.py b/fairscale/nn/data_parallel/__init__.py index b8e930047..4b79662af 100644 --- a/fairscale/nn/data_parallel/__init__.py +++ b/fairscale/nn/data_parallel/__init__.py @@ -9,7 +9,6 @@ from .fully_sharded_data_parallel import ( FullyShardedDataParallel, - OffloadConfig, TrainingState, auto_wrap_bn, get_fsdp_instances, diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index f30dba800..ccb036539 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -5,13 +5,11 @@ import contextlib import copy -from dataclasses import dataclass from enum import Enum, auto import functools import logging from math import inf import os -import tempfile import time import traceback import typing @@ -69,15 +67,6 @@ else: enable_nccl_base_collectives = True -try: - import fairscale.experimental.nn.ssd_offload as ssd_offload - - import_ssd_offload = True -except ImportError: - # The latest nightly PyTorch version required - import_ssd_offload = False - pass - class TrainingState(Enum): """ @@ -107,19 +96,6 @@ class TrainingState(Enum): SUMMON_FULL_PARAMS = auto() -# Data classes containing FSDP parameter constructs - -# Offload config for specifying SSD options (initially at least) -@dataclass -class OffloadConfig: - """Class for specifying all arguments related to offloading parameters.""" - - # Offload type: currently only supports: "ssd_offload" - offload_type: Optional[str] = None - # Path to the directory for storing parameters offloaded to disk. - dir: Optional[str] = None - - class FullyShardedDataParallel(nn.Module): """ A wrapper for sharding Module parameters across data parallel workers. This @@ -302,10 +278,6 @@ class FullyShardedDataParallel(nn.Module): cpu_offload (bool, Optional): if ``True``, offload params to CPU. Note: This arg will be deprecated in favor of *``move_params_to_cpu``* in an upcoming release. - offload_config (OffloadConfig): - The `OffloadConfig` object is used to specify the type of offload (i.e SSD, CPU) and - other required knobs when offloading parameters from GPU. Currently the OffloadConfig - only supports specifying SSD offload as an option. Note: This is an experimental feature. state_dict_on_rank_0_only (bool): When set to ``True``, ``model.state_dict()`` will only returns full state dict on rank 0 and return empty dict non-rank 0, which allow FullyShardedDataParallel to @@ -342,7 +314,6 @@ def __init__( force_input_to_fp32: bool = False, verbose: bool = False, cpu_offload: bool = False, - offload_config: Optional[OffloadConfig] = None, state_dict_on_rank_0_only: bool = False, gradient_predivide_factor: Optional[float] = None, allow_reset_parameters: bool = False, @@ -414,12 +385,6 @@ def __init__( self.force_input_to_fp32 = force_input_to_fp32 self.verbose = verbose self.state_dict_on_rank_0_only = state_dict_on_rank_0_only - # Experimental feature for now. Use at your own risk. - self.ssd_offload = True if offload_config and offload_config.offload_type == "ssd_offload" else False - if self.ssd_offload and not import_ssd_offload: - raise ImportError( - f"Trying to enable ssd_offload when it was not successfully imported (likely due to old torch version, current = {torch.__version__})" - ) self.gradient_predivide_factor: float = gradient_predivide_factor or self._get_gradient_predivide_factor( self.world_size @@ -433,9 +398,6 @@ def __init__( if self.fp32_reduce_scatter and not self.mixed_precision: raise ValueError("fp32_reduce_scatter requires mixed_precision=True") - if self.ssd_offload and not self.flatten_parameters: - raise ValueError(f"offload type: '{offload_config.offload_type}' requires flatten_parameters=True") - # skip validation if the process group was created above if process_group: validate_process_group(self.compute_device, self.process_group) @@ -456,16 +418,7 @@ def __init__( self._has_params = len(params) > 0 self._has_shared_params = False - # TODO(anj): Should we conditionally do this only if we have params? - # TODO(anj): Figure out if we can allocate the buffer during sharding. self.buffer_size = sum(p.numel() for p in params) - self.ssd_directory = tempfile.gettempdir() - if self.ssd_offload: - assert import_ssd_offload, "We need to import ssd_offload.py to enable the `ssd_offload` feature." - if offload_config and offload_config.dir: - self.ssd_directory = offload_config.dir - self.move_grads_to_cpu = True - self.move_params_to_cpu = True # For now, it is either all flatten or none flatten. This will be extended to # multiple flatten groups in my next PR. @@ -478,9 +431,7 @@ def __init__( param_name_groups = [param_names] del param_names - self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper( - module, param_list=to_be_flatten_params, ssd_offload=self.ssd_offload, ssd_directory=self.ssd_directory - ) + self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(module, param_list=to_be_flatten_params) del module # free original module in case it helps garbage collection # Now, in this FSDP wrapper class, we keep a list of to-be-flatten and not-to-be-flatten @@ -531,8 +482,6 @@ def __init__( # Flag to indicate whether state_dict() should automatically summon the # full params. This defaults to True, but may be set to False if the # user explicitly requests the local state dict via local_state_dict(). - # TODO(anj): This should by default be set to False for ssd_offload=True - # unless we are in the summon_full_params context. self._return_full_state_dict = True init_end = time.time() @@ -544,11 +493,6 @@ def __init__( # This is reset at the end of the backward pass. self._pre_backward_hook_has_run = False - # Free all params at the end of initialization. - if self.ssd_offload: - for m in get_fsdp_instances(self): - m._free_ssd_offload() - def _get_gradient_predivide_factor(self, world_size: int) -> float: factor: int = 1 while world_size % factor == 0 and world_size / factor > factor: @@ -785,10 +729,9 @@ def _shard_parameters_(self) -> None: p._orig_size = p.data.size() if not p._is_sharded: - if not self.ssd_offload: - p._is_sharded = False - self.numel_padded_per_param.append(0) - continue + p._is_sharded = False + self.numel_padded_per_param.append(0) + continue p._is_sharded = True # TODO (Min): broadcast from rank 0 to avoid each rank need to init with the same seed? @@ -797,11 +740,7 @@ def _shard_parameters_(self) -> None: p.data, num_padded = self._get_shard(p.data) self.numel_padded_per_param.append(num_padded) - if self.ssd_offload: - assert isinstance(p, ssd_offload.SsdParameter) - p.to_file() - else: - free_storage_(orig_data) + free_storage_(orig_data) assert len(self.numel_padded_per_param) == len(self.params) @@ -1014,21 +953,11 @@ def _no_return_full_state_dict(self) -> Generator: backup = self._return_full_state_dict self._return_full_state_dict = False - if self.ssd_offload: - # Move params from disk to memory before returning the local state dict. - self._move_params_to_memory() - try: yield finally: self._return_full_state_dict = backup - def _move_params_to_memory(self) -> None: - """Move params from disk to CPU.""" - for p in self.params: - assert isinstance(p, ssd_offload.SsdParameter) - p.to_tensor() - def _load_state_dict( self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True ) -> NamedTuple: @@ -1276,22 +1205,11 @@ def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Ge # Copy any changes made to the full params back into # the corresponding local shards. local_shard, _ = self._get_shard(full_tensor) - if self.ssd_offload: - assert isinstance(p, ssd_offload.SsdParameter) - self._ssd_offload_reset_param_device(p) - p.point_to_tensor(local_shard.view_as(p._fp32_shard).cpu()) - else: - p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard)) + p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard)) if safe_to_free: free_storage_(full_tensor) self.has_full_params = False - if self.ssd_offload: - # Store tensors in the SSD buffer and free param storage. - for p in self.params: - assert isinstance(p, ssd_offload.SsdParameter) - p.to_file() - else: - self._use_fp32_param_shard() + self._use_fp32_param_shard() self.training_state = TrainingState.IDLE def _reset_lazy_init(self) -> None: @@ -1366,11 +1284,6 @@ def _init_param_attributes(self, p: Parameter) -> None: return # A single shard of the parameters in full precision. - # TODO(another-pjohnson) - I believe this will cause memory leakage with ssd - # p.data returns a pointer to a handle, and that handle has it's - # ref count incremented by p._fp32_shard. So this tensor will - # never be freed even if we do p.to_disk(). investigate after - # PR #887 is merged p._fp32_shard = p.data if self.mixed_precision: @@ -1378,14 +1291,11 @@ def _init_param_attributes(self, p: Parameter) -> None: if self.move_params_to_cpu: assert p._fp32_shard.device == torch.device("cpu"), self - # We don't pin memory when using ssd_offload since that results in OOM when - # the memory requirements of a model are larger than host memory. - if not self.ssd_offload: - # If we plan to keep the FP32 parameters on CPU, then pinning - # memory allows us to later use non-blocking transfers when moving - # the FP32 param shard to compute_device. - p._fp32_shard = p._fp32_shard.pin_memory() - p.data = p._fp32_shard + # If we plan to keep the FP32 parameters on CPU, then pinning + # memory allows us to later use non-blocking transfers when moving + # the FP32 param shard to compute_device. + p._fp32_shard = p._fp32_shard.pin_memory() + p.data = p._fp32_shard if self.move_params_to_cpu or self.mixed_precision: @@ -1423,16 +1333,7 @@ def _init_param_attributes(self, p: Parameter) -> None: # pass. In this case, it's important to pre-allocate the CPU grad # shard in pinned memory so that we can do a non-blocking transfer. # This is only needed during training and not evaluation. - if self.ssd_offload: - assert isinstance(p, ssd_offload.SsdParameter) - # Gradients also need to be offloaded to SSD otherwise it can result in - # OOMs when the memory requirements of a model are larger than host memory. - p._cpu_grad = ssd_offload.SsdTensorHandle.from_tensor(torch.zeros_like(p.data, device="cpu")) - p._cpu_grad.allow_unsafe_changes = True - p._cpu_grad.set_file_params(p.filename + "_grad", 0) - p._cpu_grad.to_file() - else: - p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory() + p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory() def _set_is_root(self) -> None: """If ``True``, implies that no other :class:`FullyShardedDataParallel` @@ -1576,17 +1477,8 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: if self.clear_autocast_cache: torch.clear_autocast_cache() - self._free_ssd_offload() - return outputs - @torch.no_grad() - def _free_ssd_offload(self) -> None: - if self.ssd_offload: - for p in self.params: - assert isinstance(p, ssd_offload.SsdParameter) - p.to_file(permit_when_tensor_none=True) - def _register_pre_backward_hooks(self, outputs: Any) -> Any: """Register pre-backward hook to run before the wrapped module's backward. Hooks should be attached to all outputs from the forward. @@ -1990,7 +1882,6 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None: # Update root and nested FSDP's hooks and flags. for m in get_fsdp_instances(self): _finalize_parameters(m) - m._free_ssd_offload() m._pre_backward_hook_has_run = False if any(p.requires_grad for p in m.parameters()): # Check if the module has params and if any of them has @@ -2071,15 +1962,6 @@ def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None: # Trim any padding and reshape to match original size. p.data = p.data[: p._orig_size.numel()].view(p._orig_size) - if self.ssd_offload: - for p in self.params: - assert isinstance(p, ssd_offload.SsdParameter) - if not p.is_available(): - self._ssd_offload_reset_param_device(p) - p.to_tensor() - - self.has_full_params = False - if self._has_shared_params: # self.has_full_params flag can be out of sync if a shared param is # sharded by another FSDP instance. An example is that in eval case @@ -2366,25 +2248,13 @@ def consolidate_shard_weights( return consolidated_weights - @torch.no_grad() - def _ssd_offload_reset_param_device(self, param: Parameter) -> None: - assert isinstance(param, ssd_offload.SsdParameter) - if param.device != torch.device("cpu"): - param.data = param._fp32_shard - param.tensor = None - @torch.no_grad() def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None: """Use FP32 shard for a list of params.""" if params is None: params = self.params for p in params: - if import_ssd_offload and self.ssd_offload: - assert isinstance(p, ssd_offload.SsdParameter) - self._ssd_offload_reset_param_device(p) - p.to_tensor() - else: - p.data = p._fp32_shard + p.data = p._fp32_shard @torch.no_grad() def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = None) -> None: @@ -2395,14 +2265,11 @@ def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = No for p in params: assert p._fp16_shard is not None alloc_storage_(p._fp16_shard, size=p._fp32_shard.size()) - if self.ssd_offload: - p._fp16_shard.copy_(p.to(p._fp16_shard.device, non_blocking=True)) - else: - p._fp16_shard.copy_( - # If move_params_to_cpu is True, this will be non-blocking - # because _fp32_shard is pinned, otherwise it's a no-op. - p._fp32_shard.to(p._fp16_shard.device, non_blocking=True) - ) + p._fp16_shard.copy_( + # If move_params_to_cpu is True, this will be non-blocking + # because _fp32_shard is pinned, otherwise it's a no-op. + p._fp32_shard.to(p._fp16_shard.device, non_blocking=True) + ) p.data = p._fp16_shard torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"]) diff --git a/fairscale/nn/misc/flatten_params_wrapper.py b/fairscale/nn/misc/flatten_params_wrapper.py index 065e149c3..69d332491 100644 --- a/fairscale/nn/misc/flatten_params_wrapper.py +++ b/fairscale/nn/misc/flatten_params_wrapper.py @@ -8,7 +8,6 @@ from contextlib import contextmanager from itertools import chain -import tempfile import typing from typing import ( TYPE_CHECKING, @@ -31,19 +30,6 @@ from torch import Tensor import torch.nn as nn -try: - from fairscale.experimental.nn.ssd_offload import ( - SsdFlatParameter, - SsdFlatParameterView, - SsdFlatParameterViewProperty, - _register_property, - ) - - import_ssd_offload = True -except ImportError: - import_ssd_offload = False - pass - from fairscale.internal.state_dict import replace_by_prefix_ if TYPE_CHECKING: @@ -169,15 +155,8 @@ def __init__( module: nn.Module, param_list: ParamGroups = None, flat_param_names: Optional[List[str]] = None, - ssd_offload: bool = False, - ssd_directory: str = "", ): super().__init__() - if ssd_offload and not import_ssd_offload: - raise ImportError( - f"Trying to enable ssd_offload when it was not successfully imported (likely due to old torch version, current = {torch.__version__})" - ) - self.ssd_offload = ssd_offload self._fpw_module = module self.is_flattened = False @@ -239,14 +218,7 @@ def __init__( # Init all flat_params. for new_p_set in self._param_sets: params, param_infos, shared_param_infos = self._init_flatten_params(new_p_set) - if ssd_offload: - assert ssd_directory != "" - (handle, fname) = tempfile.mkstemp(dir=ssd_directory, suffix="ssd_buf_param") - flat_param = SsdFlatParameter.from_tensors(tensors=params) - flat_param.allow_unsafe_changes = True - flat_param.set_file_params(fname, 0) - else: - flat_param = FlatParameter(params, params[0].requires_grad) + flat_param = FlatParameter(params, params[0].requires_grad) flat_param._param_infos = param_infos flat_param._shared_param_infos = shared_param_infos self.flat_params.append(flat_param) @@ -393,13 +365,8 @@ def _unflatten_params_as_views(self) -> None: ps = self.get_param_views() param_views = [] for (_, m, n), p in zip(self._param_infos, ps): - if self.ssd_offload: - assert isinstance(p, SsdFlatParameterView) - _register_property(m, n, SsdFlatParameterViewProperty(p.parent, p.id)) - - else: - setattr(m, n, p) # This will set as plain attr - param_views.append(p) + setattr(m, n, p) # This will set as plain attr + param_views.append(p) # Save param views for easy access if anyone still wants to access # parameters of the module. diff --git a/tests/ci_test_list_2.txt b/tests/ci_test_list_2.txt index 4e0876a5b..a3ccc7b07 100644 --- a/tests/ci_test_list_2.txt +++ b/tests/ci_test_list_2.txt @@ -6,7 +6,6 @@ tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py tests/experimental/nn/test_offload.py tests/experimental/nn/test_auto_shard.py tests/experimental/optim/test_dynamic_loss_scaler.py -tests/experimental/nn/test_ssd_offload.py tests/nn/data_parallel/test_fsdp_shared_weights_mevo.py tests/nn/data_parallel/test_fsdp_shared_weights.py tests/nn/data_parallel/test_fsdp_pre_backward_hook.py @@ -50,5 +49,4 @@ tests/nn/pipe/test_dependency.py tests/nn/pipe/test_stream.py tests/nn/moe/test_moe_layer.py tests/nn/moe/test_top2gating.py -tests/nn/data_parallel/test_fsdp_offload.py tests/nn/data_parallel/test_fsdp_fwd_fwd_bwd_bwd.py diff --git a/tests/experimental/nn/test_ssd_offload.py b/tests/experimental/nn/test_ssd_offload.py deleted file mode 100644 index 2a8bf9ebf..000000000 --- a/tests/experimental/nn/test_ssd_offload.py +++ /dev/null @@ -1,439 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -""" -Testing SsdFlatParameter and SsdTensorHandle modules. -""" - -import filecmp -import functools -import os -import tempfile - -import numpy as np -import pytest -import torch - -pytestmark = pytest.mark.skip(reason="ssd offload to be removed to simplify the code") - -try: - import fairscale.experimental.nn.ssd_offload as so -except ImportError as ie: - # Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release. - pytestmark = pytest.mark.skipif(True, reason=ie.msg) - pass - - -def _init(): - torch.manual_seed(0) - np.random.seed(0) - - -def test_write_read(): - _init() - - with tempfile.NamedTemporaryFile() as f: - ref_tensor = torch.rand(128, dtype=torch.float32) - test_tensor = torch.zeros_like(ref_tensor) - assert not torch.equal(ref_tensor, test_tensor) - so.write(ref_tensor, f.name) - so.read(test_tensor, f.name) - assert torch.equal(ref_tensor, test_tensor) - - -def test_ssd_handle_dispatch_fwd(): - _init() - - with tempfile.NamedTemporaryFile() as f: - orig_tensor = torch.randn(128) - ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor) - ssd_handle.set_file_params(f.name, 0) - ssd_handle.to_file(release_tensor_after_write=True) - - assert torch.equal(ssd_handle.to_tensor(), orig_tensor) - - # This should trigger the torch_dispatch code and write - # back the results to the file - ssd_handle.add_(1) - plus1_tensor = orig_tensor.add(1) - assert torch.equal(ssd_handle.to_tensor(), plus1_tensor) - - -def test_ssd_handle_dispatch_bwd(): - _init() - - with tempfile.NamedTemporaryFile() as f: - orig_tensor = torch.randn((4, 4), requires_grad=True) - orig_copy = orig_tensor.clone().detach().requires_grad_(True) - ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor) - ssd_handle.set_file_params(f.name, 0) - ssd_handle.to_file(release_tensor_after_write=True) - - assert torch.equal(ssd_handle.to_tensor(), orig_tensor) - - y1 = ssd_handle + 1 - y2 = orig_copy + 1 - y1.sum().backward() - y2.sum().backward() - - assert torch.equal(ssd_handle.grad, orig_copy.grad) - - -@pytest.mark.skip("broken at head") -def test_ssd_handle_dispatch_bwd_hook(): - _init() - - def post_backward_hook(name, grad): - print(f"BACKWARD HOOK for tensor {name} CALLED") - - with tempfile.NamedTemporaryFile() as f: - orig_tensor = torch.randn((4, 4), requires_grad=True) - orig_copy = orig_tensor.clone().detach().requires_grad_(True) - ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor) - ssd_handle.set_file_params(f.name, 0) - ssd_handle.to_file(release_tensor_after_write=True) - one = torch.ones(1, requires_grad=True).cuda() - - orig_copy = ssd_handle.data - cuda_copy = ssd_handle.to("cuda").detach().requires_grad_(True) - ssd_handle.data = cuda_copy - - ssd_handle.register_hook(functools.partial(post_backward_hook, "ssd_handle")) - one.register_hook(functools.partial(post_backward_hook, "one")) - - y1 = ssd_handle + one - y1.sum().backward() - - -def test_ssd_handle_train_simple(): - _init() - - with tempfile.NamedTemporaryFile() as f: - orig_tensor = torch.randn((4, 4), requires_grad=True) - - with torch.no_grad(): - orig_copy = torch.empty_like(orig_tensor) - orig_copy.copy_(orig_tensor) - orig_copy.requires_grad = True - - ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor) - ssd_handle.flush_on_dirty = False - ssd_handle.set_file_params(f.name, 0) - ssd_handle.to_file(release_tensor_after_write=True) - - assert torch.equal(ssd_handle.to_tensor(), orig_tensor) - optimizer_ssd = torch.optim.SGD([ssd_handle], lr=0.1) - optimizer_orig = torch.optim.SGD([orig_copy], lr=0.1) - - y1 = ssd_handle + 1 - optimizer_ssd.zero_grad() - y1.sum().backward() - assert ssd_handle.storage_state is so.StorageState.ON_CPU_CLEAN - optimizer_ssd.step() - assert ssd_handle.storage_state is so.StorageState.ON_CPU_DIRTY - - y2 = orig_copy + 1 - optimizer_orig.zero_grad() - y2.sum().backward() - optimizer_orig.step() - - assert torch.equal(ssd_handle.to_tensor(), orig_copy) - - -def test_torch_save_load_ssd_flat_param_on_disk(): - _init() - orig_file = tempfile.NamedTemporaryFile(prefix="tensor") - checkpoint_file = tempfile.NamedTemporaryFile(prefix="checkpoint", suffix=".pt") - checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir") - - # TENSOR_SHAPE = (1024, 1024, 2048) - # use smaller shape for unit tests - TENSOR_SHAPE = (1024, 321) - ref_tensors = [torch.rand(TENSOR_SHAPE, dtype=torch.float32) for i in range(4)] - ssd_handle = so.SsdFlatParameter.from_tensors(ref_tensors, False) - ssd_handle.set_file_params(orig_file.name, 0) - ssd_handle.to_file() - ref_tensors = [] - - # after deleting ref_tensor, memory usage should be very low - # For save it shouldn't be more than 10x so.DEFAULT_CHUNK_SIZE - with so.CheckpointPathContextManager(override_path=checkpoint_load_directory.name): - so.torch_saver.save(ssd_handle, checkpoint_file.name) - # below line saves file to checkpoint_load_directory/orig_file.name - # Memory usage here should be O(1000 * so.DEFAULT_CHUNK_SIZE) - # 1000x because that's how many elements the python unpickler - # will buffer before passing to the SsdTensor - test_ssd_handle = torch.load(checkpoint_file) - head, tail = os.path.split(orig_file.name) - assert filecmp.cmp(orig_file.name, os.path.join(checkpoint_load_directory.name, tail), shallow=False) - - -def test_torch_save_load_ssd_flat_param_on_mem(): - _init() - orig_file = tempfile.NamedTemporaryFile(prefix="tensor") - checkpoint_file = tempfile.NamedTemporaryFile(prefix="checkpoint", suffix=".pt") - checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir") - - # TENSOR_SHAPE = (1024, 1024, 2048) - # use smaller shape for unit tests - TENSOR_SHAPE = (1024, 321) - ref_tensors = [torch.rand(TENSOR_SHAPE, dtype=torch.float32) for i in range(4)] - ssd_handle = so.SsdFlatParameter.from_tensors(ref_tensors, False) - ssd_handle.set_file_params(orig_file.name, 0) - ref_tensors = [] - - # after deleting ref_tensor, memory usage should be very low - # For save it shouldn't be more than 10x so.DEFAULT_CHUNK_SIZE - with so.CheckpointPathContextManager(override_path=checkpoint_load_directory.name): - so.torch_saver.save(ssd_handle, checkpoint_file.name) - # below line saves file to checkpoint_load_directory/orig_file.name - # Memory usage here should be O(1000 * so.DEFAULT_CHUNK_SIZE) - # 1000x because that's how many elements the python unpickler - # will buffer before passing to the SsdTensor - test_ssd_handle = torch.load(checkpoint_file) - assert torch.equal(ssd_handle, test_ssd_handle) - - -def test_ssd_param_train_simple(): - _init() - with tempfile.NamedTemporaryFile() as f: - orig_tensor = torch.randn((4, 4)) - - with torch.no_grad(): - orig_copy = torch.empty_like(orig_tensor) - orig_copy.copy_(orig_tensor) - param = torch.nn.Parameter(orig_copy) - - ssd_param = so.SsdParameter(orig_tensor.shape, orig_tensor.dtype) - ssd_param.point_to_tensor(orig_copy) - ssd_param.flush_on_dirty = False - ssd_param.set_file_params(f.name, 0) - ssd_param.to_file(release_tensor_after_write=True) - - assert torch.equal(ssd_param.to_tensor(), orig_tensor) - optimizer_ssd = torch.optim.SGD([ssd_param], lr=0.1) - optimizer_orig = torch.optim.SGD([param], lr=0.1) - - y1 = ssd_param + 1 - optimizer_ssd.zero_grad() - y1.sum().backward() - # Test to see if Dirty is being calculated correctly when optimizer modifies - # ssd_param - assert ssd_param.storage_state is so.StorageState.ON_CPU_CLEAN - optimizer_ssd.step() - assert ssd_param.storage_state is so.StorageState.ON_CPU_DIRTY - - y2 = param + 1 - optimizer_orig.zero_grad() - y2.sum().backward() - optimizer_orig.step() - - assert torch.equal(ssd_param.to_tensor(), param) - - -def test_ssd_flat_parameter_basic(): - _init() - with tempfile.NamedTemporaryFile() as f: - refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32)) - refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32)) - refc_param = torch.nn.Parameter(torch.rand(128, dtype=torch.float32)) - ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], direct_to_file=False) - ssd_flat_param.set_file_params(f.name, 0) - - param_views = list(ssd_flat_param.get_param_views()) - - assert refa_param.shape == param_views[0].shape - assert refb_param.shape == param_views[1].shape - assert refc_param.shape == param_views[2].shape - - assert torch.equal(refa_param, param_views[0]) - assert torch.equal(refb_param, param_views[1]) - assert torch.equal(refc_param, param_views[2]) - ssd_flat_param.to_file() - - assert not ssd_flat_param.is_available() - first_value = param_views[0][0][0].item() - assert ssd_flat_param.is_available() - assert first_value == refa_param[0][0].item() - - -def test_ssd_flat_parameter_view_modify(): - _init() - with tempfile.NamedTemporaryFile() as f: - refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=False) - refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=False) - refc_param = torch.nn.Parameter(torch.rand(128, dtype=torch.float32), requires_grad=False) - ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], direct_to_file=False) - ssd_flat_param.set_file_params(f.name, 0) - ssd_flat_param.flush_on_dirty = False - - param_views = list(ssd_flat_param.get_param_views()) - - assert ssd_flat_param.storage_state == so.StorageState.ON_CPU_DIRTY - ssd_flat_param.to_file() - assert ssd_flat_param.storage_state == so.StorageState.ON_DISK - assert param_views[0].tensor is None - - param_views[0] += 0.1 - assert ssd_flat_param.storage_state == so.StorageState.ON_CPU_DIRTY - - -@pytest.mark.skip("broken at head") -def test_ssd_flat_parameter_view_bwd(): - _init() - - hooks_called = [] - - def post_backward_hook(name, hooks_called, *grads): - print(f"BACKWARD HOOK for tensor {name} CALLED") - hooks_called.append(name) - - with tempfile.NamedTemporaryFile() as f: - refa_param = ( - torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=True) - .to("cpu") - .detach() - .requires_grad_() - ) - refb_param = ( - torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=True) - .to("cpu") - .detach() - .requires_grad_() - ) - refc_param = ( - torch.nn.Parameter(torch.rand(128, dtype=torch.float32), requires_grad=True) - .to("cpu") - .detach() - .requires_grad_() - ) - ssd_flat_param = so.SsdFlatParameter.from_tensors( - [refa_param, refb_param, refc_param], direct_to_file=True, filename=f.name, offset=0 - ) - orig_copy = ssd_flat_param.data - cuda_copy = ssd_flat_param.to("cuda").detach().requires_grad_() - cpu_copy = ssd_flat_param.to("cpu").detach().requires_grad_() - - p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp. - assert p_tmp.grad_fn is not None - grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object. - grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_orig", hooks_called)) - - ssd_flat_param.data = cuda_copy - one = torch.ones(1, requires_grad=True, device=ssd_flat_param.device) - y1 = ssd_flat_param.views[0] + one - y2 = cuda_copy + 1 - - # ssd_flat_param.to_file() - # ssd_flat_param.data = orig_copy - - p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp. - assert p_tmp.grad_fn is not None - grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object. - grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_cuda", hooks_called)) - ssd_flat_param.views[0].register_hook( - functools.partial(post_backward_hook, "ssd_flat_param.views[0]", hooks_called) - ) - ssd_flat_param.register_hook(functools.partial(post_backward_hook, "ssd_flat_param", hooks_called)) - one.register_hook(functools.partial(post_backward_hook, "one", hooks_called)) - - y1.sum().backward() - y2.sum().backward() - - assert "GradAccumulation_cuda" in hooks_called - assert "ssd_flat_param.views[0]" in hooks_called - assert "ssd_flat_param" in hooks_called - assert "one" in hooks_called - - -@pytest.mark.skip("broken at head") -def test_ssd_flat_parameter_view_bwd_parameterization(): - _init() - - hooks_called = [] - - def post_backward_hook(name, hooks_called, *grads): - print(f"BACKWARD HOOK for tensor {name} CALLED") - hooks_called.append(name) - - with tempfile.NamedTemporaryFile() as f: - layer1 = torch.nn.Linear(32, 4, bias=False) - layer2 = torch.nn.Linear(32, 4, bias=False) - layer3 = torch.nn.Linear(128, 1, bias=False) - ssd_flat_param = so.SsdFlatParameter.from_tensors( - [layer1.weight, layer2.weight, layer3.weight], direct_to_file=False, filename=f.name, offset=0 - ) - torch.nn.utils.parametrize.register_parametrization( - layer1, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 0) - ) - torch.nn.utils.parametrize.register_parametrization( - layer2, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 1) - ) - torch.nn.utils.parametrize.register_parametrization( - layer3, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 2) - ) - - orig_copy = ssd_flat_param.data - cuda_copy = ssd_flat_param.to("cuda").detach().requires_grad_() - cpu_copy = ssd_flat_param.to("cpu").detach().requires_grad_() - - p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp. - assert p_tmp.grad_fn is not None - grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object. - grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_orig", hooks_called)) - - ssd_flat_param.to_file(release_tensor_after_write=False) - ssd_flat_param.data = cuda_copy - one = torch.ones(layer1.weight.shape, requires_grad=True, device=ssd_flat_param.device) - y1 = layer1.forward(one) - y2 = cuda_copy + 1 - - # ssd_flat_param.to_file() - # ssd_flat_param.data = orig_copy - - p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp. - assert p_tmp.grad_fn is not None - grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object. - grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_cuda", hooks_called)) - ssd_flat_param.views[0].register_hook( - functools.partial(post_backward_hook, "ssd_flat_param.views[0]", hooks_called) - ) - ssd_flat_param.register_hook(functools.partial(post_backward_hook, "ssd_flat_param", hooks_called)) - one.register_hook(functools.partial(post_backward_hook, "one", hooks_called)) - - y1.sum().backward() - y2.sum().backward() - - assert "GradAccumulation_cuda" in hooks_called - assert "ssd_flat_param.views[0]" in hooks_called - assert "ssd_flat_param" in hooks_called - assert "one" in hooks_called - - -def test_ssd_flat_parameter_direct_to_file(): - _init() - with tempfile.NamedTemporaryFile() as f: - refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32)) - refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32)) - refc_param = torch.nn.Parameter(torch.rand(128, dtype=torch.float32)) - ssd_flat_param = so.SsdFlatParameter.from_tensors( - [refa_param, refb_param, refc_param], direct_to_file=True, filename=f.name, offset=0 - ) - - param_views = list(ssd_flat_param.get_param_views()) - - assert refa_param.shape == param_views[0].shape - assert refb_param.shape == param_views[1].shape - assert refc_param.shape == param_views[2].shape - - assert torch.equal(refa_param, param_views[0]) - assert torch.equal(refb_param, param_views[1]) - assert torch.equal(refc_param, param_views[2]) - ssd_flat_param.to_file() - - assert not ssd_flat_param.is_available() - first_value = param_views[0][0][0].item() - assert ssd_flat_param.is_available() - assert first_value == refa_param[0][0].item() diff --git a/tests/nn/data_parallel/test_fsdp_offload.py b/tests/nn/data_parallel/test_fsdp_offload.py deleted file mode 100644 index a5b83b0eb..000000000 --- a/tests/nn/data_parallel/test_fsdp_offload.py +++ /dev/null @@ -1,517 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import functools -import itertools -import sys -import tempfile -import time -import unittest - -from parameterized import parameterized -import pytest -import torch -from torch import nn -import torch.distributed - -pytestmark = pytest.mark.skip(reason="ssd offload to be removed to simplify the code") - -try: - import fairscale.experimental.nn.ssd_offload as so -except ImportError as ie: - # Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release. - pytestmark = pytest.mark.skipif(True, reason=ie.msg) - pass - -from fairscale.fair_dev.testing.testing import dist_init, spawn_for_all_world_sizes -from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper -from fairscale.nn.data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState - -# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4 -# All helper functions called by spawn must be either @classmethod, @staticmethod - - -class DistributedTest(unittest.TestCase): - def setUp(self): - if not torch.cuda.is_available(): - raise unittest.SkipTest("CUDA not available, skipping test") - if sys.platform == "win32": - raise unittest.SkipTest("NCCL doesn't support Windows, skipping test") - if torch.cuda.device_count() < 2: - raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping") - - @staticmethod - def _eval_with_config(model, autocast): - model.eval() - model_device = torch.device("cuda") - with torch.cuda.amp.autocast(enabled=autocast): - # Inputs always cuda regardless of move_grads_cpu, or model.device - input = model.module.get_input(torch.device("cuda")) - output = model(*input) - loss = model.module.get_loss(input, output).to(model_device) - assert loss.dtype == torch.float32 - if isinstance(model, FullyShardedDataParallel): - model.assert_state(TrainingState.IDLE) - return loss.detach() - - @staticmethod - def _eval_for_several_steps(model, num_steps, autocast, lr=0.01, norm_type=None): - model.eval() - # Inputs always cuda regardless of move_grads_cpu, or model.device - input = model.module.get_input(torch.device("cuda")) - - for _ in range(num_steps): - with torch.cuda.amp.autocast(enabled=autocast): - output = model(*input) - - @classmethod - def _test_identical_outputs_eval( - cls, - model_init_fn, - config, - rank, - group, - num_steps=2, - use_cuda=True, - lr=0.01, - ref_ddp_fn=None, - ): - if config.get("mixed_precision", False): - autocast = True - # Force the compute dtype to be torch.float32 so that we get - # identical results as PyTorch DDP when using autocast. Note that - # this will cause the all-gather to happen in FP32, which is slower - # than necessary in most cases. - config["compute_dtype"] = torch.float32 - else: - autocast = False - - # Establish reference behavior with PyTorch DDP (+ optionally autocast). - model = model_init_fn(group=group, wrapper_config=None).cuda() - if ref_ddp_fn is None: - model = nn.parallel.DistributedDataParallel( - model, device_ids=[rank], output_device=rank, process_group=group - ) - else: - model = ref_ddp_fn(model, group) - ref_loss = cls._eval_with_config(model, autocast) - ref_state_dict = model.module.state_dict() - if config.get("cpu_offload", False): - for k in ref_state_dict.keys(): - ref_state_dict[k] = ref_state_dict[k].cpu() - - # Confirm we get the same behavior using FullyShardedDataParallel. - if config.get("ssd_offload", False): - config["offload_config"] = OffloadConfig(offload_type="ssd_offload") - # ssd offload only supports flatten_params ATM - config["flatten_parameters"] = True - - del config["ssd_offload"] - model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config) - if not model.ssd_offload and not model.move_params_to_cpu: - if use_cuda: - model = model.cuda() - else: - assert next(model.parameters()).device == torch.device("cpu") - shard_loss = cls._eval_with_config(model, autocast) - - try: - torch.testing.assert_allclose(ref_loss, shard_loss) - except (AssertionError, RuntimeError) as e: - raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}") - if config.get("flatten_parameters", True): - metadata = model.local_metadata_dict() - assert isinstance(metadata, dict) - - -keys = ["reshard_after_forward", "mixed_precision", "nested_wrapping"] -CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([True, False], repeat=len(keys))] - - -def rename_test(testcase_func, param_num, param): - return "%s_%s" % ( - testcase_func.__name__, - parameterized.to_safe_name(str(param.args)), - ) - - -class TestSsdMemory(DistributedTest): - def test_memory_benchmark(self): - - test_fn = functools.partial(self._test_memory_benchmark, config={}) - spawn_and_init(test_fn) - - @classmethod - def _test_memory_benchmark(self, rank, group, config): - time_keeper = TimeKeeper() - - SIZE = 8 * 8 - time_keeper.print_time("START", 1.0) - a = torch.empty(1) - b = a.cuda() - # wait for cuda to fully load - time.sleep(1) - time_keeper.print_time("INIT_CUDA", 1.0) - model = SimpleLinear(group, input_size=SIZE, output_size=SIZE, layers=4) - time_keeper.print_time("CPU_MODEL", 1.0) - - with tempfile.TemporaryDirectory() as current_tempdir: - config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir) - - model = FullyShardedDataParallel(model, **config) - time_keeper.print_time("FSDP_MODEL", 1.0) - - self._eval_for_several_steps(model, 1, autocast=False) - time_keeper.print_time("EVAL") - - -class SimpleLinear(nn.Module): - def __init__(self, group, input_size, output_size, layers=1, **unused_kwargs): - super().__init__() - self.rank = group.rank() - self.world_size = group.size() - self.input_size = input_size - self.output_size = output_size - torch.manual_seed(0) # keep everything deterministic - seq_layers = [] - for i in range(layers): - seq_layers.append(nn.Linear(input_size, output_size, bias=False)) - self.module = nn.Sequential(*seq_layers) - self.bs = 2 - - def get_input(self, device): - torch.manual_seed(1 + self.rank) # keep everything deterministic - src = torch.rand((self.bs, self.input_size), device=device, dtype=torch.float32) - tgt = torch.rand((self.bs, self.input_size), device=device, dtype=torch.float32) - return (src, tgt) - - def forward(self, src_ids, tgt_ids): - param_devices = [p.device for p in self.module.parameters()] - - return self.module(src_ids) - - def get_loss(self, input, output): - _, tgt = input - - return nn.functional.binary_cross_entropy_with_logits(output, tgt) - - def run_backward(self, loss): - loss.backward() - - -KEYS = ["ssd_offload", "flatten_parameters", "mixed_precision", "move_params_to_cpu"] -CONFIG = [[dict(zip(KEYS, config))] for config in itertools.product([True, False], repeat=len(KEYS))] - - -class TimeKeeper: - def __init__(self): - self.start_time = time.time() - - def print_time(self, s: str, wait_time: float = 1.0): - cur_time = time.time() - print(f"@time: {cur_time - self.start_time:0.2f} {s}") - time.sleep(wait_time) - - -class TestModuleProperties(DistributedTest): - @parameterized.expand(CONFIG, name_func=rename_test) - def test_named_parameters(self, config): - - test_fn = functools.partial(self._test_named_params, config=config) - spawn_and_init(test_fn) - - @classmethod - def _test_named_params(self, rank, group, config): - # Get the named parameters before wrapping. - before_wrap_model = TransformerWithSharedParams(group) - before_wrap_params = before_wrap_model.named_parameters() - - with tempfile.TemporaryDirectory() as current_tempdir: - if config["ssd_offload"]: - config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir) - # ssd offload only supports flatten_params ATM - config["flatten_parameters"] = True - del config["ssd_offload"] - - model = FullyShardedDataParallel(before_wrap_model, **config) - print(f"model.ssd_offload {model.ssd_offload}") - if not model.ssd_offload and not model.move_params_to_cpu: - model = model.cuda() - - self._eval_with_config(model, autocast=config["mixed_precision"]) - - # Get the named parameters after wrapping to compare. - after_wrap_params = model.named_parameters() - - if not config.get("flatten_parameters", False): - for before_nm, after_nm in zip(before_wrap_params, after_wrap_params): - assert before_nm[0] == after_nm[0] - else: - named_params_flat = [p for p in after_wrap_params][0][0] - assert "flat_param_0" in named_params_flat - - after_wrap_params = model.named_parameters() - - for before_nm, after_nm_original in zip(before_wrap_params, after_wrap_params): - assert before_nm[0] == after_nm_original[0] - torch.testing.assert_allclose(before_nm[1].shape, after_nm_original[1].shape) - - -class TestSsdLoading(DistributedTest): - @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) - def test_ssd_offloading_eval(self, config): - - test_fn = functools.partial(self._test_ssd_offload_eval, config=config) - spawn_and_init(test_fn) - - @parameterized.expand(CONFIG, name_func=rename_test) - def test_transformer_parameterized(self, config): - - spawn_and_init(functools.partial(self._test_identical_outputs_eval, TransformerWithSharedParams, config)) - - @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) - def test_ssd_offloading_train_flatten_params_wrapper(self, config): - - test_fn = functools.partial(self._test_ssd_offloading_train_flatten_params_wrapper, config=config) - spawn_and_init(test_fn) - - @classmethod - def _test_ssd_offloading_train_flatten_params_wrapper(self, rank, group, config): - SIZE = 16 * 16 - LR = 0.01 - MOMENTUM = 0.1 - model = SimpleLinear(group, input_size=SIZE, output_size=SIZE, layers=4) - - with tempfile.TemporaryDirectory() as current_tempdir: - config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir) - config["flatten_parameters"] = True - - nested_wrapping = config["nested_wrapping"] - del config["nested_wrapping"] - - if nested_wrapping: - model = FullyShardedDataParallel( - NestedWrappedModule(group, wrap_everything=True, wrapper_config=config) - ) - else: - model = FullyShardedDataParallel(model, **config) - model_device = torch.device("cuda") - model.train() - optim = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM) - - checkpoint_file = tempfile.NamedTemporaryFile() - checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir") - - pre_checkpoint_last_output = None - post_checkpoint_last_output = None - - ITERATIONS = 10 - - # Inputs always cuda regardless of move_grads_cpu, or model.device - with torch.cuda.amp.autocast(enabled=config.get("mixed_precision", False)): - for i in range(ITERATIONS): - optim.zero_grad() - input = model.get_input(torch.device("cuda")) - output = model(*input) - pre_checkpoint_last_output = output - """ - param_itr = iter(model.named_parameters()) - p_name, p_val = next(param_itr) - print(f"i={i} pre_checkpoint {p_name} = {p_val[0].item()}") - """ - loss = model.module.get_loss(input, output).to(model_device) - assert loss.dtype == torch.float32 - model.module.run_backward(loss) - optim.step() - if i == 0: - with so.CheckpointPathContextManager(override_path=checkpoint_load_directory.name): - # so.torch_saver.save({"model": model.state_dict(), "optim": optim.state_dict()}, checkpoint_file.name) - torch.save({"model": model.state_dict()}, checkpoint_file.name) - # reset momentum just after checkpoint save - optim = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM) - - checkpoint = torch.load(checkpoint_file.name) - model.load_state_dict(checkpoint["model"]) - # reset momentum just after checkpoint load - optim = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM) - # do more iterations after loading checkpoint - for i in range(ITERATIONS - 1): - optim.zero_grad() - input = model.get_input(torch.device("cuda")) - output = model(*input) - post_checkpoint_last_output = output - """ - param_itr = iter(model.named_parameters()) - p_name, p_val = next(param_itr) - print(f"i={i} post_checkpoint {p_name} = {p_val[0].item()}") - """ - loss = model.module.get_loss(input, output).to(model_device) - assert loss.dtype == torch.float32 - - model.module.run_backward(loss) - optim.step() - - # Verify output of checkpoint load + run is equal to original output - assert torch.equal(pre_checkpoint_last_output, post_checkpoint_last_output) - if isinstance(model, FullyShardedDataParallel): - model.assert_state(TrainingState.IDLE) - - @classmethod - def _test_ssd_offload_eval(self, rank, group, config): - model = TransformerWithSharedParams(group) - state_dict = model.state_dict() - - nested_wrapping = config["nested_wrapping"] - del config["nested_wrapping"] - config["flatten_parameters"] = True - - with tempfile.TemporaryDirectory() as current_tempdir: - config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir) - if nested_wrapping: - model = FullyShardedDataParallel( - NestedWrappedModule(group, wrap_everything=True, wrapper_config=config) - ) - else: - model = FullyShardedDataParallel(model, **config) - - self._eval_with_config(model, autocast=config["mixed_precision"]) - - # With SSD offload only local_state_dict will work. We can support global - # state dict if we think it is necessary. - state_dict = model.local_state_dict() - model.load_local_state_dict(state_dict) - - self._eval_with_config(model, config["mixed_precision"]) - - -class TransformerWithSharedParams(nn.Module): - def __init__(self, group, *unused_args, d_vocab=23, d_model=16, add_bn=True, **unused_kwargs): - super().__init__() - self.rank = group.rank() - self.world_size = group.size() - torch.manual_seed(0) # keep everything deterministic - assert d_vocab >= 12 # we use torch.arange(12) as input - self.embed_tokens = nn.Embedding(d_vocab, d_model) - self.transformer = nn.Transformer( - d_model=d_model, - num_encoder_layers=2, - num_decoder_layers=2, - dim_feedforward=8, - dropout=0.1, - ) - self.output_proj = nn.Linear(d_model, d_vocab) - - # share the embedding and output projection weights - self.output_proj.weight = self.embed_tokens.weight - self.register_buffer("vocab_bias", self.embed_tokens.weight.new_ones((d_model,))) - self.register_buffer("long_buffer", torch.zeros_like(self.vocab_bias, dtype=torch.long)) - - self.bs = 2 - self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity() - - def get_input(self, device): - torch.manual_seed(1 + self.rank) # keep everything deterministic - src = torch.arange(12, device=device).view(6, self.bs) # T x B - tgt = torch.arange(self.bs * 4, device=device).view(4, self.bs) # T x B - return (src, tgt) - - def forward(self, src_ids, tgt_ids): - src = self.embed_tokens(src_ids) - src = src + self.vocab_bias + self.long_buffer.type_as(src) - tgt = self.embed_tokens(tgt_ids) - tgt = self.bn(tgt) - x = self.transformer(src, tgt) - return self.output_proj(x) - - def get_loss(self, input, output): - _, tgt = input - return nn.functional.cross_entropy(output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum") - - def run_backward(self, loss): - loss.backward() - - -class NestedWrappedModule(nn.Module): - def __init__(self, group, wrapper_config, wrap_everything=False, checkpoint=False): - super().__init__() - self.rank = group.rank() - self.world_size = group.size() - self.wrapper_config = wrapper_config - - def _maybe_wrap(layer): - if wrapper_config is not None: - return FullyShardedDataParallel(layer, group, **wrapper_config) - return layer - - torch.manual_seed(0) # keep everything deterministic - self.module = nn.Sequential( - nn.Linear(8, 4), - _maybe_wrap( - nn.Sequential( - _maybe_wrap(nn.Linear(4, 16)), - nn.Linear(16, 16), - ) - ), - _maybe_wrap(nn.Linear(16, 4)), - nn.Linear(4, 8), - ) - - # Wrap all modules triggers a corner case where root FSDP doesn't have any params. - # Test it with checkpoint_wrapper as well to validate final backward callback - # is queued correctly when root FSDP does not have any params and every layer is - # wrapped as FSDP(checkpoint(module)). - if wrap_everything: - if checkpoint: - self.module = nn.Sequential( - _maybe_wrap(checkpoint_wrapper(nn.Linear(8, 4))), - _maybe_wrap(checkpoint_wrapper(nn.Linear(4, 16))), - _maybe_wrap(checkpoint_wrapper(nn.Linear(16, 4))), - _maybe_wrap(checkpoint_wrapper(nn.Linear(4, 8))), - ) - else: - self.module = nn.Sequential( - _maybe_wrap(nn.Linear(8, 4)), - _maybe_wrap(nn.Linear(4, 16)), - _maybe_wrap(nn.Linear(16, 4)), - _maybe_wrap(nn.Linear(4, 8)), - ) - - def get_input(self, device): - torch.manual_seed(1 + self.rank) # keep everything deterministic - return (torch.rand(4, 8, device=device),) - - def forward(self, x): - return self.module(x) - - def get_loss(self, input, output): - loss = output.sum() - return loss - - def run_backward(self, loss): - loss.backward() - - -def spawn_and_init(fn, args=None, **spawn_kwargs): - if args is None: - args = () - - run_fn = functools.partial(init_and_run, fn, args) - - # Below 3 lines are to easily enable single-process debugging - # _, filename = tempfile.mkstemp() - # _, filename_rpc = tempfile.mkstemp() - # run_fn(0, 1, filename, filename_rpc) - - spawn_for_all_world_sizes(run_fn, **spawn_kwargs) - - -def init_and_run(fn, args, rank, world_size, filename, filename_rpc): - dist_init(rank, world_size, filename, filename_rpc) - group = torch.distributed.new_group() - fn(rank, group, *args) - - -if __name__ == "__main__": - unittest.main()