Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Are ZeRO CPU offload and gradient accumulation compatible? #671

Closed
jncasey opened this issue Jan 15, 2021 · 5 comments
Closed

Are ZeRO CPU offload and gradient accumulation compatible? #671

jncasey opened this issue Jan 15, 2021 · 5 comments

Comments

@jncasey
Copy link

jncasey commented Jan 15, 2021

I'm trying out @stas00 's HuggingFace DeepSpeed integration and it's super cool!

But I'm running into an error when I try to enable both cpu offload and gradient accumulation at the same time, and I'm not sure if my problem is on the HuggingFace side, or the DeepSpeed side, or (most likely) between my chair and keyboard. Since this post is in the DeepSpeed project, I'll leave out the HuggingFace specifics for now.

My training script will run just fine with either cpu_offload=true or --gradient_accumulation_steps > 1, but if I try using both, it throws the following:

  File "bin/train.py", line 306, in <module>
    main()
  File "bin/train.py", line 265, in main
    train_result = trainer.train()
  File "/opt/miniconda3/envs/hf/lib/python3.8/site-packages/transformers/trainer.py", line 921, in train
    self.optimizer.step()
  File "/opt/miniconda3/envs/hf/lib/python3.8/site-packages/deepspeed/runtime/zero/stage2.py", line 1378, in step
    self.complete_grad_norm_calculation_for_cpu_offload(
  File "/opt/miniconda3/envs/hf/lib/python3.8/site-packages/deepspeed/runtime/zero/stage2.py", line 881, in complete_grad_norm_calculation_for_cpu_offload
    param_norm = self.norm_for_param_grads[param_id]
KeyError: 0
Traceback (most recent call last):
  File "bin/train.py", line 306, in <module>
    main()
  File "bin/train.py", line 265, in main
    train_result = trainer.train()
  File "/opt/miniconda3/envs/hf/lib/python3.8/site-packages/transformers/trainer.py", line 921, in train
    self.optimizer.step()
  File "/opt/miniconda3/envs/hf/lib/python3.8/site-packages/deepspeed/runtime/zero/stage2.py", line 1378, in step
    self.complete_grad_norm_calculation_for_cpu_offload(
  File "/opt/miniconda3/envs/hf/lib/python3.8/site-packages/deepspeed/runtime/zero/stage2.py", line 881, in complete_grad_norm_calculation_for_cpu_offload
    param_norm = self.norm_for_param_grads[param_id]
KeyError: 130

I'm assuming it's because I haven't configured DeepSpeed or my optimizer correctly. But before I dig too much deeper, I wanted to make sure that using both was supported. I haven't seen anything in the documentation that would indicate that it wasn't.

@stas00 have you tried both simultaneously in your HuggingFace integration testing?

This is my deepspeed config json:

{
  "fp16": {
    "enabled": true,
    "loss_scale": 0,
    "loss_scale_window": 1000,
    "hysteresis": 2,
    "min_loss_scale": 1
  },

  "zero_optimization": {
    "stage": 2,
    "allgather_partitions": true,
    "allgather_bucket_size": 1e8,
    "reduce_scatter": true,
    "reduce_bucket_size": 1e8,
    "overlap_comm": true,
    "contiguous_gradients": true,
    "cpu_offload": true
  },

  "optimizer": {
    "type": "Adam",
    "params": {
      "adam_w_mode": true,
      "lr": 3e-5,
      "betas": [ 0.9, 0.999 ],
      "eps": 1e-8,
      "weight_decay": 0.1
    }
  },

  "scheduler": {
    "type": "WarmupLR",
    "params": {
        "warmup_min_lr": 0,
        "warmup_max_lr": 3e-5,
        "warmup_num_steps": 500
    }
  }
}
@stas00
Copy link
Collaborator

stas00 commented Jan 15, 2021

@stas00 have you tried both simultaneously in your HuggingFace integration testing?

I haven't tried that combination yet, and if I do I get the same error as you.

Let me investigate to ensure it's not something missing on our side.

@stas00
Copy link
Collaborator

stas00 commented Jan 15, 2021

Oh, OK, I didn't follow the instructions, so the problem is in my code:

Try this patch:

diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py
index a58119f8..a6ca42a9 100755
--- a/src/transformers/trainer.py
+++ b/src/transformers/trainer.py
@@ -931,7 +931,9 @@ class Trainer:
                             )

                     # Optimizer step
-                    if is_torch_tpu_available():
+                    if self.deepspeed:
+                        self.deepspeed.step()
+                    elif is_torch_tpu_available():
                         xm.optimizer_step(self.optimizer)
                     elif self.use_amp:
                         self.scaler.step(self.optimizer)

@jncasey
Copy link
Author

jncasey commented Jan 15, 2021

Thanks! I'll test this later this afternoon

@stas00
Copy link
Collaborator

stas00 commented Jan 15, 2021

huggingface/transformers#9622 should fix it, plus added test.

@jncasey
Copy link
Author

jncasey commented Jan 15, 2021

Awesome!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants