Skip to content

Commit

Permalink
Migrate distributed state dict API (#2138)
Browse files Browse the repository at this point in the history
  • Loading branch information
mori360 authored Jan 8, 2025
1 parent 27fd3a1 commit 38bf427
Show file tree
Hide file tree
Showing 12 changed files with 250 additions and 196 deletions.
6 changes: 3 additions & 3 deletions recipes/dev/early_exit_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,6 @@ def _setup_model(
model,
model_state_dict,
self._device,
self._is_rank_zero,
strict=True,
cpu_offload=fsdp_cpu_offload,
)
Expand Down Expand Up @@ -757,7 +756,7 @@ def save_checkpoint(
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._model,
self._is_rank_zero,
device=self._device,
)
Expand All @@ -773,6 +772,7 @@ def save_checkpoint(
log.info("Getting optimizer state dict...")
if not self._optimizer_in_bwd:
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand All @@ -781,7 +781,7 @@ def save_checkpoint(
opt_state_dict = {}
for param, opt in self._optim_ckpt_wrapper.optim_map.items():
opt_state_dict[param] = training.get_full_optimizer_state_dict(
opt, self._is_rank_zero, device=self._device
self._model, opt, self._is_rank_zero, device=self._device
)
if self._is_rank_zero:
log.info(
Expand Down
3 changes: 2 additions & 1 deletion recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,6 @@ def _setup_model(
model,
model_state_dict,
self._device,
self._is_rank_zero,
strict=True,
cpu_offload=fsdp_cpu_offload,
)
Expand Down Expand Up @@ -602,6 +601,7 @@ def _setup_optimizer(
for param in opt_state_dict.keys():
try:
training.load_from_full_optimizer_state_dict(
self._model,
self._optim_ckpt_wrapper.state_dict()[param],
opt_state_dict[param],
self._device,
Expand All @@ -617,6 +617,7 @@ def _setup_optimizer(
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
self._model,
optimizer,
opt_state_dict,
self._device,
Expand Down
7 changes: 3 additions & 4 deletions recipes/knowledge_distillation_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,6 @@ def _setup_model(
model,
lora_weights_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
else:
Expand All @@ -486,7 +485,6 @@ def _setup_model(
model,
base_model_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
for m in model.modules():
Expand Down Expand Up @@ -574,7 +572,6 @@ def _setup_teacher_model(
model,
model_state_dict,
self._device,
self._is_rank_zero,
strict=True,
cpu_offload=fsdp_cpu_offload,
)
Expand Down Expand Up @@ -611,6 +608,7 @@ def _setup_optimizer(
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
self._model,
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -705,13 +703,14 @@ def save_checkpoint(self, epoch: int) -> None:
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._model,
self._is_rank_zero,
device=self._device,
)

if intermediate_checkpoint:
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand Down
11 changes: 4 additions & 7 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,6 @@ def _setup_model(
model,
lora_weights_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
else:
Expand All @@ -410,7 +409,6 @@ def _setup_model(
model,
base_model_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
is_dora = False
Expand Down Expand Up @@ -458,6 +456,7 @@ def _setup_optimizer(
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
self._model,
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -546,17 +545,15 @@ def save_checkpoint(
intermediate_checkpoint = epoch + 1 < self.total_epochs
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
state_dict = self._model.state_dict()
if self._save_adapter_weights_only:
state_dict = get_adapter_state_dict(state_dict, device=None)

cpu_state_dict = training.gather_cpu_state_dict(
state_dict,
self._model,
self._is_rank_zero,
device=self._device,
adapter_weights_only=self._save_adapter_weights_only,
)
if intermediate_checkpoint:
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand Down
11 changes: 4 additions & 7 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,6 @@ def _setup_model(
model,
lora_weights_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
else:
Expand All @@ -505,7 +504,6 @@ def _setup_model(
model,
base_model_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
for m in model.modules():
Expand Down Expand Up @@ -549,6 +547,7 @@ def _setup_optimizer(
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
self._model,
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -656,14 +655,11 @@ def save_checkpoint(

# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
state_dict = self._model.state_dict()
if self._save_adapter_weights_only:
state_dict = get_adapter_state_dict(state_dict, device=None)

cpu_state_dict = training.gather_cpu_state_dict(
state_dict,
self._model,
self._is_rank_zero,
device=self._device,
adapter_weights_only=self._save_adapter_weights_only,
)
utils.log_rank_zero(
log,
Expand All @@ -673,6 +669,7 @@ def save_checkpoint(
if intermediate_checkpoint:
utils.log_rank_zero(log, "Retrieving optimizer state dict...")
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand Down
2 changes: 0 additions & 2 deletions recipes/lora_finetune_distributed_multi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,6 @@ def _setup_model(
model,
lora_weights_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
else:
Expand All @@ -500,7 +499,6 @@ def _setup_model(
model,
base_model_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
for m in model.modules():
Expand Down
8 changes: 5 additions & 3 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,6 @@ def _setup_model(
model,
model_state_dict,
self._device,
self._is_rank_zero,
strict=True,
cpu_offload=fsdp_cpu_offload,
)
Expand Down Expand Up @@ -562,6 +561,7 @@ def _setup_optimizer(
for param in opt_state_dict.keys():
try:
training.load_from_full_optimizer_state_dict(
self._model,
self._optim_ckpt_wrapper.state_dict()[param],
opt_state_dict[param],
self._device,
Expand All @@ -577,6 +577,7 @@ def _setup_optimizer(
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
self._model,
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -667,7 +668,7 @@ def save_checkpoint(
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._model,
self._is_rank_zero,
device=self._device,
)
Expand All @@ -682,6 +683,7 @@ def save_checkpoint(
utils.log_rank_zero(log, "Getting optimizer state dict...")
if not self._optimizer_in_bwd:
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand All @@ -690,7 +692,7 @@ def save_checkpoint(
opt_state_dict = {}
for param, opt in self._optim_ckpt_wrapper.optim_map.items():
opt_state_dict[param] = training.get_full_optimizer_state_dict(
opt, self._is_rank_zero, device=self._device
self._model, opt, self._is_rank_zero, device=self._device
)
utils.log_rank_zero(
log,
Expand Down
11 changes: 4 additions & 7 deletions recipes/qat_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,6 @@ def _setup_model(
model,
lora_weights_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
else:
Expand All @@ -550,7 +549,6 @@ def _setup_model(
model,
base_model_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
validate_missing_and_unexpected_for_lora(
Expand Down Expand Up @@ -589,6 +587,7 @@ def _setup_optimizer(
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
self._model,
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -699,14 +698,11 @@ def save_checkpoint(

# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
state_dict = self._model.state_dict()
if self._save_adapter_weights_only:
state_dict = get_adapter_state_dict(state_dict, device=None)

cpu_state_dict = training.gather_cpu_state_dict(
state_dict,
self._model,
self._is_rank_zero,
device=self._device,
adapter_weights_only=self._save_adapter_weights_only,
)
if self._is_rank_zero:
log.info(
Expand All @@ -717,6 +713,7 @@ def save_checkpoint(
if self._is_rank_zero:
log.info("Retrieving optimizer state dict...")
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand Down
2 changes: 0 additions & 2 deletions tests/torchtune/modules/peft/test_dora.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,6 @@ def _test_dora_distributed_init(self, load_dora_weights):
ffn,
adapter_state_dict,
device,
is_rank_zero,
)
if is_rank_zero:
for dora_linear in [ffn.w1, ffn.w2, ffn.w3]:
Expand Down Expand Up @@ -377,7 +376,6 @@ def _test_dora_distributed_init(self, load_dora_weights):
ffn,
base_model_state_dict,
device,
is_rank_zero,
)

# After this, everything should be off meta device
Expand Down
13 changes: 5 additions & 8 deletions tests/torchtune/training/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,9 @@ def test_lora_state_dict(self):
fsdp_optim_to_save.zero_grad()
expected_model_sd = base_model.state_dict()
expected_optim_sd = base_optim.state_dict()
model_full_sd = training.gather_cpu_state_dict(
fsdp_model_to_save.state_dict(), is_rank_zero
)
model_full_sd = training.gather_cpu_state_dict(fsdp_model_to_save, is_rank_zero)
optim_full_sd = training.get_full_optimizer_state_dict(
fsdp_model_to_save,
fsdp_optim_to_save,
is_rank_zero,
)
Expand Down Expand Up @@ -222,12 +221,12 @@ def test_lora_state_dict(self):
fsdp_model_to_load,
copy.deepcopy(base_model.state_dict()),
torch.device("cuda"),
is_rank_zero,
)
fsdp_optim_to_load = torch.optim.Adam(
fsdp_model_to_load.parameters(), weight_decay=0.01, lr=0.01
)
training.load_from_full_optimizer_state_dict(
fsdp_model_to_load,
fsdp_optim_to_load,
# mimic mmap=True where every rank see full SD
copy.deepcopy(self._broadcast_full_state_dict(optim_full_sd)),
Expand Down Expand Up @@ -324,9 +323,7 @@ def _test_qlora_state_dict(self, enable_activation_checkpointing: bool):
fsdp_model_to_save(inp)

expected_model_sd = {k: v.cpu() for k, v in base_model.state_dict().items()}
model_full_sd = training.gather_cpu_state_dict(
fsdp_model_to_save.state_dict(), is_rank_zero
)
model_full_sd = training.gather_cpu_state_dict(fsdp_model_to_save, is_rank_zero)
if is_rank_zero:
self.assertEqual(set(model_full_sd.keys()), set(expected_model_sd.keys()))
for key, value in model_full_sd.items():
Expand Down Expand Up @@ -357,7 +354,7 @@ def _test_qlora_state_dict(self, enable_activation_checkpointing: bool):
fully_shard(m)
fully_shard(fsdp_model_to_load)
training.load_from_full_model_state_dict(
fsdp_model_to_load, expected_model_sd, torch.device("cuda"), is_rank_zero
fsdp_model_to_load, expected_model_sd, torch.device("cuda")
)
fsdp_model_to_load(inp)
sharded_model_sd = fsdp_model_to_load.state_dict()
Expand Down
Loading

0 comments on commit 38bf427

Please sign in to comment.