Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adam mini can't save when using with FSDP in Huggingface Trainer #5

Open
hahuyhoang411 opened this issue Jun 27, 2024 · 19 comments
Open

Comments

@hahuyhoang411
Copy link

Hi, it's me again. The training is working great but when it comes to saving the checkpoint, I got this bug. Any ideas?

[rank0]:   File "/workspace/train.py", line 230, in <module>
[rank0]:     trainer_stats = trainer.train()
[rank0]:                     ^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/trainer/sft_trainer.py", line 440, in train
[rank0]:     output = super().train(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 1885, in train
[rank0]:     return inner_training_loop(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 2291, in _inner_training_loop
[rank0]:     self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 2732, in _maybe_log_save_evaluate
[rank0]:     self._save_checkpoint(model, trial, metrics=metrics)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 2815, in _save_checkpoint
[rank0]:     self._save_optimizer_and_scheduler(output_dir)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 2920, in _save_optimizer_and_scheduler
[rank0]:     save_fsdp_optimizer(
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/accelerate/utils/fsdp_utils.py", line 163, in save_fsdp_optimizer
[rank0]:     optim_state = FSDP.optim_state_dict(model, optimizer)
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1840, in optim_state_dict
[rank0]:     return FullyShardedDataParallel._optim_state_dict_impl(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1263, in _optim_state_dict_impl
[rank0]:     return _optim_state_dict(
[rank0]:            ^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/_optim_utils.py", line 1971, in _optim_state_dict
[rank0]:     fsdp_osd_state = convert_fn(
[rank0]:                      ^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/_optim_utils.py", line 1794, in _convert_state_with_orig_params
[rank0]:     _gather_all_orig_param_state(
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/_optim_utils.py", line 1688, in _gather_all_orig_param_state
[rank0]:     output_states = _allgather_orig_param_states(
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/_optim_utils.py", line 1625, in _allgather_orig_param_states
[rank0]:     local_shard = torch.cat(local_buffers)
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: Tensors must have same number of dimensions: got 2 and 1
@tikikun
Copy link

tikikun commented Jun 27, 2024

Hi, i'm from the same team with @hahuyhoang411 , we have been able to mitigate the issue by turning this option to False
image

Why it is the case? I noticed some speed loss turning this off, is there anyways to enable it back without causing issue ?

@zyushun
Copy link
Owner

zyushun commented Jun 27, 2024

Thanks @tikikun for the update. We are still working on this issue of saving checkpoint, but good to know that you guys work out a solution.

I wonder what do you mean by "speed loss". Did you observe:

Case 1: slow down in loss v.s. iteration.
Case 2: loss v.s. iteration is unaffected, but the time for each iteration increases.

I personally guess you mean case 2 because the option of "use_orig_params" does not seem to affect the optimizer trajectory. For case 2, there are some possibilities come to my mind now:

Possibility 1: "use_orig_params = False" slows down some other operations that are not related to the optimizer.

Possibility 2: Current implementation of Adam-mini involves several "view(-1)" type operations, which may cause the slow down. Note that these "view(-1)" will not be activated when using "use_orig_params = True" because everything is already flattened before the training.

We are still trying to fix the saving issues. Please feel free to update more of your findings. It would help a lot.

@tikikun
Copy link

tikikun commented Jun 27, 2024

@zyushun yes it is case 2, each iter takes more time. Thank you for the swift response, the result is very good and it is truly as good or better than AdamW with much fewer VRAM (much better than some previous one we try like lion etc), its really amazing work.

We will update if we have any other issues or infor, we are using this internally now at the moment.

@han508
Copy link

han508 commented Jun 29, 2024

@tikikun @hahuyhoang411 @zyushun Thanks for sharing. Have you tried this combination of trainer and deepspeeed?

@tikikun
Copy link

tikikun commented Jun 30, 2024

for deepspeed its working, only fsdp fail

@tikikun
Copy link

tikikun commented Jun 30, 2024

@tikikun @hahuyhoang411 @zyushun Thanks for sharing. Have you tried this combination of trainer and deepspeeed?

yes:
trainer + deepseed = working well
trainer + fsdp = failing

same results with zero_3 = True | False

@zyushun
Copy link
Owner

zyushun commented Jun 30, 2024

Hi @tikikun @hahuyhoang411 @han508, thanks a lot for the valuable discussion on this checkpoint-saving issue! We tried several fixes but unfortunately, so far the only effective way is to set "use_orig_params = False", as suggested by @tikikun and @hahuyhoang411.

To sum up, so far we have:

trainer + deepseed = can save/load ckpt
trainer + fsdp = fail to save ckpt
trainer + fsdp + "use_orig_params = False" = can save/load ckpt

We have updated this info in the readme.md. We will keep updating if we find other approaches to fix this checkpoint saving issue. Thanks a lot for all the great suggestions and discussions!

@Hyperparticle
Copy link

Hyperparticle commented Jun 30, 2024

The use of use_orig_params=True is required when using torch.compile with FSDP as explained by these docs. Compilation offers another very complementary training speedup. So would be really awesome if there's any way to support use_orig_params=True.

@han508
Copy link

han508 commented Jun 30, 2024

Can you provide an example with traniner? I tried to override the create_optimizer method, but it failed.

@chcoliang
Copy link
Collaborator

chcoliang commented Jul 1, 2024

Can you provide an example with traniner? I tried to override the create_optimizer method, but it failed.

Thanks for the great suggestion! Here is an example of create_optimizer. We have also included this example in the readme.md.

 def create_optimizer(self) -> "torch.optim.Optimizer":
        if self.optimizer is None:
            if (self.finetuning_args.use_adammini):
                self.optimizer = Adam_mini(model = self.model, lr = self.args.learning_rate, weight_decay = self.args.weight_decay, 
                                           beta1 = self.args.adam_beta1, beta2 = self.args.adam_beta2, model_sharding = True, 
                                           n_embd = 4096, n_head = 32, n_query_groups = 32)
        return super().create_optimizer()

@han508
Copy link

han508 commented Jul 1, 2024

@chcoliang
Thank you very much for your patient response.
I had a problem rewriting the trainer in this way and deleting the adamw configuration from the deepspeed profile. "Adam-mini is using model_sharding"on the terminal, but there is no significant change in gpu memory and speed compared to adamw, I used 8-card t4.
this is start bash:
deepspeed --include=localhost:0,1,2,3,4,5,6,7 /home/han/training/train.py
--model_path /home/han/model/LLM-Research/Meta-Llama-3-8B
--data_path /home/han/data_sft/dataset_test_code
--group_loss False
--train_type 'sft'
--packing False
--use_unsloth False
--max_length 2048
--output_dir /home/han/0509_7b
--do_train True
--do_eval False
--per_device_train_batch_size 1
--gradient_accumulation_steps 4
--gradient_checkpointing True
--save_strategy "epoch"
--logging_steps 10
--lr_scheduler_type 'cosine'
--warmup_ratio 0.03
--num_train_epochs 4
--learning_rate 2e-5
--fp16 True
--save_safetensors
--seed 2025
--deepspeed /home/han/config/ds.json \

@liziniu
Copy link
Collaborator

liziniu commented Jul 2, 2024

@han508 Hi, this is an author of Adam-mini. According to our calculation, Adam-mini can save about 4GB per GPU card in your case. I conjecture that something is wrong and the used optimizer is still AdamW. Could you provide the deepspeed log to us? Or could you tell us how to reproduce your results?

An example of deepspeed log is given below:

[2024-06-29 16:39:24,514] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed info: version=0.14.2, git-hash=unknown, git-branch=unknown
[2024-06-29 16:39:24,514] [INFO] [comm.py:662:init_distributed] Distributed backend already initialized
[2024-06-29 16:39:28,419] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Flops Profiler Enabled: False
[2024-06-29 16:39:28,422] [INFO] [logging.py:96:log_dist] [Rank 0] Using client Optimizer as basic optimizer
[2024-06-29 16:39:28,423] [INFO] [logging.py:96:log_dist] [Rank 0] Removing param_group that has no 'params' in the basic Optimizer
[2024-06-29 16:39:28,462] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Basic Optimizer = DeepSpeedCPUAdam
[2024-06-29 16:39:28,463] [INFO] [utils.py:56:is_zero_supported_optimizer] Checking ZeRO support for optimizer=DeepSpeedCPUAdam type=<class 'deepspeed.ops.adam.cpu_adam.DeepSpeedCPUAdam'>
[2024-06-29 16:39:28,463] [INFO] [logging.py:96:log_dist] [Rank 0] Creating torch.bfloat16 ZeRO stage 2 optimizer
[2024-06-29 16:39:28,463] [INFO] [stage_1_and_2.py:148:init] Reduce bucket size 500,000,000
[2024-06-29 16:39:28,463] [INFO] [stage_1_and_2.py:149:init] Allgather bucket size 500,000,000
[2024-06-29 16:39:28,463] [INFO] [stage_1_and_2.py:150:init] CPU Offload: True
[2024-06-29 16:39:28,463] [INFO] [stage_1_and_2.py:151:init] Round robin gradient partitioning: False

@han508
Copy link

han508 commented Jul 3, 2024

Thank for your reply.
this is my deepspeed log.

2024-07-03 22:39:26,401] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-devel package with yum
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  NVIDIA Inference is only supported on Ampere and newer architectures
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.2
 [WARNING]  using untested triton version (2.2.0), only 1.0.0 is known to be compatible
[2024-07-03 22:39:28,926] [WARNING] [runner.py:202:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
[2024-07-03 22:39:28,926] [INFO] [runner.py:568:main] cmd = /data/hanze/miniconda3/envs/test/bin/python -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMCwgMSwgMiwgMywgNCwgNSwgNiwgN119 --master_addr=127.0.0.1 --master_port=29500 --enable_each_rank_log=None /data/hanze/emo/train.py --model_path /data/hanze/pt_model/qwen2/ --data_path /data/hanze/dataset/dataset_test_code --group_loss False --train_type sft --packing False --use_unsloth False --max_length 2048 --output_dir /data/hanze/test --do_train True --do_eval False --per_device_train_batch_size 1 --gradient_accumulation_steps 4 --gradient_checkpointing True --save_strategy epoch --logging_steps 10 --lr_scheduler_type cosine --warmup_ratio 0.03 --num_train_epochs 4 --learning_rate 2e-5 --fp16 True --save_safetensors --seed 2025 --deepspeed /data/hanze/emo/config/ds1.json
[2024-07-03 22:39:30,810] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-devel package with yum
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  NVIDIA Inference is only supported on Ampere and newer architectures
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.2
 [WARNING]  using untested triton version (2.2.0), only 1.0.0 is known to be compatible
[2024-07-03 22:39:33,483] [INFO] [launch.py:146:main] WORLD INFO DICT: {'localhost': [0, 1, 2, 3, 4, 5, 6, 7]}
[2024-07-03 22:39:33,483] [INFO] [launch.py:152:main] nnodes=1, num_local_procs=8, node_rank=0
[2024-07-03 22:39:33,483] [INFO] [launch.py:163:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0, 1, 2, 3, 4, 5, 6, 7]})
[2024-07-03 22:39:33,483] [INFO] [launch.py:164:main] dist_world_size=8
[2024-07-03 22:39:33,483] [INFO] [launch.py:168:main] Setting CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

[2024-07-03 22:39:38,019] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
[2024-07-03 22:39:38,294] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  async_io: please install the libaio-devel package with yum
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
[2024-07-03 22:39:38,338] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-07-03 22:39:38,359] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  NVIDIA Inference is only supported on Ampere and newer architectures
[2024-07-03 22:39:38,430] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
[2024-07-03 22:39:38,497] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.2
 [WARNING]  using untested triton version (2.2.0), only 1.0.0 is known to be compatible
[2024-07-03 22:39:38,514] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-07-03 22:39:38,540] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  using untested triton version (2.2.0), only 1.0.0 is known to be compatible
[2024-07-03 22:39:38,717] [INFO] [comm.py:637:init_distributed] cdb=None
[2024-07-03 22:39:38,917] [INFO] [comm.py:637:init_distributed] cdb=None
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.2
 [WARNING]  using untested triton version (2.2.0), only 1.0.0 is known to be compatible
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.2
 [WARNING]  using untested triton version (2.2.0), only 1.0.0 is known to be compatible
[2024-07-03 22:39:38,992] [INFO] [comm.py:637:init_distributed] cdb=None
[2024-07-03 22:39:38,992] [INFO] [comm.py:637:init_distributed] cdb=None
[2024-07-03 22:39:39,112] [INFO] [comm.py:637:init_distributed] cdb=None
[2024-07-03 22:39:39,132] [INFO] [comm.py:637:init_distributed] cdb=None
[2024-07-03 22:39:39,191] [INFO] [comm.py:637:init_distributed] cdb=None
[2024-07-03 22:39:39,196] [INFO] [comm.py:637:init_distributed] cdb=None
[2024-07-03 22:39:39,197] [INFO] [comm.py:668:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
[2024-07-03 22:39:50,609] [INFO] [partition_parameters.py:345:__exit__] finished initializing model - num_params = 339, num_elems = 7.62B
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:10<00:00,  2.54s/it]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:10<00:00,  2.55s/it]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:10<00:00,  2.57s/it]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:10<00:00,  2.55s/it]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:10<00:00,  2.55s/it]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:10<00:00,  2.57s/it]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:10<00:00,  2.57s/it]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:10<00:00,  2.60s/it]
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Adam-mini is using model_sharding
Adam-mini is using model_sharding
Adam-mini is using model_sharding
Adam-mini is using model_sharding
Adam-mini is using model_sharding
Adam-mini is using model_sharding
Adam-mini is using model_sharding
Adam-mini is using model_sharding
Adam-mini is using model_sharding
Adam-mini is using model_sharding
Adam-mini is using model_sharding
Adam-mini is using model_sharding
Adam-mini is using model_sharding
Adam-mini is using model_sharding
Adam-mini is using model_sharding
Adam-mini is using model_sharding
Using /home/centos7/.cache/torch_extensions/py39_cu118 as PyTorch extensions root...
Using /home/centos7/.cache/torch_extensions/py39_cu118 as PyTorch extensions root...
Using /home/centos7/.cache/torch_extensions/py39_cu118 as PyTorch extensions root...
Using /home/centos7/.cache/torch_extensions/py39_cu118 as PyTorch extensions root...
Using /home/centos7/.cache/torch_extensions/py39_cu118 as PyTorch extensions root...
Using /home/centos7/.cache/torch_extensions/py39_cu118 as PyTorch extensions root...
Emitting ninja build file /home/centos7/.cache/torch_extensions/py39_cu118/cpu_adam/build.ninja...
Building extension module cpu_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Using /home/centos7/.cache/torch_extensions/py39_cu118 as PyTorch extensions root...
ninja: no work to do.
Loading extension module cpu_adam...
Time to load cpu_adam op: 0.9709291458129883 seconds
Loading extension module cpu_adam...
Time to load cpu_adam op: 0.7428872585296631 seconds
Loading extension module cpu_adam...
Time to load cpu_adam op: 0.7151861190795898 seconds
Loading extension module cpu_adam...
Using /home/centos7/.cache/torch_extensions/py39_cu118 as PyTorch extensions root...
Time to load cpu_adam op: 0.7295939922332764 seconds
Emitting ninja build file /home/centos7/.cache/torch_extensions/py39_cu118/cpu_adam/build.ninja...
Building extension module cpu_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module cpu_adam...
Time to load cpu_adam op: 1.0122733116149902 seconds
Loading extension module cpu_adam...
Loading extension module cpu_adam...
Time to load cpu_adam op: 1.2715320587158203 seconds
Time to load cpu_adam op: 1.0640928745269775 seconds
Loading extension module cpu_adam...
Time to load cpu_adam op: 1.1445198059082031 seconds
Parameter Offload: Total persistent parameters: 333312 in 141 params
wandb: Tracking run with wandb version 0.17.3
wandb: Run data is saved locally in /tmp/wandb/run-20240703_224103-l7rsc9ho
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run /data/hanze/test
wandb: ⭐️ View project at https://wandb.ai/hanze/huggingface
wandb: 🚀 View run at https://wandb.ai/hanze/huggingface/runs/l7rsc9ho
  0%|                                                                                                                         | 0/48 [00:00<?, ?it/s]

Application details:

  1. use deepspeed(zero3 + offload) + trainer
    2.remove the content about optimizer in deepspeed config file
    3.I only rewrote trainer 's create_optimizer method using the example you provided.

@chcoliang
Copy link
Collaborator

@han508. Hi. According to the log "Loading extension module cpu_adam...", I think you are using cpu_adam instead of Adam-mini. From my view, when "optimizer" in the deepspeed config is None, deepspeed will use trainer.create_optimizer() to generate optimizer. Thus, some content about "optimizer" may still be in the configure file. Please double-check the config file or some default configuration which is not in the config file.

Further, the current version of Adam-mini does not support cpu-offload for deepspeed according to our experiment. We do not recommend to use cpu-offload when using Adam-mini in deepspeed.

@han508
Copy link

han508 commented Jul 3, 2024

@chcoliang
when i remove cpu-offload parameter in configure file, it work.
Are there any plans to support cpu-offload in the future?

@chcoliang
Copy link
Collaborator

Sure! It is vitally important to support cpu-offload. As we mentioned in the readme.md, the current version of Adam-mini supports cpu-offload in FSDP, but not DeepSeed (due to some unexpected error). We are working on it and hopefully it will be done soon.

@awgu
Copy link

awgu commented Jul 7, 2024

I may not have caught up all of the context, but I wanted to mention that for FullyShardedDataParallel with use_orig_params=False, the parameter names in named_parameters() will not be the original parameter names but rather something like ._flat_param. This means that the name checks will return False, which is not what we want:

Adam-mini/Adam_mini.py

Lines 53 to 56 in f98a1cd

for name, param in self.model.named_parameters():
if param.requires_grad:
dic = {}
dic["name"] = name

elif ("attn.attn.weight" in name or "attn.qkv.weight" in name):

@zyushun
Copy link
Owner

zyushun commented Jul 8, 2024

Hi @awgu ! Thanks a lot for mentioning this to use! This is a great catch.

So it seems like we need another way to save ckpt for FSDP. We will work on it and will update here as soon as we make any progress.

@han508
Copy link

han508 commented Aug 21, 2024

Is there any latest process on "deepspeed cpu offload" and i am looking forward to it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants