-
Notifications
You must be signed in to change notification settings - Fork 23k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Compiled_autograd] running nn.LayerNorm failed for torch.compile with compiled_autograd when deepspeed Zero3 #140091
Comments
Thanks for the detailed write up and small repro. Deepspeed seems to be doing parameter offloading somewhere (I suspect it to be implemented in some backward hook) and is not compatible with our current implementation. I don't have a timeline for the fix, it requires a design on how we do offloading and torch.compile. So the issue is that compiled autograd is inlining everything but hooks. My theory of what's happening is:
Splitting this into 2 parts:
with this change, I expect to no longer error, but to have a graph break on each hook that tries to load the parameter into memory.
@zou3519 wdyt about using custom ops to support traceable offloading |
Thanks @xmfan. That's really a reasonable explanation. Yes, Deepspeed Zero3 will do parameter-related process in some backward hook. I'll focus on investigating this part. |
It sounds feasible, the thing to be careful here is that we (torch.compile) isn't good at tracing through |
And one more question, my observation about bypassing the LayerNorm weight when running deepspeed zero3:
It's the right conclusion when using PT2.5.0/PT2.5.1 for this small reproducer. But for PT2.4.1, another error happens after compiled autograd graph is generated and during that compiled autograd graph is traced through dynamo. While running bert with deepspeed zero3 instead of this reproducer, this error still occurs (PT2.5.1/PT2.5/PT2.4.1). For PT2.4.1, the error happens during dynamo traces the second grad hooks(call_hook_1). When tracing second grad hooks, it will check Do you have any idea about this? Is this the same error mechanism? Thanks! |
This is the same issue, but different part of the stack. The previous issue was about deepspeed offload not being compatible with compiled autograd. This one is about deepspeed offload not being compatible with torch.compile. Are you saying that this is already fixed in 2.5? The safest workaround is to graph break the code that's loading the params into memory again. By chance, do you know where those are implemented in deepspeed? |
I don't think so. The error disappears using PT2.5.0/PT2.5.1 for this small reproducer. But for the real workload, the error still occurs when running bert with deepspeed zero3 using PT2.5.0/PT2.5.1.
If suppressing this exception and falling back to eager by setting |
The grad-related hooks(for allreduce) is in deepspeed/runtime/zero/stage3.py class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
...
def create_reduce_and_remove_grad_hooks(self):
...
def wrapper(param):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
@instrument_w_nvtx
def reduce_partition_and_remove_grads(*notneeded):
self.reduce_ready_partitions_and_remove_grads(param)
self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads))
self.grad_accs.append(grad_acc)
def reduce_ready_partitions_and_remove_grads(self, param):
self.reduce_independent_p_g_buckets_and_remove_grads(param)
###############Independent Partition Gradient ########################
def reduce_independent_p_g_buckets_and_remove_grads(self, param):
if self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size and self.elements_in_ipg_bucket > 0:
self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.ds_numel)
self.__reduce_and_partition_ipg_grads()
self.__add_grad_to_ipg_bucket(param)
@property
def elements_in_ipg_bucket(self):
return sum(p.ds_numel for p in self.params_in_ipg_bucket) The error when dynamo is tracing the generated compiled autograd graph, in details when executing
So torch.compile without compiled autograd will eagerly execute grad-related hooks(for allreduce), while torch.compile with compiled autograd trys to include grad-related hooks in backward graph, then the error occurs, right? So why is deepspeed offload not compatible with torch.compile but eager is compatible? |
Is it related to |
I'd advise caution when using suppress_errors. It will fallback the entire graph you are capturing back to eager. Here, I believe only a small portion of that graph needs to be graph broken on. I think this should work:
Is there a 2.5.1 repro for this error? I suspect this to be issue 1 I previously mentioned: |
Thanks for the reminder of suppress_errors usages! @torch._dynamo.disable
def reduce_partition_and_remove_grads(...) It works for current deepspeed code. But after applying a robust param.register_post_accumulate_grad_hook(..) as microsoft/DeepSpeed#6773. There are still errors. More details in #141646. As for
I can't make a repro for PT2.5.1. It runs successfully in this reproducer but still fails in real workload bert. But the error in #141646 seems similar. They are all related to zero3 grad hooks and the param size(0,) causes error. |
Is there any chance for def free_param(param: Parameter) -> None:
...
param.data = torch.empty(0, dtype=param.dtype, device=param.device) I've tried
|
We've been working on a branch over at https://github.com/pytorch/pytorch/tree/fca2. Recently got parity, so we'll cleanup and begin landing.
It sounds wrong for size and storage use to mismatch each other, at least for the base Tensor class. We haven't thought too much about torch.compile + offloading use cases, so I'm not sure if a framework solution will be coming soon. Fundamentally, we need our tracing mechanism to know about the offloading state of a tensor. But I don't see a consistent implementation of offload across frameworks, and looking at the deepspeed implementation, it seems unlikely that we can get FakeTensors to understand all the state manipulation of Tensor.ds_tensor. The easiest way for the compiler to understand something is to describe the logic in PyTorch operators, and add operators if none meet the bill. |
🐛 Describe the bug
When running a simple model including torch.nn.LayerNorm using deepspeed zero3 with torch.compile and compiled_autograd. An error occurs:
We first found this error in BERT model with deepspeed Zero3 with torch.compile and compiled_autograd.
Investigation
The error: "RuntimeError: Attempting to broadcast a dimension of length 0 at -1! Mismatching argument at index 1 had torch.Size([0]); but expected shape should be broadcastable to [100, 120]"
It occurs when compiled autograd tries to trace the backward graph.
It appears in LayerNorm backward decompositions. It tries to broadcast weight_cast(torch.Size([0]) to grad_out_cast' shape([100, 120]) and fails.
If bypassing the LayerNorm weight by setting
nn.LayerNorm(120, eps=1e-12, elementwise_affine=False)
instead ofelementwise_affine=True
in the file deepspeed_reproducer_cpu.py, the running is ok.cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @chauhang @penguinwu @xmfan @yf225 @ezyang @albanD @gqchen @pearu @nikitaved @soulitzer @Varal7
Error logs
Minified repro
Running script:
TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 TORCH_LOGS="+dynamo,graph,graph_code,graph_breaks,recompiles,aot_graphs,aot_joint_graph,compiled_autograd_verbose" deepspeed --num_nodes 1 --num_gpus 1 deepspeed_reproducer_cpu.py
Below is deepspeed_reproducer_cpu.py
Below is deepspeed_config.json
Versions
Collecting environment information...
PyTorch version: 2.5.1+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.27.4
Libc version: glibc-2.35
Python version: 3.10.0 | packaged by conda-forge | (default, Nov 20 2021, 02:24:10) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-102-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.26.3
[pip3] torch==2.5.1+cpu
[pip3] torchaudio==2.5.1+cpu
[pip3] torchvision==0.20.1+cpu
[pip3] deepspeed==0.15.3
The text was updated successfully, but these errors were encountered: