Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zero3 hangs in inference #860

Closed
stas00 opened this issue Mar 14, 2021 · 15 comments
Closed

zero3 hangs in inference #860

stas00 opened this issue Mar 14, 2021 · 15 comments

Comments

@stas00
Copy link
Collaborator

stas00 commented Mar 14, 2021

So training works with zero3 and then I do inference calling deepspeed.forward() and while it works on a very small sample, with just slightly bigger sample it hangs with 100% gpu utilization:

Thread 0x00007f57caf71740 (most recent call first):
  File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/cuda/streams.py", line 95 in synchronize
  File "/mnt/nvme1/code/github/00optimize/DeepSpeed/deepspeed/runtime/zero/stage3.py", line 490 in _synchronize_communication
  File "/mnt/nvme1/code/github/00optimize/DeepSpeed/deepspeed/runtime/zero/stage3.py", line 406 in fetch_sub_module
  File "/mnt/nvme1/code/github/00optimize/DeepSpeed/deepspeed/runtime/zero/stage3.py", line 1139 in pre_sub_module_forward_function
  File "/mnt/nvme1/code/github/00optimize/DeepSpeed/deepspeed/runtime/zero/stage3.py", line 1071 in _pre_forward_module_hook
  File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 881 in _call_impl
  File "/mnt/nvme1/code/huggingface/transformers-ds-zero-3/src/transformers/models/t5/modeling_t5.py", line 451 in project
  File "/mnt/nvme1/code/huggingface/transformers-ds-zero-3/src/transformers/models/t5/modeling_t5.py", line 474 in forward
  File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 892 in _call_impl
  File "/mnt/nvme1/code/huggingface/transformers-ds-zero-3/src/transformers/models/t5/modeling_t5.py", line 540 in forward
  File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 892 in _call_impl
  File "/mnt/nvme1/code/huggingface/transformers-ds-zero-3/src/transformers/models/t5/modeling_t5.py", line 633 in forward
  File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 892 in _call_impl
  File "/mnt/nvme1/code/huggingface/transformers-ds-zero-3/src/transformers/models/t5/modeling_t5.py", line 954 in forward
  File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 892 in _call_impl
  File "/mnt/nvme1/code/huggingface/transformers-ds-zero-3/src/transformers/models/t5/modeling_t5.py", line 1505 in forward
  File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 892 in _call_impl
  File "/mnt/nvme1/code/github/00optimize/DeepSpeed/deepspeed/runtime/engine.py", line 893 in forward
  File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 872 in _call_impl
  File "/mnt/nvme1/code/huggingface/transformers-ds-zero-3/src/transformers/trainer_seq2seq.py", line 185 in prediction_step
  File "/mnt/nvme1/code/huggingface/transformers-ds-zero-3/src/transformers/trainer.py", line 1800 in prediction_loop
  File "/mnt/nvme1/code/huggingface/transformers-ds-zero-3/src/transformers/trainer.py", line 1647 in evaluate
  File "/mnt/nvme1/code/huggingface/transformers-ds-zero-3/src/transformers/trainer_seq2seq.py", line 74 in evaluate
  File "examples/seq2seq/run_seq2seq.py", line 607 in main
  File "examples/seq2seq/run_seq2seq.py", line 655 in <module>

the trace is from faulthandler so please read in reverse.

I'm not sure if you have inference tests - may be this can be reproduced with just model.eval()?

Config:

{
    "fp16": {
        "enabled": true,
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    "zero_optimization": {
        "stage": 3,
        "cpu_offload": true,
        "cpu_offload_params": true,
        "cpu_offload_use_pin_memory" : true,
        "overlap_comm": true,
        "contiguous_gradients": true,
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e8,
        "stage3_prefetch_bucket_size": 2e5,
        "stage3_param_persitance_threshold": 1e5,
        "reduce_bucket_size": 3e6,
        "prefetch_bucket_size": 3e6,
        "sub_group_size": 1e6
    },

    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": 3e-5,
            "betas": [0.8, 0.999],
            "eps": 1e-8,
            "weight_decay": 3e-7
        }
    },

    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": 0,
            "warmup_max_lr": 3e-5,
            "warmup_num_steps": 500
        }
    },

    "steps_per_print": 2000,
    "wall_clock_breakdown": false
}

Thanks.

@samyam
Copy link
Contributor

samyam commented Mar 15, 2021

@stas00 how many gpus are you running this on? If its more than one, can you please try running it on one GPU and see if it works. If it does, I am suspicious about two issues here:

  1. Different GPUs are running model.forward() for different number of iterations. This could happen if different GPUs are processing different number of samples, or generating different sequence length. Could you test for this by simply printing out the number of times model.forward() is called for each rank?

  2. The issue could also be that different ranks are taking different code path. I remember talking to you few weeks back about the different code branches in T5. Could something like this be happening here?

@stas00
Copy link
Collaborator Author

stas00 commented Mar 15, 2021

@stas00 how many gpus are you running this on? If its more than one, can you please try running it on one GPU and see if it works.

2 gpus. and tt works fine with a single gpu

I'll report back on your suggestions shortly.

Thank you.

@stas00
Copy link
Collaborator Author

stas00 commented Mar 15, 2021

The issue could also be that different ranks are taking different code path. I remember talking to you few weeks back about the different code branches in T5. Could something like this be happening here?

Very possible. I tried mbart and it didn't have this problem. Albeit once it failed in inference with this:

Traceback (most recent call last):
 File "examples/seq2seq/run_seq2seq.py", line 655, in <module>
 main()
 File "examples/seq2seq/run_seq2seq.py", line 607, in main
 metrics = trainer.evaluate(
 File "/mnt/nvme1/code/huggingface/transformers-ds-zero-3/src/transformers/trainer_seq2seq.py", line 74, in evaluate
 return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
 File "/mnt/nvme1/code/huggingface/transformers-ds-zero-3/src/transformers/trainer.py", line 1647, in evaluate
 output = self.prediction_loop(
 File "/mnt/nvme1/code/huggingface/transformers-ds-zero-3/src/transformers/trainer.py", line 1800, in prediction_loop
 loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
 File "/mnt/nvme1/code/huggingface/transformers-ds-zero-3/src/transformers/trainer_seq2seq.py", line 167, in prediction_step
 generated_tokens = self.model.generate(
 File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
 return func(*args, **kwargs)
 File "/mnt/nvme1/code/huggingface/transformers-ds-zero-3/src/transformers/generation_utils.py", line 1012, in generate
 return self.beam_search(
 File "/mnt/nvme1/code/huggingface/transformers-ds-zero-3/src/transformers/generation_utils.py", line 1669, in beam_search
 outputs = self(
 File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 892, in _call_impl
 result = forward_call(*input, **kwargs)
 File "/mnt/nvme1/code/huggingface/transformers-ds-zero-3/src/transformers/models/mbart/modeling_mbart.py", line 1289, in forward
 outputs = self.model(
 File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 892, in _call_impl
 result = forward_call(*input, **kwargs)
 File "/mnt/nvme1/code/huggingface/transformers-ds-zero-3/src/transformers/models/mbart/modeling_mbart.py", line 1177, in forward
 decoder_outputs = self.decoder(
 File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 892, in _call_impl
 result = forward_call(*input, **kwargs)
 File "/mnt/nvme1/code/huggingface/transformers-ds-zero-3/src/transformers/models/mbart/modeling_mbart.py", line 1048, in forward
 layer_outputs = decoder_layer(
 File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 892, in _call_impl
 result = forward_call(*input, **kwargs)
 File "/mnt/nvme1/code/huggingface/transformers-ds-zero-3/src/transformers/models/mbart/modeling_mbart.py", line 426, in forward
 hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
 File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 892, in _call_impl
Timeout (0:00:20)!
Thread 0x00007fc210c9b700 (most recent call first):
<no Python frame>

Thread 0x00007fc3a796c740 (most recent call first):
<no Python frame>
 result = forward_call(*input, **kwargs)
 File "/mnt/nvme1/code/huggingface/transformers-ds-zero-3/src/transformers/models/mbart/modeling_mbart.py", line 271, in forward
 attn_output = self.out_proj(attn_output)
 File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 881, in _call_impl
 result = hook(self, input)
 File "/mnt/nvme1/code/github/00optimize/DeepSpeed/deepspeed/runtime/zero/stage3.py", line 1064, in _pre_forward_module_hook
 self.pre_sub_module_forward_function(module)
 File "/mnt/nvme1/code/github/00optimize/DeepSpeed/deepspeed/runtime/zero/stage3.py", line 1137, in pre_sub_module_forward_function
 self.param_coordinator.prefetch_next_sub_modules(sub_module,
 File "/mnt/nvme1/code/github/00optimize/DeepSpeed/deepspeed/runtime/zero/stage3.py", line 320, in prefetch_next_sub_modules
 params_to_prefetch = self.prefetch_coordinator.get_params_to_prefetch(
 File "/mnt/nvme1/code/github/00optimize/DeepSpeed/deepspeed/runtime/zero/stage3.py", line 198, in get_params_to_prefetch
 if sub_module.id != self.sub_module_trace[self.step_id]:
IndexError: list index out of range

The failure seems to be intermittent, so it's possible there is some race condition.

@stas00
Copy link
Collaborator Author

stas00 commented Mar 15, 2021

Different GPUs are running model.forward() for different number of iterations. This could happen if different GPUs are processing different number of samples, or generating different sequence length. Could you test for this by simply printing out the number of times model.forward() is called for each rank?

You're spot on with this suggestion, one gpu seems to want to run an additional forward for specific number of samples, but not for others. I will investigate and report back.

@stas00
Copy link
Collaborator Author

stas00 commented Mar 16, 2021

I understood the problem.

During inference we have an option to generate predictions which can then be scored against BLEU, etc. Different sequences may take a different number of forward passes to complete this task.

So when one gpu finished generating its predictions quicker than the others - say it decided using a criteria that it's done at length of 10 tokens, whereas others aren't done, and say the max_length is 15, they are now stuck waiting for the first gpu to continue running forward but it will not do that.

To ensure that this is so, I hacked the code to complete the while loop till it hit max_length on all gpus repeating last forward call, and the problem went away.

I am not at all sure this hack will be acceptable as:

  1. The code where we run the generate loop - it's quite a few call frames away from the trainer and as mentioned early doesn't know anything about such special circumstances or that it's running under deepspeed (or fairscale), since it was designed to work with any model.
  2. It wastes resources running dumb forward calls and throw the results away.

But I will totally understand if you don't have any brilliant ideas to how to overcome this hurdle and we will find some way around this.

@samyam
Copy link
Contributor

samyam commented Mar 16, 2021

Thank you @stas00 for digging into this. I am glad you were able to get to the core of the problem.

I understood the problem.

During inference we have an option to generate predictions which can then be scored against BLEU, etc. Different sequences may take a different number of forward passes to complete this task.

So when one gpu finished generating its predictions quicker than the others - say it decided using a criteria that it's done at length of 10 tokens, whereas others aren't done, and say the max_length is 15, they are now stuck waiting for the first gpu to continue running forward but it will not do that.

This makes sense. This is pretty much what I was expecting as well. Since, ZeRO-3 is a single program multiple data (SPMD) approach to parallelism with coordinated data movement, all process must be running the same program, in this case the forward on the model on each process to work correctly.

To ensure that this is so, I hacked the code to complete the while loop till it hit max_length on all gpus repeating last forward call, and the problem went away.

I am not at all sure this hack will be acceptable as:

  1. The code where we run the generate loop - it's quite a few call frames away from the trainer and as mentioned early doesn't know anything about such special circumstances or that it's running under deepspeed (or fairscale), since it was designed to work with any model.

I agree that the hack is limiting but I have a slightly different view on the "designed to work with any model" part. It seems that the code is actually designed to work only with single GPU models, and is limited in that sense. As long as the model is single GPU, it will work, but it will not work with any multi-GPU model regardless of whether it is ZeRO-3 or model parallel (tensor slicing) or pipeline parallel, since each of them requires some form of special treatment that is inherent in the parallelism itself. For example, model parallelism would require the data loader to give the same sample to all GPUs, and pipeline parallelism would require the data loader to give samples only to the first stage GPU.

A potential solution here could be to extend the code to support multi-GPU inference, by allowing for adaptable variations based on the type of parallelism being used?

  1. It wastes resources running dumb forward calls and throw the results away.

This I think can be mitigated to a point that the waste in resource is minimal. Two potential solutions:

  • If the generate code can support a batch size > 1, then run with a large batch size, all running for max_len. During inference a larger batch will in general give a better throughput, and with a large batch size, the probability of getting a large sequence generated increases so the expected waste in resource will go down. Also a large batch size will significantly reduce the communication overhead of ZeRO-3.
  • If batch size > 1 is not supported, run all the generation for all the samples one after another until you are done with all the samples before doing anything else. As you noticed, as long as each process is running a forward on something, it will run fine. There will still be some wasted resource at the very end due to difference in the total number of generated tokens across all the queries, but this will be much less than running fake forward for each query.

But I will totally understand if you don't have any brilliant ideas to how to overcome this hurdle and we will find some way around this.

@stas00
Copy link
Collaborator Author

stas00 commented Mar 16, 2021

I totally hear you, @samyam, that we need to adapt this code to support parallelization. I obviously wanted to hear first if we can avoid that ;)

Thank you for your detailed reply! I totally agree!

Wrt item 2 I think the main concern is for situations where max_len happens to be much bigger than needed for the whole batch and then it's no longer about doing a few extra forward passes on a few gpus, but running extra forward passes on all gpus.

To try to explain better, let's say our max_len is 100, and all gpus completed their criteteria for a predicted output in 50 tokens, now we are going to run unnecessarily for 50 more tokens for each gpu! Is there a way to signal that all gpus have reached their criteria and synchronize not to continue running forward?

That is there is a new for a new condition that if it returns true the loop can be exited on all gpus in a synchronious way.

Implementation-wise I'm thinking that will take 2 APIs

  1. for each gpu to signal to others or gpu0 when it's done the real work and can exit this loop
  2. for each gpu to query a global state or gpu0 to know whether it should still continue running the loop when it doesn't need to

With such API in place, if all gpus finished their work in < max_len they can synchronously exit the loop early.

(And yes generate works with batches)

@samyam
Copy link
Contributor

samyam commented Mar 16, 2021

I totally hear you, @samyam, that we need to adapt this code to support parallelization. I obviously wanted to hear first if we can avoid that ;)

Of course :) I thought a bit about if this would be possible, but could not figure out a solution without code change on the client side, and eventually realized that the core of the issue is support for different types of parallelism.

Thank you for your detailed reply! I totally agree!

Wrt item 2 I think the main concern is for situations where max_len happens to be much bigger than needed for the whole batch and then it's no longer about doing a few extra forward passes on a few gpus, but running extra forward passes on all gpus.

To try to explain better, let's say our max_len is 100, and all gpus completed their criteteria for a predicted output in 50 tokens, now we are going to run unnecessarily for 50 more tokens for each gpu! Is there a way to signal that all gpus have reached their criteria and synchronize not to continue running forward?

This makes sense. I was just pointing out that the likely hood of at least one of the samples generating all 100 tokens increases proportionally with batch size, and ZeRO-3 should allow for much larger batch sizes. But since its ok to make changes to the generation code pipeline, I completely agree with doing something smarter like the solution you have below.

That is there is a new for a new condition that if it returns true the loop can be exited on all gpus in a synchronious way.

Implementation-wise I'm thinking that will take 2 APIs

  1. for each gpu to signal to others or gpu0 when it's done the real work and can exit this loop
  2. for each gpu to query a global state or gpu0 to know whether it should still continue running the loop when it doesn't need to

With such API in place, if all gpus finished their work in < max_len they can synchronously exit the loop early.

This makes sense. We could maybe simplify by doing a single all_reduce, where gpus that are done will use a tensor with 0.0 and those that are not done will use 1.0. If the result of all reduce is 0.0 then everyone can stop, otherwise gpus that are done will do fake forward.

while sync.item() > 0.0:
    p = model.forward(fake_input if am_i_done() else real_input)
    sync =torch.tensor(0.0 if am_i_done() else 1.0)
    torch.distributed.allreduce(sync)

I think you can extend this even further, by checking for the early termination condition not at the end of a batch but at the end of the entire epoch, if you can store the results of the prediction for all the samples and process them at the very end. This would result in even fewer wasted forward. I might not have understood your pipeline fully though, so take this suggestion with a grain of salt.

(And yes generate works with batches)

I

@stas00
Copy link
Collaborator Author

stas00 commented Mar 16, 2021

Thank you for the recipe, @samyam!

I think it's the safest to the use the last valid input to ensure that all forward passes of sub-layers get to run in case there is some condition on the data.

I don't think in our particular setup we could do the syncronization on the epoch-level loop.

I will save it for later, as for now we really want to make the training fast and efficient, and inference possible. We want inference so that we can quickly eval the outcome of training.

@stas00
Copy link
Collaborator Author

stas00 commented Mar 18, 2021

Closing this for now, as I have a working solution. Thanks again, @samyam

@stas00 stas00 closed this as completed Mar 18, 2021
@wenlai-lavine
Copy link

@samyam @stas00 I am also got the same error when using zero3 to finetune 'm2m100', it is fine in other model (mbart/t5) with the same data. can anyone help me to solve with this?

@stas00
Copy link
Collaborator Author

stas00 commented Feb 9, 2022

This is a really old ticket and a ton of things have changed since then in both Deepspeed and HF Transformers so I'd recommend opening a new ticket explaining the problem, including your exact versions of the packages you're using and how the issue can be reproduced.

@evros-chris
Copy link

Hi @lavine-lmu, I am facing the same issue when trying to use zero3 to fine-tune m2m100.

It is fine when I use zero2 to fine-tune m2m100. It is also fine when I use zero3 to fine-tune t5 with the same data.

Did you manage to solve this error?

I tried running the following command, as you also mentioned in: huggingface/transformers#15570

deepspeed examples/pytorch/translation/run_translation.py
--deepspeed tests/deepspeed/ds_config_zero2.json
--model_name_or_path facebook/m2m100_418M
--per_device_train_batch_size 8
--per_device_eval_batch_size 8
--output_dir output_dir --overwrite_output_dir
--fp16
--do_train --do_eval --do_predict
--max_train_samples 500 --max_eval_samples 50 --max_predict_samples 50
--num_train_epochs 3
--dataset_name wmt16 --dataset_config "ro-en"
--source_lang en --target_lang ro
--predict_with_generate --forced_bos_token ro

It works with ds_config_zero2.json but it does not work when I use ds_config_zero3.json. Can you help please?:)

@stas00
Copy link
Collaborator Author

stas00 commented Apr 10, 2022

@evros-chris, as I suggested above, this is a really old Issue and any further comments are very likely irrelevant since the code base has changed since then.

Please open a new issue with the full traceback of the error you experience, and your environment, and any other details that can be used to reproduce the problem you are experiencing.

Please avoid using "it doesn't work" alone - when reporting bugs, since we have no idea what that means.

You may tag me on the issue and I'd be happy to take a look.

Thank you.

@evros-chris
Copy link

Thank you @stas00! I have opened a new issue and tagged you here:
huggingface/transformers#16688

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants