Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trouble with the backward pass in ZeRO 3 #846

Closed
StellaAthena opened this issue Mar 10, 2021 · 18 comments
Closed

Trouble with the backward pass in ZeRO 3 #846

StellaAthena opened this issue Mar 10, 2021 · 18 comments

Comments

@StellaAthena
Copy link

StellaAthena commented Mar 10, 2021

I have a custom megatron model and a corresponding custom DeepSpeed. I believe that I have incorporated your recent update correctly, but when I try to train a ZeRO 3 model I get the error RuntimeError: The size of tensor a (171) must match the size of tensor b (169) at non-singleton dimension 0.

When I turn off CPU adam, I instead get this error RuntimeError: start (0) + length (174763) exceeds dimension size (174761)

I notice in both cases the shape of a tensor seems to be off by 2, but I have no idea what's causing this. My code is overall extremely similar to yours, though as I note at deepspeedai/DeepSpeedExamples#92 I cannot get your code to run either (though for different reasons).

@StellaAthena
Copy link
Author

Here's the full stacktrace

10.141.246.153: setting training data start iteration to 0
10.141.246.153: setting validation data start iteration to 0
10.141.246.153: done with setups ...
10.141.246.153: time (ms) | model and optimizer: 19888.00 | train/valid/test data iterators: 662.94
10.141.246.153: training ...
10.141.246.153: /home/mchorse/gpt-neox/megatron/mpu/cross_entropy.py:110: UserWarning: Output 0 of _ReduceFromModelParallelRegionBackward is a view and its base or another view of its base has been modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is deprecated and will be forbidden starting version 1.6. You can remove this warning by cloning the output of the custom Function. (Triggered internally at  /pytorch/torch/csrc/autograd/variable.cpp:547.)
10.141.246.153:   return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target)
10.141.246.153: /home/mchorse/gpt-neox/megatron/mpu/cross_entropy.py:110: UserWarning: Output 0 of _ReduceFromModelParallelRegionBackward is a view and its base or another view of its base has been modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is deprecated and will be forbidden starting version 1.6. You can remove this warning by cloning the output of the custom Function. (Triggered internally at  /pytorch/torch/csrc/autograd/variable.cpp:547.)
10.141.246.153:   return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target)
10.141.246.153: /home/mchorse/gpt-neox/megatron/mpu/cross_entropy.py:110: UserWarning: Output 0 of _ReduceFromModelParallelRegionBackward is a view and its base or another view of its base has been modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is deprecated and will be forbidden starting version 1.6. You can remove this warning by cloning the output of the custom Function. (Triggered internally at  /pytorch/torch/csrc/autograd/variable.cpp:547.)
10.141.246.153:   return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target)
10.141.246.153: /home/mchorse/gpt-neox/megatron/mpu/cross_entropy.py:110: UserWarning: Output 0 of _ReduceFromModelParallelRegionBackward is a view and its base or another view of its base has been modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is deprecated and will be forbidden starting version 1.6. You can remove this warning by cloning the output of the custom Function. (Triggered internally at  /pytorch/torch/csrc/autograd/variable.cpp:547.)
10.141.246.153:   return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target)
10.141.246.153: /home/mchorse/gpt-neox/megatron/mpu/cross_entropy.py:110: UserWarning: Output 0 of _ReduceFromModelParallelRegionBackward is a view and its base or another view of its base has been modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is deprecated and will be forbidden starting version 1.6. You can remove this warning by cloning the output of the custom Function. (Triggered internally at  /pytorch/torch/csrc/autograd/variable.cpp:547.)
10.141.246.153:   return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target)
10.141.246.153: /home/mchorse/gpt-neox/megatron/mpu/cross_entropy.py:110: UserWarning: Output 0 of _ReduceFromModelParallelRegionBackward is a view and its base or another view of its base has been modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is deprecated and will be forbidden starting version 1.6. You can remove this warning by cloning the output of the custom Function. (Triggered internally at  /pytorch/torch/csrc/autograd/variable.cpp:547.)
10.141.246.153:   return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target)
10.141.246.153: Traceback (most recent call last):
10.141.246.153:   File "pretrain_gpt2.py", line 191, in <module>
10.141.246.153:     pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
10.141.246.153:   File "/home/mchorse/gpt-neox/megatron/training.py", line 109, in pretrain
10.141.246.153:     iteration = train(forward_step_func,
10.141.246.153:   File "/home/mchorse/gpt-neox/megatron/training.py", line 547, in train
10.141.246.153:     loss_dict, skipped_iter = train_step(forward_step_func,
10.141.246.153:   File "/home/mchorse/gpt-neox/megatron/training.py", line 378, in train_step
10.141.246.153:     backward_step(optimizer, model, loss)
10.141.246.153:   File "/home/mchorse/gpt-neox/megatron/training.py", line 323, in backward_step
10.141.246.153:     model.backward(loss)
10.141.246.153:   File "/home/mchorse/gpt-neox/src/deepspeed/deepspeed/runtime/engine.py", line 988, in backward
10.141.246.153:     self.optimizer.backward(loss)
10.141.246.153:   File "/home/mchorse/gpt-neox/src/deepspeed/deepspeed/runtime/zero/stage3.py", line 2532, in backward
10.141.246.153:     self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
10.141.246.153:   File "/home/mchorse/gpt-neox/src/deepspeed/deepspeed/runtime/fp16/loss_scaler.py", line 53, in backward
10.141.246.153:     scaled_loss.backward(retain_graph=retain_graph)
10.141.246.153:   File "/usr/local/lib/python3.8/dist-packages/torch/tensor.py", line 245, in backward
10.141.246.153:     torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
10.141.246.153:   File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 145, in backward
10.141.246.153:     Variable._execution_engine.run_backward(
10.141.246.153:   File "/usr/local/lib/python3.8/dist-packages/torch/autograd/function.py", line 89, in apply
10.141.246.153:     return self._forward_cls.backward(self, *args)  # type: ignore
10.141.246.153:   File "/home/mchorse/gpt-neox/megatron/mpu/random.py", line 310, in backward
10.141.246.153:     torch.autograd.backward(outputs, args)
10.141.246.153:   File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 145, in backward
10.141.246.153:     Variable._execution_engine.run_backward(
10.141.246.153:   File "/home/mchorse/gpt-neox/src/deepspeed/deepspeed/runtime/zero/stage3.py", line 1398, in reduce_partition_and_remove_grads
10.141.246.153:     self.reduce_ready_partitions_and_remove_grads(param, i)
10.141.246.153:   File "/home/mchorse/gpt-neox/src/deepspeed/deepspeed/runtime/zero/stage3.py", line 1715, in reduce_ready_partitions_and_remove_grads
10.141.246.153:     self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
10.141.246.153:   File "/home/mchorse/gpt-neox/src/deepspeed/deepspeed/runtime/zero/stage3.py", line 1428, in reduce_independent_p_g_buckets_and_remove_grads
10.141.246.153:     self.reduce_ipg_grads()
10.141.246.153:   File "/home/mchorse/gpt-neox/src/deepspeed/deepspeed/runtime/zero/stage3.py", line 1685, in reduce_ipg_grads
10.141.246.153:     self.partition_previous_reduced_grads()
10.141.246.153:   File "/home/mchorse/gpt-neox/src/deepspeed/deepspeed/runtime/zero/stage3.py", line 1665, in partition_previous_reduced_grads
10.141.246.153:     self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(
10.141.246.153:   File "/home/mchorse/gpt-neox/src/deepspeed/deepspeed/runtime/zero/stage3.py", line 1566, in async_inplace_copy_grad_to_fp32_buffer_from_gpu
10.141.246.153:     fp32_grad_tensor.copy_(src_tensor, non_blocking=True)
10.141.246.153: RuntimeError: The size of tensor a (171) must match the size of tensor b (169) at non-singleton dimension 0

cc'ing my colleagues @sdtblck @ShivanshuPurohit

@samyam
Copy link
Contributor

samyam commented Mar 10, 2021

@StellaAthena, thank you for checking out Z3 Offload, and pointing this issue. I may have a potential fix, working on it now. Will keep you posted

@ShadenSmith ShadenSmith linked a pull request Mar 10, 2021 that will close this issue
@samyam
Copy link
Contributor

samyam commented Mar 11, 2021

@StellaAthena it will take us a few days to merge the fix in. In the meantime, can you try the branch with the fix: https://github.com/microsoft/DeepSpeed/tree/fix-misaligned-grad

@StellaAthena
Copy link
Author

@samyam I pulled that and it runs! Now time to figure out if my model is actually running faster....

@StellaAthena
Copy link
Author

StellaAthena commented Mar 12, 2021

Actually, it looks like the loss isn't going down, and those FLOPS/s/GPU numbers look extremely suspicious.

@samyam
Copy link
Contributor

samyam commented Mar 12, 2021

@StellaAthena that's strange.. Did you just turn on Stage 3 in the config file or, did you also register the necessary external parameters?

If you have not registered the external parameters yet, you can find instructions on which parameters needs to be registered here: https://www.deepspeed.ai/tutorials/zero/#training-trillion-scale-models-with-zero-3-offload

If you have already done this, can you please point me to the commit that contains the changes you made for Z3 then I can take a look and see if anything else is missing.

@StellaAthena
Copy link
Author

StellaAthena commented Mar 12, 2021

The commit history is a bit of a mess, but we made a consolidated squash commit here that shows the difference between our main branch and the ZeRO-3 integration. The code generally follows your Megatron repo, though one major difference is that we've reformulated and consolidated how we handle arguments. I'm using the config file configs/medium-ZeRO3.yml. I also tried 2-7B-ZeRO3.yml but couldn't shake the OOM errors, even when using 6 A100s (if offload is working correctly that shouldn't happen, right?). The run I shared above was also on 6 A100s.

@salanki helps us with the hardware and network topology side of things.

@samyam
Copy link
Contributor

samyam commented Mar 12, 2021

@StellaAthena the model size you can run would depend on how much CPU memory you have with Offload. Generally a 10B parameter will take about 200 GB of CPU memory with offload. If you can give some more details on the system you are running (exact number of GPUs on a node, number of nodes, exact amount of CPU per node), I can give you an estimation of what is that max model size you should be able to run with Z3 Offload.

Regarding your port of Z3, I think the issue might be that you are initializing some of the embedding parameters outside the class where the parameters were created. To do it correctly, you need to first gather those parameters before initializing them, as shown here: https://github.com/microsoft/DeepSpeedExamples/blob/20ea07a2a069696abec212e25476a9bf76aced70/Megatron-LM-v1.1.5-ZeRO3/megatron/model/language_model.py#L133.

This was the very last step of our tutorial: https://www.deepspeed.ai/tutorials/zero/#training-trillion-scale-models-with-zero-3-offload so its easy to miss

From a cursory look at your code base, the place where you need to make this changes are here:
https://github.com/EleutherAI/gpt-neox/blob/630575ff1b84e491921da616ca5e3c34eb02d865/megatron/model/language_model.py#L162
https://github.com/EleutherAI/gpt-neox/blob/630575ff1b84e491921da616ca5e3c34eb02d865/megatron/model/language_model.py#L175

If there are other places where you access parameters outside of the module where it was created, then you need to do the Gather there as well, except for if its in the forward pass. Then it will be handled by the register_external_parameters

Please let us know if this fixes your issue.

@ShivanshuPurohit
Copy link

Just to be clear, it means every time self.init_method() is called, it should be called as:

with deepspeed.zero.GatheredParameters(self.any_embedding.weight, modifier_rank=0):
     self.init_method(self.any_embedding.weight)

Do I understand it correctly? Because deepspeed.zero.GatheredParameters() is only called once in the DeepSpeedExamples language_model.py

@StellaAthena
Copy link
Author

@samyam It looks like that made no difference. You can see the log for the new run here and the changes I made here. I made the two changes you suggested, along with the same at L200

I set modifier_rank=0 as that's what the docs seem to suggest.

@samyam
Copy link
Contributor

samyam commented Mar 15, 2021

@ShivanshuPurohit your understanding is correct.

@StellaAthena I cloned your branch from here and was able to repro your results with the small-zero3.yml config. I do see that the loss does not drop, but I also noticed that the loss does not drop regardless of whether ZeRO-3 is enabled or ZeRO is disabled completely using stage 0. I also tried turning off DeepSpeed entirely, and the loss still doesn't drop even without DeepSpeed. So the issue does not seem to be related to DeepSpeed or ZeRO Stage 3.

I did notice that when you use DeepSpeed with pipeline parallelism model using config/small.yml, then the loss drops, but when I set pipeline_parallelism=0, then the loss stays the same regardless of whether DeepSpeed is enabled or disabled.

May I suggest trying out the Megatron example we created to work with ZeRO-3 from here.

Adding @ShadenSmith for visibility and any feedback he may have regarding the megatron with DeepSpeed pipeline parallelism example. @StellaAthena I am assuming here that your current code is based off of this example?

@StellaAthena
Copy link
Author

Very interesting.

I ran configs/medium.yml today on the main branch and it worked fine. When you did a successful run with configs/small.yml were you using the main branch or were you using code that already had ZeRO-3 integrated into it?

I made a code diff between your example code and the code on our main branch. On first pass nothing looks anomalous, but I’ll go over it with a comb whenever I can make time (unfortunately I got to do my day job at some point!).

@samyam
Copy link
Contributor

samyam commented Mar 15, 2021

No worries. Please take your time. I was using branch with ZeRO-3 integrated I believe:
image

@StellaAthena
Copy link
Author

Okay, so there are two potential failure points:

  1. Going from PipelineParallel to non-pipeline
  2. Integrating ZeRO3 aka going from main to zero-3-merged
    Seems like the obvious next step is to double check that non-pipeline works on main. I’m pretty sure it does, which should significantly narrow down where the bug might be.

@ShivanshuPurohit can you take a look at this?

@ShivanshuPurohit
Copy link

Can confirm. For some reason setting pipeline parallelism to 0 (currently on main), keeps the loss from reducing. Here is the run with "pipe-parallel-stage": 0 on main. While keeping it at 1 stabilizes learning as seen here. The only difference in these runs is pipeline parallel stage being 0 and 1 respectively.

@StellaAthena
Copy link
Author

Look at that beautiful learning curve! The problem was on our end, we were handling non-pipeline models incorrectly. Once we got that fixed ZeRO-3 ran straight away. The model is still not as efficient as I had hoped (6.1 e12 flops/s/gpu) but this is with extremely unoptimized settings. Time to do benchmarking!

Capture

@samyam
Copy link
Contributor

samyam commented Mar 16, 2021

@StellaAthena this is great to hear! Please do keep us posted on the benchmarking results. :)

@jeffra
Copy link
Collaborator

jeffra commented Apr 20, 2021

Please reopen if there are further issues

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants