Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training with DeepSpeed takes more GPU memory than without DeepSpeed #10929

Closed
oriyor opened this issue Mar 27, 2021 · 4 comments
Closed

Training with DeepSpeed takes more GPU memory than without DeepSpeed #10929

oriyor opened this issue Mar 27, 2021 · 4 comments
Assignees

Comments

@oriyor
Copy link

oriyor commented Mar 27, 2021

Environment info

  • transformers version: 4.5.0.dev0
  • deepspeed version: 0.3.13
  • Platform: Linux-4.15.0-66-generic-x86_64-with-Ubuntu-18.04-bionic
  • Python version: 3.6.8
  • PyTorch version (GPU?): 1.8.0 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help

@stas00

Information

I'm interested in training the large T5 models with deepspeed and huggingface. More specifically, I'm interested in fine-tuning a T5-11B model on one RTX-8000 48 GB GPU (similarly to https://huggingface.co/blog/zero-deepspeed-fairscale, #9996).

However, when I try to use deepspeed the amount of memory on the GPU increases. For example, running the example seq2seq/run_summarization.py script with T5-Small and without deepspeed takes ~6GB, and running it with deepspeed takes ~8GB.

Model I am using: T5

The problem arises when using: The official examples/seq2seq/run_summarization.py script.

Without deepspeed:
python examples/seq2seq/run_summarization.py --model_name_or_path t5-small --do_train --do_eval --dataset_name cnn_dailymail --dataset_config "3.0.0" --source_prefix "summarize: " --output_dir /tmp/tst-summarization --per_device_train_batch_size=4 --per_device_eval_batch_size=4 --overwrite_output_dir --predict_with_genera

With deepspeed:
deepspeed examples/seq2seq/run_summarization.py --model_name_or_path t5-small --do_train --do_eval --dataset_name cnn_dailymail --dataset_config "3.0.0" --source_prefix "summarize: " --output_dir /tmp/tst-summarization --per_device_train_batch_size=4 --per_device_eval_batch_size=4 --overwrite_output_dir --predict_with_generate --deepspeed examples/tests/deepspeed/ds_config.json

The tasks I am working on is:
Sequence to sequence generation.

To reproduce

Steps to reproduce the behavior:

  1. Clone transformers repo
  2. Install requirements (including deepspeed: pip install deepspeed)
  3. Run summarization example without deeepspeed:
    python examples/seq2seq/run_summarization.py --model_name_or_path t5-small --do_train --do_eval --dataset_name cnn_dailymail --dataset_config "3.0.0" --source_prefix "summarize: " --output_dir /tmp/tst-summarization --per_device_train_batch_size=4 --per_device_eval_batch_size=4 --overwrite_output_dir --predict_with_genera
  4. Run summarization example with deepspeed:
    deepspeed examples/seq2seq/run_summarization.py --model_name_or_path t5-small --do_train --do_eval --dataset_name cnn_dailymail --dataset_config "3.0.0" --source_prefix "summarize: " --output_dir /tmp/tst-summarization --per_device_train_batch_size=4 --per_device_eval_batch_size=4 --overwrite_output_dir --predict_with_generate --deepspeed examples/tests/deepspeed/ds_config.json

Expected behavior

I would expect using deepspeed would reduce the amount of memory being used by the GPU.

@oriyor
Copy link
Author

oriyor commented Mar 27, 2021

Also adding the logs from the beginning of training with deepspeed:

deepspeed examples/seq2seq/run_summarization.py --model_name_or_path t5-small --do_train --do_eval --dataset_name cnn_dailymail --dataset_config "3.0.0" --source_prefix "summarize: " --output_dir /tmp/tst-summarization --per_device_train_batch_size=4 --per_device_eval_batch_size=4 --overwrite_output_dir --predict_with_generate --deepspeed examples/tests/deepspeed/ds_config.json
[2021-03-27 17:02:34,357] [WARNING] [runner.py:117:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
[2021-03-27 17:02:34,381] [INFO] [runner.py:358:main] cmd = /media/disk1/oriyor/hf_venv_3.6/bin/python -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMF19 --master_addr=127.0.0.1 --master_port=29500 examples/seq2seq/run_summarization.py --model_name_or_path t5-small --do_train --do_eval --dataset_name cnn_dailymail --dataset_config 3.0.0 --source_prefix summarize: --output_dir /tmp/tst-summarization --per_device_train_batch_size=4 --per_device_eval_batch_size=4 --overwrite_output_dir --predict_with_generate --deepspeed examples/tests/deepspeed/ds_config.json
[2021-03-27 17:02:34,981] [INFO] [launch.py:80:main] WORLD INFO DICT: {'localhost': [0]}
[2021-03-27 17:02:34,981] [INFO] [launch.py:89:main] nnodes=1, num_local_procs=1, node_rank=0
[2021-03-27 17:02:34,981] [INFO] [launch.py:101:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0]})
[2021-03-27 17:02:34,981] [INFO] [launch.py:102:main] dist_world_size=1
[2021-03-27 17:02:34,981] [INFO] [launch.py:105:main] Setting CUDA_VISIBLE_DEVICES=0
[2021-03-27 17:02:36,820] [INFO] [distributed.py:47:init_distributed] Initializing torch distributed with backend: nccl
WARNING:main:Process rank: 0, device: cuda:0, n_gpu: 1distributed training: True, 16-bits training: False
INFO:main:Training/evaluation parameters Seq2SeqTrainingArguments(output_dir='/tmp/tst-summarization', overwrite_output_dir=True, do_train=True, do_eval=True, do_predict=False, evaluation_strategy=<IntervalStrategy.NO: 'no'>, prediction_loss_only=False, per_device_train_batch_size=4, per_device_eval_batch_size=4, per_gpu_train_batch_size=None, per_gpu_eval_batch_size=None, gradient_accumulation_steps=1, eval_accumulation_steps=None, learning_rate=5e-05, weight_decay=0.0, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08, max_grad_norm=1.0, num_train_epochs=3.0, max_steps=-1, lr_scheduler_type=<SchedulerType.LINEAR: 'linear'>, warmup_ratio=0.0, warmup_steps=0, logging_dir='runs/Mar27_17-02-36_rack-jonathan-g04', logging_strategy=<IntervalStrategy.STEPS: 'steps'>, logging_first_step=False, logging_steps=500, save_strategy=<IntervalStrategy.STEPS: 'steps'>, save_steps=500, save_total_limit=None, no_cuda=False, seed=42, fp16=False, fp16_opt_level='O1', fp16_backend='auto', fp16_full_eval=False, local_rank=0, tpu_num_cores=None, tpu_metrics_debug=False, debug=False, dataloader_drop_last=False, eval_steps=500, dataloader_num_workers=0, past_index=-1, run_name='/tmp/tst-summarization', disable_tqdm=False, remove_unused_columns=True, label_names=None, load_best_model_at_end=False, metric_for_best_model=None, greater_is_better=None, ignore_data_skip=False, sharded_ddp=[], deepspeed='examples/tests/deepspeed/ds_config.json', label_smoothing_factor=0.0, adafactor=False, group_by_length=False, report_to=['tensorboard'], ddp_find_unused_parameters=None, dataloader_pin_memory=True, skip_memory_metrics=False, sortish_sampler=False, predict_with_generate=True)
WARNING:datasets.builder:Reusing dataset cnn_dailymail (/home/oriy/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/0a01b1abede4f646130574f203de57a293ded8a7a11e3406a539453afdfeb2c0)
loading configuration file https://huggingface.co/t5-small/resolve/main/config.json from cache at /home/oriy/.cache/huggingface/transformers/fe501e8fd6425b8ec93df37767fcce78ce626e34cc5edc859c662350cf712e41.406701565c0afd9899544c1cb8b93185a76f00b31e5ce7f6e18bbaef02241985
Model config T5Config {
"architectures": [
"T5WithLMHeadModel"
],
"d_ff": 2048,
"d_kv": 64,
"d_model": 512,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 1,
"feed_forward_proj": "relu",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"n_positions": 512,
"num_decoder_layers": 6,
"num_heads": 8,
"num_layers": 6,
"output_past": true,
"pad_token_id": 0,
"relative_attention_num_buckets": 32,
"task_specific_params": {
"summarization": {
"early_stopping": true,
"length_penalty": 2.0,
"max_length": 200,
"min_length": 30,
"no_repeat_ngram_size": 3,
"num_beams": 4,
"prefix": "summarize: "
},
"translation_en_to_de": {
"early_stopping": true,
"max_length": 300,
"num_beams": 4,
"prefix": "translate English to German: "
},
"translation_en_to_fr": {
"early_stopping": true,
"max_length": 300,
"num_beams": 4,
"prefix": "translate English to French: "
},
"translation_en_to_ro": {
"early_stopping": true,
"max_length": 300,
"num_beams": 4,
"prefix": "translate English to Romanian: "
}
},
"transformers_version": "4.5.0.dev0",
"use_cache": true,
"vocab_size": 32128
}

loading configuration file https://huggingface.co/t5-small/resolve/main/config.json from cache at /home/oriy/.cache/huggingface/transformers/fe501e8fd6425b8ec93df37767fcce78ce626e34cc5edc859c662350cf712e41.406701565c0afd9899544c1cb8b93185a76f00b31e5ce7f6e18bbaef02241985
Model config T5Config {
"architectures": [
"T5WithLMHeadModel"
],
"d_ff": 2048,
"d_kv": 64,
"d_model": 512,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 1,
"feed_forward_proj": "relu",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"n_positions": 512,
"num_decoder_layers": 6,
"num_heads": 8,
"num_layers": 6,
"output_past": true,
"pad_token_id": 0,
"relative_attention_num_buckets": 32,
"task_specific_params": {
"summarization": {
"early_stopping": true,
"length_penalty": 2.0,
"max_length": 200,
"min_length": 30,
"no_repeat_ngram_size": 3,
"num_beams": 4,
"prefix": "summarize: "
},
"translation_en_to_de": {
"early_stopping": true,
"max_length": 300,
"num_beams": 4,
"prefix": "translate English to German: "
},
"translation_en_to_fr": {
"early_stopping": true,
"max_length": 300,
"num_beams": 4,
"prefix": "translate English to French: "
},
"translation_en_to_ro": {
"early_stopping": true,
"max_length": 300,
"num_beams": 4,
"prefix": "translate English to Romanian: "
}
},
"transformers_version": "4.5.0.dev0",
"use_cache": true,
"vocab_size": 32128
}

loading file https://huggingface.co/t5-small/resolve/main/spiece.model from cache at /home/oriy/.cache/huggingface/transformers/65fc04e21f45f61430aea0c4fedffac16a4d20d78b8e6601d8d996ebefefecd2.3b69006860e7b5d0a63ffdddc01ddcd6b7c318a6f4fd793596552c741734c62d
loading file https://huggingface.co/t5-small/resolve/main/tokenizer.json from cache at /home/oriy/.cache/huggingface/transformers/06779097c78e12f47ef67ecb728810c2ae757ee0a9efe9390c6419783d99382d.8627f1bd5d270a9fd2e5a51c8bec3223896587cc3cfe13edeabb0992ab43c529
loading file https://huggingface.co/t5-small/resolve/main/added_tokens.json from cache at None
loading file https://huggingface.co/t5-small/resolve/main/special_tokens_map.json from cache at None
loading file https://huggingface.co/t5-small/resolve/main/tokenizer_config.json from cache at None
loading weights file https://huggingface.co/t5-small/resolve/main/pytorch_model.bin from cache at /home/oriy/.cache/huggingface/transformers/fee5a3a0ae379232608b6eed45d2d7a0d2966b9683728838412caccc41b4b0ed.ddacdc89ec88482db20c676f0861a336f3d0409f94748c209847b49529d73885
All model checkpoint weights were used when initializing T5ForConditionalGeneration.

All the weights of T5ForConditionalGeneration were initialized from the model checkpoint at t5-small.
If your task is similar to the task the model of the checkpoint was trained on, you can already use T5ForConditionalGeneration for predictions without further training.
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/oriy/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/0a01b1abede4f646130574f203de57a293ded8a7a11e3406a539453afdfeb2c0/cache-3c2d8ad9af1d1a3e.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/oriy/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/0a01b1abede4f646130574f203de57a293ded8a7a11e3406a539453afdfeb2c0/cache-2e7e82c8de410d07.arrow
Updating the scheduler config from examples/tests/deepspeed/ds_config.json with other command line arguments
setting optimizer.params.lr to 5e-05
setting optimizer.params.betas to [0.9, 0.999]
setting optimizer.params.eps to 1e-08
setting optimizer.params.weight_decay to 0.0
Updating the scheduler config from examples/tests/deepspeed/ds_config.json with other command line arguments
setting scheduler.params.warmup_max_lr to 5e-05
setting scheduler.params.warmup_num_steps to 0
[2021-03-27 17:02:46,871] [INFO] [logging.py:60:log_dist] [Rank 0] DeepSpeed info: version=0.3.13, git-hash=unknown, git-branch=unknown
[2021-03-27 17:02:48,970] [INFO] [engine.py:77:_initialize_parameter_parallel_groups] data_parallel_size: 1, parameter_parallel_size: 1
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using tokenizers before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using tokenizers before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using tokenizers before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using tokenizers before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using tokenizers before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Using /home/oriy/.cache/torch_extensions as PyTorch extensions root...
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using tokenizers before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using tokenizers before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using tokenizers before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Detected CUDA files, patching ldflags
Emitting ninja build file /home/oriy/.cache/torch_extensions/cpu_adam/build.ninja...
Building extension module cpu_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using tokenizers before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
ninja: no work to do.
Loading extension module cpu_adam...
Time to load cpu_adam op: 0.43370747566223145 seconds
Adam Optimizer #0 is created with AVX2 arithmetic capability.
Config: alpha=0.000050, betas=(0.900000, 0.999000), weight_decay=0.000000, adam_w=1
[2021-03-27 17:02:52,144] [INFO] [engine.py:602:_configure_optimizer] Using DeepSpeed Optimizer param name adam as basic optimizer
[2021-03-27 17:02:52,145] [INFO] [engine.py:606:_configure_optimizer] DeepSpeed Basic Optimizer = DeepSpeedCPUAdam
Checking ZeRO support for optimizer=DeepSpeedCPUAdam type=<class 'deepspeed.ops.adam.cpu_adam.DeepSpeedCPUAdam'>
[2021-03-27 17:02:52,145] [INFO] [logging.py:60:log_dist] [Rank 0] Creating fp16 ZeRO stage 2 optimizer
Using /home/oriy/.cache/torch_extensions as PyTorch extensions root...
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using tokenizers before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using tokenizers before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using tokenizers before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Emitting ninja build file /home/oriy/.cache/torch_extensions/utils/build.ninja...
Building extension module utils...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using tokenizers before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
ninja: no work to do.
Loading extension module utils...
Time to load utils op: 0.29197263717651367 seconds
[2021-03-27 17:02:52,437] [INFO] [stage2.py:130:init] Reduce bucket size 200000000.0
[2021-03-27 17:02:52,438] [INFO] [stage2.py:131:init] Allgather bucket size 200000000.0
[2021-03-27 17:02:52,438] [INFO] [stage2.py:132:init] CPU Offload: True
[2021-03-27 17:02:52,846] [INFO] [stage2.py:399:init] optimizer state initialized
[2021-03-27 17:02:52,846] [INFO] [logging.py:60:log_dist] [Rank 0] DeepSpeed Final Optimizer = adam
[2021-03-27 17:02:52,847] [INFO] [engine.py:439:_configure_lr_scheduler] DeepSpeed using configured LR scheduler = WarmupLR
[2021-03-27 17:02:52,847] [INFO] [logging.py:60:log_dist] [Rank 0] DeepSpeed LR Scheduler = <deepspeed.runtime.lr_schedules.WarmupLR object at 0x7fea742ef2b0>
[2021-03-27 17:02:52,847] [INFO] [logging.py:60:log_dist] [Rank 0] step=0, skipped=0, lr=[5e-05], mom=[[0.9, 0.999]]
[2021-03-27 17:02:52,847] [INFO] [config.py:737:print] DeepSpeedEngine configuration:
[2021-03-27 17:02:52,847] [INFO] [config.py:741:print] activation_checkpointing_config {
"contiguous_memory_optimization": false,
"cpu_checkpointing": false,
"number_checkpoints": null,
"partition_activations": false,
"profile": false,
"synchronize_checkpoint_boundary": false
}
[2021-03-27 17:02:52,847] [INFO] [config.py:741:print] allreduce_always_fp32 ........ False
[2021-03-27 17:02:52,847] [INFO] [config.py:741:print] amp_enabled .................. False
[2021-03-27 17:02:52,847] [INFO] [config.py:741:print] amp_params ................... False
[2021-03-27 17:02:52,847] [INFO] [config.py:741:print] checkpoint_tag_validation_enabled True
[2021-03-27 17:02:52,847] [INFO] [config.py:741:print] checkpoint_tag_validation_fail False
[2021-03-27 17:02:52,847] [INFO] [config.py:741:print] disable_allgather ............ False
[2021-03-27 17:02:52,848] [INFO] [config.py:741:print] dump_state ................... False
[2021-03-27 17:02:52,848] [INFO] [config.py:741:print] dynamic_loss_scale_args ...... {'init_scale': 4294967296, 'scale_window': 1000, 'delayed_shift': 2, 'min_scale': 1}
[2021-03-27 17:02:52,848] [INFO] [config.py:741:print] elasticity_enabled ........... False
[2021-03-27 17:02:52,848] [INFO] [config.py:741:print] flops_profiler_config ........ {
"detailed": true,
"enabled": false,
"module_depth": -1,
"profile_step": 1,
"top_modules": 3
}
[2021-03-27 17:02:52,848] [INFO] [config.py:741:print] fp16_enabled ................. True
[2021-03-27 17:02:52,848] [INFO] [config.py:741:print] global_rank .................. 0
[2021-03-27 17:02:52,848] [INFO] [config.py:741:print] gradient_accumulation_steps .. 1
[2021-03-27 17:02:52,848] [INFO] [config.py:741:print] gradient_clipping ............ 1.0
[2021-03-27 17:02:52,848] [INFO] [config.py:741:print] gradient_predivide_factor .... 1.0
[2021-03-27 17:02:52,848] [INFO] [config.py:741:print] initial_dynamic_scale ........ 4294967296
[2021-03-27 17:02:52,848] [INFO] [config.py:741:print] loss_scale ................... 0
[2021-03-27 17:02:52,848] [INFO] [config.py:741:print] memory_breakdown ............. False
[2021-03-27 17:02:52,848] [INFO] [config.py:741:print] optimizer_legacy_fusion ...... False
[2021-03-27 17:02:52,848] [INFO] [config.py:741:print] optimizer_name ............... adam
[2021-03-27 17:02:52,848] [INFO] [config.py:741:print] optimizer_params ............. {'lr': 5e-05, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0.0}
[2021-03-27 17:02:52,848] [INFO] [config.py:741:print] pipeline ..................... {'stages': 'auto', 'partition': 'best', 'seed_layers': False, 'activation_checkpoint_interval': 0}
[2021-03-27 17:02:52,848] [INFO] [config.py:741:print] pld_enabled .................. False
[2021-03-27 17:02:52,848] [INFO] [config.py:741:print] pld_params ................... False
[2021-03-27 17:02:52,848] [INFO] [config.py:741:print] prescale_gradients ........... False
[2021-03-27 17:02:52,849] [INFO] [config.py:741:print] scheduler_name ............... WarmupLR
[2021-03-27 17:02:52,849] [INFO] [config.py:741:print] scheduler_params ............. {'warmup_min_lr': 0, 'warmup_max_lr': 5e-05, 'warmup_num_steps': 0}
[2021-03-27 17:02:52,849] [INFO] [config.py:741:print] sparse_attention ............. None
[2021-03-27 17:02:52,849] [INFO] [config.py:741:print] sparse_gradients_enabled ..... False
[2021-03-27 17:02:52,849] [INFO] [config.py:741:print] steps_per_print .............. 10
[2021-03-27 17:02:52,849] [INFO] [config.py:741:print] tensorboard_enabled .......... False
[2021-03-27 17:02:52,849] [INFO] [config.py:741:print] tensorboard_job_name ......... DeepSpeedJobName
[2021-03-27 17:02:52,849] [INFO] [config.py:741:print] tensorboard_output_path ......
[2021-03-27 17:02:52,849] [INFO] [config.py:741:print] train_batch_size ............. 4
[2021-03-27 17:02:52,849] [INFO] [config.py:741:print] train_micro_batch_size_per_gpu 4
[2021-03-27 17:02:52,849] [INFO] [config.py:741:print] wall_clock_breakdown ......... False
[2021-03-27 17:02:52,849] [INFO] [config.py:741:print] world_size ................... 1
[2021-03-27 17:02:52,849] [INFO] [config.py:741:print] zero_allow_untested_optimizer False
[2021-03-27 17:02:52,849] [INFO] [config.py:741:print] zero_config .................. {
"allgather_bucket_size": 200000000.0,
"allgather_partitions": true,
"contiguous_gradients": true,
"cpu_offload": true,
"cpu_offload_params": false,
"cpu_offload_use_pin_memory": false,
"elastic_checkpoint": true,
"load_from_fp32_weights": true,
"max_live_parameters": 1000000000,
"max_reuse_distance": 1000000000,
"overlap_comm": true,
"param_persistence_threshold": 100000,
"prefetch_bucket_size": 50000000,
"reduce_bucket_size": 200000000.0,
"reduce_scatter": true,
"stage": 2,
"sub_group_size": 1000000000000
}
[2021-03-27 17:02:52,849] [INFO] [config.py:741:print] zero_enabled ................. True
[2021-03-27 17:02:52,849] [INFO] [config.py:741:print] zero_optimization_stage ...... 2
[2021-03-27 17:02:52,850] [INFO] [config.py:748:print] json = {
"fp16":{
"enabled":true,
"hysteresis":2,
"loss_scale":0,
"loss_scale_window":1000,
"min_loss_scale":1
},
"gradient_accumulation_steps":1,
"gradient_clipping":1.0,
"optimizer":{
"params":{
"betas":[
0.9,
0.999
],
"eps":1e-08,
"lr":5e-05,
"weight_decay":0.0
},
"type":"Adam"
},
"scheduler":{
"params":{
"warmup_max_lr":5e-05,
"warmup_min_lr":0,
"warmup_num_steps":0
},
"type":"WarmupLR"
},
"train_micro_batch_size_per_gpu":4,
"zero_optimization":{
"allgather_bucket_size":200000000.0,
"allgather_partitions":true,
"contiguous_gradients":true,
"cpu_offload":true,
"overlap_comm":true,
"reduce_bucket_size":200000000.0,
"reduce_scatter":true,
"stage":2
}
}
Using /home/oriy/.cache/torch_extensions as PyTorch extensions root...
No modifications detected for re-loaded extension module utils, skipping build step...
Loading extension module utils...
Time to load utils op: 0.0005881786346435547 seconds
***** Running training *****
Num examples = 287113
Num Epochs = 3
Instantaneous batch size per device = 4
Total train batch size (w. parallel, distributed & accumulation) = 4
Gradient Accumulation steps = 1
Total optimization steps = 215337
0%| | 0/215337 [00:00<?, ?it/s][2021-03-27 17:02:53,333] [INFO] [stage2.py:1391:step] [deepspeed] fp16 dynamic loss scale overflow! Rank 0 Skipping step. Attempted loss scale: 4294967296, reducing to 4294967296
0%| | 1/215337 [00:00<26:38:16, 2.25it/s][2021-03-27 17:02:53,687] [INFO] [stage2.py:1391:step] [deepspeed] fp16 dynamic loss scale overflow! Rank 0 Skipping step. Attempted loss scale: 4294967296, reducing to 2147483648.0

@stas00
Copy link
Contributor

stas00 commented Mar 27, 2021

Next week I hope #10753 will be finished, but for now here are the results on rtx-3090 24GB card with the unfinished zero-3 PR.

As you can see Deepspeed zero3's cpu offload is a way way more memory-efficient:

# baseline

BS=4; CUDA_VISIBLE_DEVICES=0 PYTHONPATH=src USE_TF=0 python examples/seq2seq/run_translation.py --model_name_or_path t5-small --output_dir /tmp/zero3 --overwrite_output_dir --max_train_samples 64 --max_val_samples 64 --max_source_length 128 --max_target_length 128 --val_max_target_length 128 --do_train --num_train_epochs 1 --per_device_train_batch_size $BS --per_device_eval_batch_size $BS --learning_rate 3e-3 --warmup_steps 500 --predict_with_generate --logging_steps 0 --save_steps 0 --eval_steps 1 --group_by_length  --adafactor --dataset_name wmt16 --dataset_config ro-en --source_lang en --target_lang ro --source_prefix "translate English to Romanian: " 

***** train metrics *****
  epoch                      =   1.0
  init_mem_cpu_alloc_delta   =   3MB
  init_mem_cpu_peaked_delta  =   0MB
  init_mem_gpu_alloc_delta   = 230MB
  init_mem_gpu_peaked_delta  =   0MB
  train_mem_cpu_alloc_delta  =  60MB
  train_mem_cpu_peaked_delta =   0MB
  train_mem_gpu_alloc_delta  = 231MB
  train_mem_gpu_peaked_delta = 226MB
  train_runtime              = 3.619
  train_samples              =    64
  train_samples_per_second   = 4.421

# zero2


BS=4; PYTHONPATH=src USE_TF=0 deepspeed --num_gpus 1 examples/seq2seq/run_translation.py --model_name_or_path t5-small --output_dir /tmp/zero3 --overwrite_output_dir --max_train_samples 64 --max_val_samples 64 --max_source_length 128 --max_target_length 128 --val_max_target_length 128 --do_train --num_train_epochs 1 --per_device_train_batch_size $BS --per_device_eval_batch_size $BS --learning_rate 3e-3 --warmup_steps 500 --predict_with_generate --logging_steps 0 --save_steps 0 --eval_steps 1 --group_by_length  --adafactor --dataset_name wmt16 --dataset_config ro-en --source_lang en --target_lang ro --source_prefix "translate English to Romanian: " --deepspeed examples/tests/deepspeed/ds_config_zero2.json


***** train metrics *****
  epoch                      =    1.0
  init_mem_cpu_alloc_delta   =    7MB
  init_mem_cpu_peaked_delta  =    0MB
  init_mem_gpu_alloc_delta   =    0MB
  init_mem_gpu_peaked_delta  =    0MB
  train_mem_cpu_alloc_delta  =   70MB
  train_mem_cpu_peaked_delta =    0MB
  train_mem_gpu_alloc_delta  =  148MB
  train_mem_gpu_peaked_delta = 3559MB
  train_runtime              = 5.0669
  train_samples              =     64
  train_samples_per_second   =  3.158


# zero3


BS=4; PYTHONPATH=src USE_TF=0 deepspeed --num_gpus 1 examples/seq2seq/run_translation.py --model_name_or_path t5-small --output_dir /tmp/zero3 --overwrite_output_dir --max_train_samples 64 --max_val_samples 64 --max_source_length 128 --max_target_length 128 --val_max_target_length 128 --do_train --num_train_epochs 1 --per_device_train_batch_size $BS --per_device_eval_batch_size $BS --learning_rate 3e-3 --warmup_steps 500 --predict_with_generate --logging_steps 0 --save_steps 0 --eval_steps 1 --group_by_length  --adafactor --dataset_name wmt16 --dataset_config ro-en --source_lang en --target_lang ro --source_prefix "translate English to Romanian: " --deepspeed examples/tests/deepspeed/ds_config_zero3.json

***** train metrics *****
  epoch                      =    1.0
  init_mem_cpu_alloc_delta   =    7MB
  init_mem_cpu_peaked_delta  =    0MB
  init_mem_gpu_alloc_delta   =    0MB
  init_mem_gpu_peaked_delta  =    0MB
  train_mem_cpu_alloc_delta  =   71MB
  train_mem_cpu_peaked_delta =    0MB
  train_mem_gpu_alloc_delta  =  -52MB
  train_mem_gpu_peaked_delta =  244MB
  train_runtime              = 7.6324
  train_samples              =     64
  train_samples_per_second   =  2.096

The config files are from the PR I linked to in the first para.

So please give us a few more days - this is also depending on deepspeed merging several PRs and making a new release.

@stas00 stas00 self-assigned this Mar 27, 2021
@stas00
Copy link
Contributor

stas00 commented Mar 27, 2021

I suspect my cpu memory profiling functions are missing some allocations, which is odd. Surely, there must be more cpu memory used with cpu_offload. I will investigate this.

Suspecting that tracemalloc doesn't tracks c++ allocations, which is what deepspeed does. might have to switch to sampling, but python threads's GIL is a big problem to get correct results.

edit: this should fix it: #10937

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed May 4, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants