[dist_cp] Using no_dist
with load
still requires a process group
#115591
Labels
module: distributed_checkpoint
oncall: distributed
Add this issue/PR to distributed oncall triage queue
🐛 Describe the bug
Using the
no_dist
parameter to load a distributed checkpoint without a process group does not work withdist_cp.load
after #114304.Minimal reproduction:
This results in the following:
The expected behavior is that a process group is not required with
no_dist
set. We useno_dist
in our torch_xla distributed checkpointing tests, and this issue was caught migrating our tests to useload
instead ofload_state_dict
.Thank you!
Versions
Collecting environment information...
PyTorch version: 2.2.0a0+gitfd79995
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: Debian GNU/Linux 10 (buster) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31
Python version: 3.8.17 (default, Jun 13 2023, 16:14:12) [GCC 8.3.0] (64-bit runtime)
Python platform: Linux-5.19.0-1022-gcp-x86_64-with-glibc2.2.5
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 48 bits physical, 48 bits virtual
CPU(s): 240
On-line CPU(s) list: 0-239
Thread(s) per core: 2
Core(s) per socket: 60
Socket(s): 2
NUMA node(s): 2
Vendor ID: AuthenticAMD
CPU family: 23
Model: 49
Model name: AMD EPYC 7B12
Stepping: 0
CPU MHz: 2250.000
BogoMIPS: 4500.00
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 32K
L1i cache: 32K
L2 cache: 512K
L3 cache: 16384K
NUMA node0 CPU(s): 0-59,120-179
NUMA node1 CPU(s): 60-119,180-239
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr arat npt nrip_save umip rdpid
Versions of relevant libraries:
[pip3] numpy==1.24.4
[pip3] torch==2.2.0a0+gitfd79995
[pip3] torch-xla==2.2.0+gitcb32668
[pip3] torchvision==0.16.1
[conda] Could not collect
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225
The text was updated successfully, but these errors were encountered: