Skip to content

Commit

Permalink
Add support to benchmark multi-node checkpointing with default FSDP s…
Browse files Browse the repository at this point in the history
…trategy (#151)

* Add option to select regular FSDP strategy

* Raise TypeError if strategy is none of the expected values

* Set torch version

* Undo changes to requirements.txt

* Add a function that constructs appropriate strategy object

* Pass use_orig_params to default FSDP strategy

* Return ValueError instead of TypeError
  • Loading branch information
abhibyreddi authored Oct 15, 2024
1 parent da087ea commit f4f9c5b
Showing 1 changed file with 41 additions and 41 deletions.
82 changes: 41 additions & 41 deletions dataflux_pytorch/benchmark/checkpointing/multinode/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos import WikiText2
from lightning.pytorch.strategies import FSDPStrategy
import torch.distributed
from torch.utils.data import DataLoader

from demo.lightning.checkpoint.multinode.fsspecfsdp import FSSpecFSDPStrategy
Expand All @@ -17,16 +19,46 @@

DF_FSDP_STRATEGY = "dataflux_fsdp"
FSSPEC_FSDP_STRATEGY = "fsspec_fsdp"
FSDP_STRATEGY = "fsdp"


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--strategy',
choices=[DF_FSDP_STRATEGY, FSSPEC_FSDP_STRATEGY],
default=DF_FSDP_STRATEGY)
parser.add_argument(
'--strategy',
choices=[DF_FSDP_STRATEGY, FSSPEC_FSDP_STRATEGY, FSDP_STRATEGY],
default=DF_FSDP_STRATEGY)
return parser.parse_args()


def get_strategy(choice, model, ckpt_dir_path):
project = os.getenv("PROJECT")
strategy = None
if choice == DF_FSDP_STRATEGY:
print("Using DatafluxFSDPStrategy")
strategy = DatafluxFSDPStrategy(
path=ckpt_dir_path,
project_name=project,
storage_client=None,
model=model,
state_dict_type="sharded",
use_orig_params=False,
)
elif choice == FSSPEC_FSDP_STRATEGY:
print("Using FSSpecFSDPStrategy")
strategy = FSSpecFSDPStrategy(path=ckpt_dir_path,
model=model,
state_dict_type="sharded",
use_orig_params=False)
elif choice == FSDP_STRATEGY:
print("Using FSDPStrategy.")
strategy = FSDPStrategy(state_dict_type="sharded",
use_orig_params=False)
else:
raise ValueError("Invalid strategy.")
return strategy


def main(project: str,
ckpt_dir_path: str,
save_only_latest: bool,
Expand All @@ -50,23 +82,7 @@ def main(project: str,
enable_version_counter=True,
)

strategy = None
if args.strategy == DF_FSDP_STRATEGY:
print("Using DatafluxFSDPStrategy")
strategy = DatafluxFSDPStrategy(
path=ckpt_dir_path,
project_name=project,
storage_client=None,
model=model,
state_dict_type="sharded",
use_orig_params=False,
)
else:
print("Using FSSpecFSDPStrategy")
strategy = FSSpecFSDPStrategy(path=ckpt_dir_path,
model=model,
state_dict_type="sharded",
use_orig_params=False)
strategy = get_strategy(args.strategy, model, ckpt_dir_path)
min_epochs_save = int(os.environ.get("MIN_EPOCHS_SAVE", 4))
max_epochs_save = int(os.environ.get("MAX_EPOCHS_SAVE", 5))
max_steps_save = int(os.environ.get("MAX_STEPS_SAVE", 3))
Expand Down Expand Up @@ -107,24 +123,8 @@ def main(project: str,
)
model = DemoTransformer(vocab_size=dataset.vocab_size,
nlayers=int(os.environ.get("NUM_LAYERS", 10)))
new_path = os.path.join(ckpt_restore_path, f'ckpt_{i}.ckpt/')
strategy = None
if args.strategy == DF_FSDP_STRATEGY:
print("Using DatafluxFSDPStrategy")
strategy = DatafluxFSDPStrategy(
path=new_path,
project_name=project,
storage_client=None,
model=model,
state_dict_type="sharded",
use_orig_params=False,
)
else:
print("Using FSSpecFSDPStrategy")
strategy = FSSpecFSDPStrategy(path=new_path,
model=model,
state_dict_type="sharded",
use_orig_params=False)
new_ckpt_dir_path = os.path.join(ckpt_restore_path, f'ckpt_{i}.ckpt/')
strategy = get_strategy(args.strategy, model, new_ckpt_dir_path)
trainer = Trainer(
default_root_dir=ckpt_dir_path,
plugins=[],
Expand All @@ -137,13 +137,13 @@ def main(project: str,
devices=os.environ.get("NUM_DEVICES", 'auto'),
num_nodes=num_nodes,
)
trainer.fit(model, dataloader, ckpt_path=new_path)
trainer.fit(model, dataloader, ckpt_path=new_ckpt_dir_path)
start = time.time()
trainer.strategy.load_checkpoint(new_path)
trainer.strategy.load_checkpoint(new_ckpt_dir_path)
end = time.time()

if torch.distributed.get_rank() == 0:
print(f"Loaded checkpoint from {new_path}.")
print(f"Loaded checkpoint from {new_ckpt_dir_path}.")
load_checkpoint_times.append(end - start)

if torch.distributed.get_rank() == 0:
Expand Down

0 comments on commit f4f9c5b

Please sign in to comment.