Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dist_cp] Using no_dist with load still requires a process group #115591

Closed
jonb377 opened this issue Dec 11, 2023 · 2 comments
Closed

[dist_cp] Using no_dist with load still requires a process group #115591

jonb377 opened this issue Dec 11, 2023 · 2 comments
Assignees
Labels
module: distributed_checkpoint oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@jonb377
Copy link
Contributor

jonb377 commented Dec 11, 2023

🐛 Describe the bug

Using the no_dist parameter to load a distributed checkpoint without a process group does not work with dist_cp.load after #114304.

Minimal reproduction:

import torch.distributed.checkpoint as dist_cp
dist_cp.load(state_dict={}, storage_reader=dist_cp.FileSystemReader('/tmp/foo'), no_dist=True)

This results in the following:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/workspaces/work/pytorch/torch/distributed/checkpoint/state_dict_loader.py", line 112, in load
    keys = _all_gather_keys(state_dict)
  File "/workspaces/work/pytorch/torch/distributed/checkpoint/utils.py", line 53, in _all_gather_keys
    gathered_keys: List[List[Any]] = [None] * dist.get_world_size()  # type: ignore[list-item]
  File "/workspaces/work/pytorch/torch/distributed/distributed_c10d.py", line 1592, in get_world_size
    return _get_group_size(group)
  File "/workspaces/work/pytorch/torch/distributed/distributed_c10d.py", line 836, in _get_group_size
    default_pg = _get_default_group()
  File "/workspaces/work/pytorch/torch/distributed/distributed_c10d.py", line 977, in _get_default_group
    raise ValueError(
ValueError: Default process group has not been initialized, please make sure to call init_process_group.

The expected behavior is that a process group is not required with no_dist set. We use no_dist in our torch_xla distributed checkpointing tests, and this issue was caught migrating our tests to use load instead of load_state_dict.

Thank you!

Versions

Collecting environment information...
PyTorch version: 2.2.0a0+gitfd79995
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 10 (buster) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31

Python version: 3.8.17 (default, Jun 13 2023, 16:14:12) [GCC 8.3.0] (64-bit runtime)
Python platform: Linux-5.19.0-1022-gcp-x86_64-with-glibc2.2.5
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 48 bits physical, 48 bits virtual
CPU(s): 240
On-line CPU(s) list: 0-239
Thread(s) per core: 2
Core(s) per socket: 60
Socket(s): 2
NUMA node(s): 2
Vendor ID: AuthenticAMD
CPU family: 23
Model: 49
Model name: AMD EPYC 7B12
Stepping: 0
CPU MHz: 2250.000
BogoMIPS: 4500.00
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 32K
L1i cache: 32K
L2 cache: 512K
L3 cache: 16384K
NUMA node0 CPU(s): 0-59,120-179
NUMA node1 CPU(s): 60-119,180-239
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr arat npt nrip_save umip rdpid

Versions of relevant libraries:
[pip3] numpy==1.24.4
[pip3] torch==2.2.0a0+gitfd79995
[pip3] torch-xla==2.2.0+gitcb32668
[pip3] torchvision==0.16.1
[conda] Could not collect

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225

@jonb377
Copy link
Contributor Author

jonb377 commented Dec 11, 2023

cc @LucasLLC

@malfet malfet added oncall: distributed Add this issue/PR to distributed oncall triage queue module: distributed_checkpoint labels Dec 12, 2023
@wz337
Copy link
Contributor

wz337 commented Dec 12, 2023

@LucasLLC Should be a simple fix. We can skip the all gather keys when no_dist is True.

guilhermeleobas pushed a commit to guilhermeleobas/pytorch that referenced this issue Dec 18, 2023
…d` (pytorch#115660)

Fixes expected behavior when `no_dist=True` in `state_dict_loader.load`

Fixes pytorch#115591

Pull Request resolved: pytorch#115660
Approved by: https://github.com/wz337, https://github.com/fegin
dmenig pushed a commit to dmenig/pytorch that referenced this issue Dec 21, 2023
…d` (pytorch#115660)

Fixes expected behavior when `no_dist=True` in `state_dict_loader.load`

Fixes pytorch#115591

Pull Request resolved: pytorch#115660
Approved by: https://github.com/wz337, https://github.com/fegin
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: distributed_checkpoint oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants