Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Compiled_autograd] running nn.LayerNorm failed for torch.compile with compiled_autograd when deepspeed Zero3 #140091

Open
yitingw1 opened this issue Nov 8, 2024 · 13 comments
Labels
module: compiled autograd compiled_autograd oncall: distributed Add this issue/PR to distributed oncall triage queue oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@yitingw1
Copy link

yitingw1 commented Nov 8, 2024

🐛 Describe the bug

When running a simple model including torch.nn.LayerNorm using deepspeed zero3 with torch.compile and compiled_autograd. An error occurs:

site-packages/torch/_subclasses/fake_tensor.py:2017] RuntimeError: Attempting to broadcast a dimension of length 0 at -1! Mismatching argument at index 1 had torch.Size([0]); but expected shape should be broadcastable to [100, 120]

We first found this error in BERT model with deepspeed Zero3 with torch.compile and compiled_autograd.

  • It's ok for deepspeed Zero1/2 with torch.compile and compiled_autograd
  • It's ok for deepspeed Zero3 with torch.compile and without compiled_autograd
  • There are a lot of graph beaks and recompiles in deepspeed Zero3 with torch.compile. And Zero3 will partition model parameters through hooks.
  • To simplify the issue, I made a small reproducer to extract error op(torch.nn.LayerNorm)

Investigation
The error: "RuntimeError: Attempting to broadcast a dimension of length 0 at -1! Mismatching argument at index 1 had torch.Size([0]); but expected shape should be broadcastable to [100, 120]"
It occurs when compiled autograd tries to trace the backward graph.
It appears in LayerNorm backward decompositions. It tries to broadcast weight_cast(torch.Size([0]) to grad_out_cast' shape([100, 120]) and fails.

if weight_cast is not None:         
    grad_x_hat = grad_out_cast * weight_cast 

If bypassing the LayerNorm weight by setting nn.LayerNorm(120, eps=1e-12, elementwise_affine=False) instead of elementwise_affine=True in the file deepspeed_reproducer_cpu.py, the running is ok.

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @chauhang @penguinwu @xmfan @yf225 @ezyang @albanD @gqchen @pearu @nikitaved @soulitzer @Varal7

Error logs

[rank0]: Traceback (most recent call last):
[rank0]:   File "/home1/yitingw1/habana/deepspeed_demo/deepspeed_reproducer_cpu.py", line 83, in <module>
[rank0]:     model_engine.backward(loss)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2020, in backward
[rank0]:     self.optimizer.backward(loss, retain_graph=retain_graph)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 2250, in backward
[rank0]:     self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
[rank0]:     scaled_loss.backward(retain_graph=retain_graph)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/utils/_stats.py", line 21, in wrapper
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1308, in __torch_dispatch__
[rank0]:     return proxy_call(self, func, self.pre_dispatch, args, kwargs)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 906, in proxy_call
[rank0]:     out = func(*args, **kwargs)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/_ops.py", line 716, in __call__
[rank0]:     return self._op(*args, **kwargs)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/utils/_stats.py", line 21, in wrapper
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1238, in __torch_dispatch__
[rank0]:     return self.dispatch(func, types, args, kwargs)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1692, in dispatch
[rank0]:     return self._cached_dispatch_impl(func, types, args, kwargs)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1348, in _cached_dispatch_impl
[rank0]:     output = self._dispatch_impl(func, types, args, kwargs)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1943, in _dispatch_impl
[rank0]:     return decomposition_table[func](*args, **kwargs)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/_decomp/decompositions.py", line 1729, in native_layer_norm_backward
[rank0]:     grad_x_hat = grad_out_cast * weight_cast
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/utils/_stats.py", line 21, in wrapper
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1238, in __torch_dispatch__
[rank0]:     return self.dispatch(func, types, args, kwargs)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1692, in dispatch
[rank0]:     return self._cached_dispatch_impl(func, types, args, kwargs)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1339, in _cached_dispatch_impl
[rank0]:     output = self._dispatch_impl(func, types, args, kwargs)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2013, in _dispatch_impl
[rank0]:     r = func(*args, **kwargs)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/_ops.py", line 716, in __call__
[rank0]:     return self._op(*args, **kwargs)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 273, in _fn
[rank0]:     result = fn(*args, **kwargs)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 141, in _fn
[rank0]:     result = fn(**bound.arguments)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/_refs/__init__.py", line 1049, in _ref
[rank0]:     a, b = _maybe_broadcast(a, b)
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/_refs/__init__.py", line 422, in _maybe_broadcast
[rank0]:     common_shape = _broadcast_shapes(
[rank0]:   File "/home1/yitingw1/mambaforge/envs/wyt_pt/lib/python3.10/site-packages/torch/_refs/__init__.py", line 411, in _broadcast_shapes
[rank0]:     raise RuntimeError(
[rank0]: RuntimeError: Attempting to broadcast a dimension of length 0 at -1! Mismatching argument at index 1 had torch.Size([0]); but expected shape should be broadcastable to [100, 120]

Minified repro

Running script:
TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 TORCH_LOGS="+dynamo,graph,graph_code,graph_breaks,recompiles,aot_graphs,aot_joint_graph,compiled_autograd_verbose" deepspeed --num_nodes 1 --num_gpus 1 deepspeed_reproducer_cpu.py

Below is deepspeed_reproducer_cpu.py

import torch
import torchvision
import torchvision.transforms as transforms
import torch.distributed as dist
import deepspeed
from deepspeed.accelerator import get_accelerator
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(32 * 32 * 3, 120)
        self.fc2 = nn.Linear(120, 10)
        self.LayerNorm1 = nn.LayerNorm(120, eps=1e-12, elementwise_affine=True)

    def forward(self, x):
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = self.LayerNorm1(x)
        x = self.fc2(x)
        return x

compile_kwargs = {"dynamic": False}
device = torch.device('cpu')

model = Net()
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
model_engine, optimizer, *_ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(),
    optimizer=optimizer,
    config="./deepspeed_config.json",
)
# torch_compile
model_engine.compile(
    compile_kwargs=compile_kwargs,
)

# dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
batch_size = 100
trainset = torchvision.datasets.CIFAR10(
    root="./DATA/CIFAR10", train=True, download=True, transform=transform
)
# process dataset
trainloader = DataLoader(
    trainset,
    batch_size=batch_size,
    sampler=DistributedSampler(trainset, shuffle=True),
    num_workers=16,
    pin_memory=True,
)
progress_bar = tqdm(
    total=len(trainloader),
    desc=f"Training 1/1 epoch",
    position=0,
    leave=True,
    disable= dist.is_initialized() and dist.get_rank() != 0,
)
for epoch in range(100):
    with torch._dynamo.compiled_autograd.enable(
                torch.compile(backend=get_accelerator().get_compile_backend(), **compile_kwargs)):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            
            # forward + backward + optimize
            outputs = model_engine(inputs)
            loss = criterion(outputs, labels)
            model_engine.backward(loss)
            model_engine.step()

            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
                running_loss = 0.0
            progress_bar.update(1)
print("Finished Training")

Below is deepspeed_config.json

{
    "train_batch_size": 32, 
    "optimizer": {
        "type": "SGD",
        "params": {
            "lr": 0.001,
            "momentum": 0.9
        }
    },
    "zero_allow_untested_optimizer": true,
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": false,
        "reduce_scatter" : false,
        "contiguous_gradients" : false
    },
}

Versions

Collecting environment information...
PyTorch version: 2.5.1+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.27.4
Libc version: glibc-2.35

Python version: 3.10.0 | packaged by conda-forge | (default, Nov 20 2021, 02:24:10) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-102-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.26.3
[pip3] torch==2.5.1+cpu
[pip3] torchaudio==2.5.1+cpu
[pip3] torchvision==0.20.1+cpu
[pip3] deepspeed==0.15.3

@bdhirsh bdhirsh added the module: compiled autograd compiled_autograd label Nov 8, 2024
@xmfan
Copy link
Member

xmfan commented Nov 9, 2024

Thanks for the detailed write up and small repro. Deepspeed seems to be doing parameter offloading somewhere (I suspect it to be implemented in some backward hook) and is not compatible with our current implementation. I don't have a timeline for the fix, it requires a design on how we do offloading and torch.compile.

So the issue is that compiled autograd is inlining everything but hooks. My theory of what's happening is:

  • compiled autograd sees the hook that should be loading the parameter back into memory, but it doesn't execute it
  • compiled autograd then sees the LayerNorm, executes it, and errors out because the parameter is size 0 (still offloaded)

Splitting this into 2 parts:

  1. compiled autograd shouldn't inline into some nodes and not others: we want to achieve this by end of month

with this change, I expect to no longer error, but to have a graph break on each hook that tries to load the parameter into memory.

  1. to avoid the graph breaks, one way could be through PT2 Custom Ops, and providing a meta implementation to tell the compiler about what shape will be loaded in

@zou3519 wdyt about using custom ops to support traceable offloading

@yitingw1
Copy link
Author

Thanks @xmfan. That's really a reasonable explanation. Yes, Deepspeed Zero3 will do parameter-related process in some backward hook. I'll focus on investigating this part.
In this case, nn.Linear seems to work well with deepspeed zero3 and compile autograd. While nn.LayerNorm is not. I‘ll try to find out what causes these differences. Or do you have any ideas?
By the way, could you inform me when the first part(about compiled autograd inlining) is achieved or tell me the PR? Thanks a lot!

@zou3519
Copy link
Contributor

zou3519 commented Nov 11, 2024

to avoid the graph breaks, one way could be through PT2 Custom Ops, and providing a meta implementation to tell the compiler about what shape will be loaded in
@zou3519 wdyt about using custom ops to support traceable offloading

It sounds feasible, the thing to be careful here is that we (torch.compile) isn't good at tracing through set_/resize_, so I hope that deepspeed isn't using set_/resize_ to do the offloading. Hiding these in a custom op doesn't work all the time.

@shunting314 shunting314 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 15, 2024
@yitingw1
Copy link
Author

Hi, @xmfan @zou3519, the PR about compiled autograd inlining is #139098, right?

@yitingw1
Copy link
Author

yitingw1 commented Nov 20, 2024

And one more question, my observation about bypassing the LayerNorm weight when running deepspeed zero3:

If bypassing the LayerNorm weight by setting nn.LayerNorm(120, eps=1e-12, elementwise_affine=False) instead of elementwise_affine=True in the file deepspeed_reproducer_cpu.py, the running is ok.

It's the right conclusion when using PT2.5.0/PT2.5.1 for this small reproducer.

But for PT2.4.1, another error happens after compiled autograd graph is generated and during that compiled autograd graph is traced through dynamo. While running bert with deepspeed zero3 instead of this reproducer, this error still occurs (PT2.5.1/PT2.5/PT2.4.1).

For PT2.4.1, the error happens during dynamo traces the second grad hooks(call_hook_1). When tracing second grad hooks, it will check self.params_in_ipg_bucket[0](the first param that is put into self.params_in_ipg_bucket when tracing the first grad hooks. It's for later allreduce). Then the error occurs because dynamo finds the size of first param is (0,), which mismatches the true size. This first param is the weight of self.fc2 = nn.Linear(120, 10), which is in AddmmBackward0 node in compiled autograd graph.

Do you have any idea about this? Is this the same error mechanism? Thanks!

@yf225 yf225 added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Nov 20, 2024
@xmfan
Copy link
Member

xmfan commented Nov 21, 2024

Then the error occurs because dynamo finds the size of first param is (0,), which mismatches the true size.

This is the same issue, but different part of the stack. The previous issue was about deepspeed offload not being compatible with compiled autograd. This one is about deepspeed offload not being compatible with torch.compile.

Are you saying that this is already fixed in 2.5? The safest workaround is to graph break the code that's loading the params into memory again. By chance, do you know where those are implemented in deepspeed?

@yitingw1
Copy link
Author

yitingw1 commented Nov 22, 2024

Are you saying that this is already fixed in 2.5?

I don't think so. The error disappears using PT2.5.0/PT2.5.1 for this small reproducer. But for the real workload, the error still occurs when running bert with deepspeed zero3 using PT2.5.0/PT2.5.1.

The safest workaround is to graph break the code that's loading the params into memory again.

If suppressing this exception and falling back to eager by setting torch._dynamo.config.suppress_errors = True, this error disappears for both this reproducer and the real workload bert.
And if I comment out the _post_backward_module_hook for partition params in deepspeed, deepspeed will not release params after backward and this error also disappears for both this reproducer and the real workload bert.

@yitingw1
Copy link
Author

yitingw1 commented Nov 22, 2024

By chance, do you know where those are implemented in deepspeed?

The grad-related hooks(for allreduce) is in deepspeed/runtime/zero/stage3.py

class  DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
  ...
  def create_reduce_and_remove_grad_hooks(self):
          ...
          def wrapper(param):
              param_tmp = param.expand_as(param)
              grad_acc = param_tmp.grad_fn.next_functions[0][0]
  
              @instrument_w_nvtx
              def reduce_partition_and_remove_grads(*notneeded):
                  self.reduce_ready_partitions_and_remove_grads(param)
  
              self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads))
              self.grad_accs.append(grad_acc)
  
  def reduce_ready_partitions_and_remove_grads(self, param):
          self.reduce_independent_p_g_buckets_and_remove_grads(param)
  
  ###############Independent Partition Gradient ########################
  def reduce_independent_p_g_buckets_and_remove_grads(self, param):
          if self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size and self.elements_in_ipg_bucket > 0:
              self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.ds_numel)
              self.__reduce_and_partition_ipg_grads()
      
          self.__add_grad_to_ipg_bucket(param)
  
  @property
  def elements_in_ipg_bucket(self):
      return sum(p.ds_numel for p in self.params_in_ipg_bucket)

The error when dynamo is tracing the generated compiled autograd graph, in details when executing reduce_independent_p_g_buckets_and_remove_grads() for the second grad hook. It's executing if-statement(self.elements_in_ipg_bucket) and as it's second grad hook, there is already the first param in self.params_in_ipg_bucket, but the first param is partitioned so its size is (0,). So when tracing return sum(p.ds_numel for p in self.params_in_ipg_bucket), the error occurs:

[rank0]:V1110 23:55:36.202000 140682076239680 torch/_dynamo/output_graph.py:2033] [45/0] create_graph_input L_hooks_7_closure_0_cell_contents_closure_1_cell_contents_params_in_ipg_bucket_0_ L['hooks'][7].__closure__[0].cell_contents.__closure__[1].cell_contents.params_in_ipg_bucket[0]
[rank0]:V1110 23:55:36.202000 140682076239680 torch/_dynamo/variables/builder.py:2268] [45/0] wrap_to_fake L['hooks'][7].__closure__[0].cell_contents.__closure__[1].cell_contents.params_in_ipg_bucket[0] (0,) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>], constraint_sizes=[None], view_base_context=None, tensor_source=GetItemSource(base=AttrSource(base=AttrSource(base=GetItemSource(base=AttrSource(base=AttrSource(base=GetItemSource(base=AttrSource(base=GetItemSource(base=LocalSource(local_name='hooks', cell_or_freevar=False), index=7, index_is_slice=False), member='__closure__'), index=0, index_is_slice=False), member='cell_contents'), member='__closure__'), index=1, index_is_slice=False), member='cell_contents'), member='params_in_ipg_bucket'), index=0, index_is_slice=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.nn.parameter.Parameter'>
[rank0]:E1110 23:55:36.202000 140682076239680 torch/fx/experimental/recording.py:281] [45/0] failed while running _create_symbolic_sizes_strides_storage_offset(*((10, 120), (120, 1), 0, [False, False], AttrSource(base=GetItemSource(base=AttrSource(base=AttrSource(base=GetItemSource(base=AttrSource(base=AttrSource(base=GetItemSource(base=AttrSource(base=GetItemSource(base=LocalSource(local_name='hooks', cell_or_freevar=False), index=7, index_is_slice=False), member='__closure__'), index=0, index_is_slice=False), member='cell_contents'), member='__closure__'), index=1, index_is_slice=False), member='cell_contents'), member='params_in_ipg_bucket'), index=0, index_is_slice=False), member='grad')), **{'symbolic_context': StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>], constraint_sizes=[None], view_base_context=None, tensor_source=GetItemSource(base=AttrSource(base=AttrSource(base=GetItemSource(base=AttrSource(base=AttrSource(base=GetItemSource(base=AttrSource(base=GetItemSource(base=LocalSource(local_name='hooks', cell_or_freevar=False), index=7, index_is_slice=False), member='__closure__'), index=0, index_is_slice=False), member='cell_contents'), member='__closure__'), index=1, index_is_slice=False), member='cell_contents'), member='params_in_ipg_bucket'), index=0, index_is_slice=False), shape_env_to_source_to_symbol_cache={140681157770208: {"L['hooks'][7].__closure__[0].cell_contents.__closure__[1].cell_contents.params_in_ipg_bucket[0].size()[0]": 0, "L['hooks'][7].__closure__[0].cell_contents.__closure__[1].cell_contents.params_in_ipg_bucket[0].storage_offset()": 0}})})

So torch.compile without compiled autograd will eagerly execute grad-related hooks(for allreduce), while torch.compile with compiled autograd trys to include grad-related hooks in backward graph, then the error occurs, right? So why is deepspeed offload not compatible with torch.compile but eager is compatible?
Is it because dynamo has a tighter restrictions? Here is my suppose: Dynamo has a guard for p size. When tracing return sum(p.ds_numel for p in self.params_in_ipg_bucket), dynamo finds p size is changed(became(0,)) and mismatches the guard which leads to the error. While for eager execution, there is no such guard so it runs successfully.

@yitingw1
Copy link
Author

This one is about deepspeed offload not being compatible with torch.compile.

Is it related to param.data = replicated_tensor.data as microsoft/DeepSpeed#6773 says?

@xmfan
Copy link
Member

xmfan commented Nov 25, 2024

If suppressing this exception and falling back to eager by setting torch._dynamo.config.suppress_errors = True, this error disappears for both this reproducer and the real workload bert.

I'd advise caution when using suppress_errors. It will fallback the entire graph you are capturing back to eager. Here, I believe only a small portion of that graph needs to be graph broken on.

I think this should work:

@torch._dynamo.disable
def reduce_partition_and_remove_grads(...)

So when tracing return sum(p.ds_numel for p in self.params_in_ipg_bucket), the error occurs

Is there a 2.5.1 repro for this error? I suspect this to be issue 1 I previously mentioned: compiled autograd shouldn't inline into some nodes and not others

@yitingw1
Copy link
Author

Thanks for the reminder of suppress_errors usages!
And I've tried

@torch._dynamo.disable
def reduce_partition_and_remove_grads(...)

It works for current deepspeed code. But after applying a robust param.register_post_accumulate_grad_hook(..) as microsoft/DeepSpeed#6773. There are still errors. More details in #141646.

As for

So when tracing return sum(p.ds_numel for p in self.params_in_ipg_bucket), the error occurs

I can't make a repro for PT2.5.1. It runs successfully in this reproducer but still fails in real workload bert. But the error in #141646 seems similar. They are all related to zero3 grad hooks and the param size(0,) causes error.

@yitingw1
Copy link
Author

yitingw1 commented Nov 27, 2024

Is there any chance for torch.nn.parameter.Parameter to maintain size and release data storage at the same time?
In deepspeed zero3, it uses below code to release data storage but will cause size to be (0,).

def free_param(param: Parameter) -> None:
    ...
    param.data = torch.empty(0, dtype=param.dtype, device=param.device)

I've tried param.data=param.data.to("meta"), but it will raise error:

RuntimeError: Attempted to call `variable.set_data(tensor)`, 
but `variable` and `tensor` have incompatible tensor type.

@xmfan
Copy link
Member

xmfan commented Dec 6, 2024

Hi, @xmfan @zou3519, the PR about compiled autograd inlining is #139098, right?

We've been working on a branch over at https://github.com/pytorch/pytorch/tree/fca2. Recently got parity, so we'll cleanup and begin landing.

Is there any chance for torch.nn.parameter.Parameter to maintain size and release data storage at the same time?

It sounds wrong for size and storage use to mismatch each other, at least for the base Tensor class. We haven't thought too much about torch.compile + offloading use cases, so I'm not sure if a framework solution will be coming soon.

Fundamentally, we need our tracing mechanism to know about the offloading state of a tensor. But I don't see a consistent implementation of offload across frameworks, and looking at the deepspeed implementation, it seems unlikely that we can get FakeTensors to understand all the state manipulation of Tensor.ds_tensor. The easiest way for the compiler to understand something is to describe the logic in PyTorch operators, and add operators if none meet the bill.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: compiled autograd compiled_autograd oncall: distributed Add this issue/PR to distributed oncall triage queue oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants