Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

only main process should call _save on deepspeed zero3 #25959

Merged
merged 1 commit into from
Sep 11, 2023
Merged

only main process should call _save on deepspeed zero3 #25959

merged 1 commit into from
Sep 11, 2023

Conversation

zjjMaiMai
Copy link
Contributor

Background

trainer._save call on all process after #25817. will raise FileExistsError when model save.

What does this PR do?

this pr fix it, trainer._save will call on main process only.

@amyeroberts
Copy link
Collaborator

cc @muellerzr @pacman100

Comment on lines 2796 to 2800
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
# remove the dummy state_dict
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
self.model_wrapped.save_checkpoint(output_dir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not feel like the right solution, as should_save will check if we are the main process or not, and we should only ever be saving once. You can have self.model_wrapped.save_checkpoint under that self.args.should_save I believe, but I'm not sure this is the right "fix". Definitely cc @pacman100 here

Copy link
Contributor Author

@zjjMaiMai zjjMaiMai Sep 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.model_wrapped.save_checkpoint needs to be called on each process, because each process only have part of model weight under deepspeed zero3.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks for the clarification!

@muellerzr muellerzr requested a review from pacman100 September 5, 2023 13:06
Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, Thank you for the fix! However, it needs correction as explained in the comment,

if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
# remove the dummy state_dict
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would remove the legitimate model checkpoint when stage3_gather_16bit_weights_on_model_save=True. remove_dummy_checkpoint should only be called in the exception block.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for your explanation, it has been fixed.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@zjjMaiMai
Copy link
Contributor Author

@pacman100 any things i need to do?

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for iterating, LGTM!

@amyeroberts
Copy link
Collaborator

@zjjMaiMai One of the hub tests are failing, complaining that the base_model is empty when pushing to the hub. Could you try running this test locally to see whether it's a result of the changes in this PR?

@zjjMaiMai
Copy link
Contributor Author

zjjMaiMai commented Sep 6, 2023

@zjjMaiMai One of the hub tests are failing, complaining that the base_model is empty when pushing to the hub. Could you try running this test locally to see whether it's a result of the changes in this PR?

$ pytest tests/trainer/test_trainer.py -k 'test_push_to_hub'
================================================================================================================================ test session starts ================================================================================================================================
platform linux -- Python 3.9.2, pytest-7.4.1, pluggy-1.3.0
configfile: setup.cfg
plugins: timeout-2.1.0, hypothesis-6.84.2, dash-2.13.0, xdist-3.3.1, anyio-3.7.1
collected 95 items / 91 deselected / 4 selected                                                                                                                                                                                                                                     

tests/trainer/test_trainer.py ssss                                                                                                                                                                                                                                            [100%]

================================================================================================================================= warnings summary ==================================================================================================================================
../../../../../../home/.local/lib/python3.9/site-packages/_pytest/config/__init__.py:1376
  /home/.local/lib/python3.9/site-packages/_pytest/config/__init__.py:1376: PytestConfigWarning: Unknown config option: doctest_glob
  
    self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=================================================================================================================== 4 skipped, 91 deselected, 1 warning in 1.69s ====================================================================================================================
$ git branch 
* fix_save_deepspeed_3
  main

@amyeroberts
Copy link
Collaborator

@zjjMaiMai Could you try and rebase on main? This should resolve the failing tests.

@zjjMaiMai
Copy link
Contributor Author

All green! @amyeroberts

@amyeroberts amyeroberts merged commit 7fd2d68 into huggingface:main Sep 11, 2023
parambharat pushed a commit to parambharat/transformers that referenced this pull request Sep 26, 2023
…5959)

only main process should call _save when deepspeed zero3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants