From 4a1e2b88321a91db5a46b3c9b609ed22627c4464 Mon Sep 17 00:00:00 2001 From: Divya Rawal Date: Mon, 8 Jul 2024 23:33:26 +0000 Subject: [PATCH 01/10] Added benchmark folder with edits and flags to rerun benchmarks for dataflux lightning checkcpoints --- .../benchmark/checkpoint_lightning.py | 78 +++++++++++++++++ .../lightning_checkpoint_benchmark.py | 83 +++++++++++++++++++ .../dataflux_lightning_checkpoint.py | 2 +- 3 files changed, 162 insertions(+), 1 deletion(-) create mode 100644 dataflux_pytorch/benchmark/checkpoint_lightning.py create mode 100644 dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py diff --git a/dataflux_pytorch/benchmark/checkpoint_lightning.py b/dataflux_pytorch/benchmark/checkpoint_lightning.py new file mode 100644 index 00000000..e493fbfd --- /dev/null +++ b/dataflux_pytorch/benchmark/checkpoint_lightning.py @@ -0,0 +1,78 @@ +import os +import time +import torch + +from typing import Tuple +from torch.utils.data import DataLoader +from torch import Tensor +from torch.utils.data import DataLoader +from lightning import Trainer +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.demos import WikiText2, Transformer +from lightning.pytorch import LightningModule + +from dataflux_pytorch.lightning import DatafluxLightningCheckpoint + +class LightningTransformer(LightningModule): + def __init__(self, vocab_size: int = 33278, nlayers: int = 100) -> None: + super().__init__() + self.model = Transformer(vocab_size=vocab_size, nlayers=nlayers) + + def forward(self, inputs: Tensor, target: Tensor) -> Tensor: + return self.model(inputs, target) + + def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: + inputs, target = batch + output = self(inputs, target) + loss = torch.nn.functional.nll_loss(output, target.view(-1)) + return loss + + def configure_optimizers(self) -> torch.optim.Optimizer: + return torch.optim.SGD(self.model.parameters(), lr=0.1) + + def prepare_data(self) -> None: + WikiText2(download=True) + + def train_dataloader(self) -> DataLoader: + dataset = WikiText2() + return DataLoader(dataset) + + +def main(project: str, bucket: str, ckpt_dir_path: str, save_only_latest: bool, model_size: int = 100): + dataset = WikiText2() + dataloader = DataLoader(dataset, num_workers=1) + model = LightningTransformer(vocab_size=dataset.vocab_size, nlayers=model_size) + dataflux_ckpt = DatafluxLightningCheckpoint(project_name=project,bucket_name=bucket) + + # Save once per step, and if `save_only_latest`, replace the last checkpoint each time. + # Replacing is implemented by saving the new checkpoint, and then deleting the previous one. + # If `save_only_latest` is False, a new checkpoint is created for each step. + checkpoint_callback = ModelCheckpoint( + save_top_k=1 if save_only_latest else -1, + every_n_train_steps=1, + filename="checkpoint-{epoch:02d}-{step:02d}", + enable_version_counter=True, + ) + trainer = Trainer( + default_root_dir=ckpt_dir_path, + plugins=[dataflux_ckpt], + callbacks=[checkpoint_callback], + min_epochs=4, + max_epochs=5, + max_steps=5, + accelerator="cpu", + ) + start = time.time() + trainer.fit(model, dataloader) + end = time.time() + print(end-start) + +if __name__ == "__main__": + + main( + os.getenv("PROJECT"), + os.getenv("BUCKET"), + os.getenv("CKPT_DIR_PATH"), + os.getenv("SAVE_ONLY_LATEST") == "1", + int(os.getenv("CHECKPOINT_SIZE")), + ) diff --git a/dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py b/dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py new file mode 100644 index 00000000..340eccb9 --- /dev/null +++ b/dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py @@ -0,0 +1,83 @@ +import os +import time +import torch + +from typing import Tuple +from torch.utils.data import DataLoader +from torch import Tensor +from torch.utils.data import DataLoader +from lightning import Trainer +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.demos import WikiText2, Transformer +from lightning.pytorch import LightningModule + +from dataflux_pytorch.lightning import DatafluxLightningCheckpoint + +class LightningTransformer(LightningModule): + def __init__(self, vocab_size: int = 33278, nlayers: int = 100) -> None: + super().__init__() + self.model = Transformer(vocab_size=vocab_size, nlayers=nlayers) + + def forward(self, inputs: Tensor, target: Tensor) -> Tensor: + return self.model(inputs, target) + + def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: + inputs, target = batch + output = self(inputs, target) + loss = torch.nn.functional.nll_loss(output, target.view(-1)) + return loss + + def configure_optimizers(self) -> torch.optim.Optimizer: + return torch.optim.SGD(self.model.parameters(), lr=0.1) + + def prepare_data(self) -> None: + WikiText2(download=True) + + def train_dataloader(self) -> DataLoader: + dataset = WikiText2() + return DataLoader(dataset) + + +def main(project: str, bucket: str, ckpt_dir_path: str, save_only_latest: bool, model_size: int = 100, steps: int = 5): + dataset = WikiText2() + dataloader = DataLoader(dataset, num_workers=1) + model = LightningTransformer(vocab_size=dataset.vocab_size, nlayers=model_size) + dataflux_ckpt = DatafluxLightningCheckpoint(project_name=project,bucket_name=bucket) + # Save once per step, and if `save_only_latest`, replace the last checkpoint each time. + # Replacing is implemented by saving the new checkpoint, and then deleting the previous one. + # If `save_only_latest` is False, a new checkpoint is created for each step. + checkpoint_callback = ModelCheckpoint( + save_top_k=1 if save_only_latest else -1, + every_n_train_steps=1, + filename="checkpoint-{epoch:02d}-{step:02d}", + enable_version_counter=True, + ) + trainer = Trainer( + default_root_dir=ckpt_dir_path, + plugins=[dataflux_ckpt], + callbacks=[checkpoint_callback], + min_epochs=4, + max_epochs=5, + max_steps=5, + accelerator="cpu", + ) + start = time.time() + trainer.fit(model, dataloader) + end = time.time() + print(end-start) + +if __name__ == "__main__": + + DEFAULT_SIZE = 100 + DEFAULT_STEPS = 5 + size = int(os.getenv("CHECKPOINT_SIZE'",DEFAULT_SIZE)) + steps = int(os.getenv("STEPS'",DEFAULT_STEPS)) + + main( + os.getenv("PROJECT"), + os.getenv("BUCKET"), + os.getenv("CKPT_DIR_PATH"), + os.getenv("SAVE_ONLY_LATEST") == "1", + size, + steps, + ) diff --git a/dataflux_pytorch/lightning/dataflux_lightning_checkpoint.py b/dataflux_pytorch/lightning/dataflux_lightning_checkpoint.py index 3deb20e3..d2356d0d 100644 --- a/dataflux_pytorch/lightning/dataflux_lightning_checkpoint.py +++ b/dataflux_pytorch/lightning/dataflux_lightning_checkpoint.py @@ -51,7 +51,7 @@ def save_checkpoint( key = self._parse_gcs_path(path) blob = self.bucket.blob(key) with blob.open("wb", ignore_flush=True) as blobwriter: - torch.save(checkpoint, blob.open("wb", ignore_flush=True)) + torch.save(checkpoint, blobwriter) def load_checkpoint( self, From e5c7037f32e7c20ffac4f1bb28c431311a574281 Mon Sep 17 00:00:00 2001 From: Divya Rawal Date: Mon, 8 Jul 2024 23:34:02 +0000 Subject: [PATCH 02/10] Pulled changes from client --- dataflux_client_python | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataflux_client_python b/dataflux_client_python index ac12adba..8b9cb11f 160000 --- a/dataflux_client_python +++ b/dataflux_client_python @@ -1 +1 @@ -Subproject commit ac12adba3aa3f4a6d27d3703223aac4b2f8b850c +Subproject commit 8b9cb11f62f7e7494ddbdd26d32d6e942d06b033 From ba91a0fa55d24fdcb697db2b3432a6e4a7c64900 Mon Sep 17 00:00:00 2001 From: Divya Rawal Date: Tue, 9 Jul 2024 18:43:44 +0000 Subject: [PATCH 03/10] Allow for testing on non dataflux lightning checkpoints --- dataflux_client_python | 2 +- .../benchmark/checkpoint_lightning.py | 78 ------------------- .../lightning_checkpoint_benchmark.py | 16 ++-- 3 files changed, 11 insertions(+), 85 deletions(-) delete mode 100644 dataflux_pytorch/benchmark/checkpoint_lightning.py diff --git a/dataflux_client_python b/dataflux_client_python index 8b9cb11f..4a4acb55 160000 --- a/dataflux_client_python +++ b/dataflux_client_python @@ -1 +1 @@ -Subproject commit 8b9cb11f62f7e7494ddbdd26d32d6e942d06b033 +Subproject commit 4a4acb5543758bc9910ac74a656564921dc3e225 diff --git a/dataflux_pytorch/benchmark/checkpoint_lightning.py b/dataflux_pytorch/benchmark/checkpoint_lightning.py deleted file mode 100644 index e493fbfd..00000000 --- a/dataflux_pytorch/benchmark/checkpoint_lightning.py +++ /dev/null @@ -1,78 +0,0 @@ -import os -import time -import torch - -from typing import Tuple -from torch.utils.data import DataLoader -from torch import Tensor -from torch.utils.data import DataLoader -from lightning import Trainer -from lightning.pytorch.callbacks import ModelCheckpoint -from lightning.pytorch.demos import WikiText2, Transformer -from lightning.pytorch import LightningModule - -from dataflux_pytorch.lightning import DatafluxLightningCheckpoint - -class LightningTransformer(LightningModule): - def __init__(self, vocab_size: int = 33278, nlayers: int = 100) -> None: - super().__init__() - self.model = Transformer(vocab_size=vocab_size, nlayers=nlayers) - - def forward(self, inputs: Tensor, target: Tensor) -> Tensor: - return self.model(inputs, target) - - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: - inputs, target = batch - output = self(inputs, target) - loss = torch.nn.functional.nll_loss(output, target.view(-1)) - return loss - - def configure_optimizers(self) -> torch.optim.Optimizer: - return torch.optim.SGD(self.model.parameters(), lr=0.1) - - def prepare_data(self) -> None: - WikiText2(download=True) - - def train_dataloader(self) -> DataLoader: - dataset = WikiText2() - return DataLoader(dataset) - - -def main(project: str, bucket: str, ckpt_dir_path: str, save_only_latest: bool, model_size: int = 100): - dataset = WikiText2() - dataloader = DataLoader(dataset, num_workers=1) - model = LightningTransformer(vocab_size=dataset.vocab_size, nlayers=model_size) - dataflux_ckpt = DatafluxLightningCheckpoint(project_name=project,bucket_name=bucket) - - # Save once per step, and if `save_only_latest`, replace the last checkpoint each time. - # Replacing is implemented by saving the new checkpoint, and then deleting the previous one. - # If `save_only_latest` is False, a new checkpoint is created for each step. - checkpoint_callback = ModelCheckpoint( - save_top_k=1 if save_only_latest else -1, - every_n_train_steps=1, - filename="checkpoint-{epoch:02d}-{step:02d}", - enable_version_counter=True, - ) - trainer = Trainer( - default_root_dir=ckpt_dir_path, - plugins=[dataflux_ckpt], - callbacks=[checkpoint_callback], - min_epochs=4, - max_epochs=5, - max_steps=5, - accelerator="cpu", - ) - start = time.time() - trainer.fit(model, dataloader) - end = time.time() - print(end-start) - -if __name__ == "__main__": - - main( - os.getenv("PROJECT"), - os.getenv("BUCKET"), - os.getenv("CKPT_DIR_PATH"), - os.getenv("SAVE_ONLY_LATEST") == "1", - int(os.getenv("CHECKPOINT_SIZE")), - ) diff --git a/dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py b/dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py index 340eccb9..66708287 100644 --- a/dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py +++ b/dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py @@ -10,6 +10,7 @@ from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos import WikiText2, Transformer from lightning.pytorch import LightningModule +from lightning.pytorch.plugins.io import TorchCheckpointIO from dataflux_pytorch.lightning import DatafluxLightningCheckpoint @@ -38,11 +39,13 @@ def train_dataloader(self) -> DataLoader: return DataLoader(dataset) -def main(project: str, bucket: str, ckpt_dir_path: str, save_only_latest: bool, model_size: int = 100, steps: int = 5): +def main(project: str, bucket: str, ckpt_dir_path: str, save_only_latest: bool, dataflux_ckpt: bool, model_size: int = 100, steps: int = 5): dataset = WikiText2() dataloader = DataLoader(dataset, num_workers=1) model = LightningTransformer(vocab_size=dataset.vocab_size, nlayers=model_size) - dataflux_ckpt = DatafluxLightningCheckpoint(project_name=project,bucket_name=bucket) + ckpt = TorchCheckpointIO() + if dataflux_ckpt: + ckpt = DatafluxLightningCheckpoint(project_name=project,bucket_name=bucket) # Save once per step, and if `save_only_latest`, replace the last checkpoint each time. # Replacing is implemented by saving the new checkpoint, and then deleting the previous one. # If `save_only_latest` is False, a new checkpoint is created for each step. @@ -54,11 +57,11 @@ def main(project: str, bucket: str, ckpt_dir_path: str, save_only_latest: bool, ) trainer = Trainer( default_root_dir=ckpt_dir_path, - plugins=[dataflux_ckpt], + plugins=[ckpt], callbacks=[checkpoint_callback], min_epochs=4, max_epochs=5, - max_steps=5, + max_steps=steps, accelerator="cpu", ) start = time.time() @@ -70,14 +73,15 @@ def main(project: str, bucket: str, ckpt_dir_path: str, save_only_latest: bool, DEFAULT_SIZE = 100 DEFAULT_STEPS = 5 - size = int(os.getenv("CHECKPOINT_SIZE'",DEFAULT_SIZE)) - steps = int(os.getenv("STEPS'",DEFAULT_STEPS)) + size = int(os.getenv("CHECKPOINT_SIZE",DEFAULT_SIZE)) + steps = int(os.getenv("STEPS",DEFAULT_STEPS)) main( os.getenv("PROJECT"), os.getenv("BUCKET"), os.getenv("CKPT_DIR_PATH"), os.getenv("SAVE_ONLY_LATEST") == "1", + os.getenv("DATAFLUX_CKPT") == "1", size, steps, ) From 3dd7b0bee1dcdbd3b7c78ef540eca2f14d697c9c Mon Sep 17 00:00:00 2001 From: Divya Rawal Date: Tue, 9 Jul 2024 22:27:45 +0000 Subject: [PATCH 04/10] README file walking through how to run the benchmarking --- dataflux_pytorch/benchmark/README.md | 34 ++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 dataflux_pytorch/benchmark/README.md diff --git a/dataflux_pytorch/benchmark/README.md b/dataflux_pytorch/benchmark/README.md new file mode 100644 index 00000000..0494210e --- /dev/null +++ b/dataflux_pytorch/benchmark/README.md @@ -0,0 +1,34 @@ +# Benchmarking PyTorch Lightning Checkpoints with Google Cloud Storage + +This benchmarking script will allow you to run and benchmark the performance of the PyTorch Lightning Checkpoint save function. This script does not rely on GPUs or CPU Clusters and can be run directly on your machine. The script runs the `WikiText2` PyTorch Lightining demo code with some modifications. + +First ensure you are running within a virtual python enviroment, then set the enviroment variables. + +`CKPT_DIR_PATH` is the location of where to save the checkpoints. `STEPS` is the number of steps the model will take (the number of checkpoints created will be the same). The default value for `STEPS` is 5. + +```shell +export CKPT_DIR_PATH=`gs://path/to/directory/` +export STEPS=5 +``` + +You can also optionally change the size of the model. The size variable will be passed into `nn.Transformer` for `num_encoder_layers` and `num_decoder_layers`. The default value for size is 100. + +```shell +export CHECKPOINT_SIZE=`1000` +``` + +If you are benchmarking Dataflux Lightning Checkpoint, save information regarding your project and bucket, and make sure to enable the flag by setting it to `1`. + +```shell +export PROJECT=`YOUR_PROJECT_NAME` +export BUCKET=`YOUR_BUCKET_NAME` +export DATAFLUX_CKPT=1 +``` + +Run the script. + +```shell +python dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py +``` + +The time will print out and the checkpoints can be viewed in GCS at the location passed in. From 7ef91a1b73d8960ade54727ae55fd9b22373dabb Mon Sep 17 00:00:00 2001 From: Divya Rawal Date: Wed, 10 Jul 2024 00:05:29 +0000 Subject: [PATCH 05/10] Added tables with information regarding dataflux checkpointing performance as it is today --- dataflux_pytorch/benchmark/README.md | 105 +++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/dataflux_pytorch/benchmark/README.md b/dataflux_pytorch/benchmark/README.md index 0494210e..a8244624 100644 --- a/dataflux_pytorch/benchmark/README.md +++ b/dataflux_pytorch/benchmark/README.md @@ -2,6 +2,8 @@ This benchmarking script will allow you to run and benchmark the performance of the PyTorch Lightning Checkpoint save function. This script does not rely on GPUs or CPU Clusters and can be run directly on your machine. The script runs the `WikiText2` PyTorch Lightining demo code with some modifications. +### Getting started + First ensure you are running within a virtual python enviroment, then set the enviroment variables. `CKPT_DIR_PATH` is the location of where to save the checkpoints. `STEPS` is the number of steps the model will take (the number of checkpoints created will be the same). The default value for `STEPS` is 5. @@ -17,6 +19,8 @@ You can also optionally change the size of the model. The size variable will be export CHECKPOINT_SIZE=`1000` ``` +### Dataflux Lightning Checkpoint + If you are benchmarking Dataflux Lightning Checkpoint, save information regarding your project and bucket, and make sure to enable the flag by setting it to `1`. ```shell @@ -25,6 +29,8 @@ export BUCKET=`YOUR_BUCKET_NAME` export DATAFLUX_CKPT=1 ``` +### Running + Run the script. ```shell @@ -32,3 +38,102 @@ python dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py ``` The time will print out and the checkpoints can be viewed in GCS at the location passed in. + +### Results + +Type Size Steps Time +Default +Dataflux + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Checkpoint Type + Size + Checkpoint Size (MB) per step + Steps + Time (s) +
Default + 10 + 75.6 + 5 + 13.25 +
Dataflux + 10 + 75.6 + 5 + 14.08 +
Default + 100 + 298 + 5 + 36.55 +
Dataflux + 100 + 298 + 5 + 44.07 +
Default + 1000 + 2500 + 5 + 266.16 +
Dataflux + 1000 + 2500 + 349.19 + +
From 3d72a0372b703c871b542afbff83395cf7c9d7aa Mon Sep 17 00:00:00 2001 From: Divya Rawal Date: Wed, 10 Jul 2024 18:33:54 +0000 Subject: [PATCH 06/10] Table format fix --- dataflux_pytorch/benchmark/README.md | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/dataflux_pytorch/benchmark/README.md b/dataflux_pytorch/benchmark/README.md index a8244624..ac06fe0f 100644 --- a/dataflux_pytorch/benchmark/README.md +++ b/dataflux_pytorch/benchmark/README.md @@ -40,11 +40,6 @@ python dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py The time will print out and the checkpoints can be viewed in GCS at the location passed in. ### Results - -Type Size Steps Time -Default -Dataflux - + - -
Checkpoint Type @@ -129,11 +124,9 @@ Dataflux 2500 5 349.19 -
From ad7b1c6b15c81c7da81c0a4450e829297eaaa477 Mon Sep 17 00:00:00 2001 From: Divya Rawal Date: Wed, 10 Jul 2024 20:55:40 +0000 Subject: [PATCH 07/10] Included description about the results table --- dataflux_pytorch/benchmark/README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dataflux_pytorch/benchmark/README.md b/dataflux_pytorch/benchmark/README.md index ac06fe0f..34790758 100644 --- a/dataflux_pytorch/benchmark/README.md +++ b/dataflux_pytorch/benchmark/README.md @@ -40,6 +40,11 @@ python dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py The time will print out and the checkpoints can be viewed in GCS at the location passed in. ### Results + +The table below contains benchmarking times to run trainer.fit() with writing checkpoints to GCS. + +Dataflux's implementation of CheckpointIO for PyTorch Lightning is undergoing active development. The numbers below will be continuously updated to reflect the current state and performance of Dataflux's PyTorch Lightning checkpoint utility. These values are compared to `Default`, which refers to fsspec. + - - + + @@ -69,6 +103,10 @@ Dataflux's implementation of CheckpointIO for PyTorch Lightning is undergoing ac + + - - + + + + - - + + + + + +
Checkpoint Type From 05a11cb535fcc1f35f2ce418698bef12ec0e5150 Mon Sep 17 00:00:00 2001 From: Divya Rawal Date: Wed, 10 Jul 2024 21:10:27 +0000 Subject: [PATCH 08/10] Typos --- dataflux_pytorch/benchmark/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dataflux_pytorch/benchmark/README.md b/dataflux_pytorch/benchmark/README.md index 34790758..620d54d8 100644 --- a/dataflux_pytorch/benchmark/README.md +++ b/dataflux_pytorch/benchmark/README.md @@ -1,6 +1,6 @@ # Benchmarking PyTorch Lightning Checkpoints with Google Cloud Storage -This benchmarking script will allow you to run and benchmark the performance of the PyTorch Lightning Checkpoint save function. This script does not rely on GPUs or CPU Clusters and can be run directly on your machine. The script runs the `WikiText2` PyTorch Lightining demo code with some modifications. +This benchmarking script will allow you to run and benchmark the performance of the PyTorch Lightning Checkpoint save function. This script does not rely on GPUs, TPUs or CPU Clusters and can be run directly on your machine. The script runs the `WikiText2` PyTorch Lightning demo code with some modifications. ### Getting started @@ -16,7 +16,7 @@ export STEPS=5 You can also optionally change the size of the model. The size variable will be passed into `nn.Transformer` for `num_encoder_layers` and `num_decoder_layers`. The default value for size is 100. ```shell -export CHECKPOINT_SIZE=`1000` +export CHECKPOINT_SIZE=1000 ``` ### Dataflux Lightning Checkpoint From 3f0479a38bdada5fa2429ca9a4e932de783dd1e7 Mon Sep 17 00:00:00 2001 From: Divya Rawal Date: Wed, 10 Jul 2024 22:28:09 +0000 Subject: [PATCH 09/10] Updated Size var to Layers, added example output, and added throughput column to the results table --- dataflux_pytorch/benchmark/README.md | 78 ++++++++++++++++--- .../lightning_checkpoint_benchmark.py | 17 ++-- 2 files changed, 77 insertions(+), 18 deletions(-) diff --git a/dataflux_pytorch/benchmark/README.md b/dataflux_pytorch/benchmark/README.md index 620d54d8..9a50ab2b 100644 --- a/dataflux_pytorch/benchmark/README.md +++ b/dataflux_pytorch/benchmark/README.md @@ -2,7 +2,15 @@ This benchmarking script will allow you to run and benchmark the performance of the PyTorch Lightning Checkpoint save function. This script does not rely on GPUs, TPUs or CPU Clusters and can be run directly on your machine. The script runs the `WikiText2` PyTorch Lightning demo code with some modifications. -### Getting started +## Getting started + +### Installation + +```shell +pip install gcs-torch-dataflux gcsfs +``` + +### Configuration First ensure you are running within a virtual python enviroment, then set the enviroment variables. @@ -13,10 +21,10 @@ export CKPT_DIR_PATH=`gs://path/to/directory/` export STEPS=5 ``` -You can also optionally change the size of the model. The size variable will be passed into `nn.Transformer` for `num_encoder_layers` and `num_decoder_layers`. The default value for size is 100. +You can also optionally change the size of the model. The `LAYERS` variable will be passed into `nn.Transformer` for `num_encoder_layers` and `num_decoder_layers`. The default value for `LAYERS` is 100. ```shell -export CHECKPOINT_SIZE=1000 +export LAYERS=1000 ``` ### Dataflux Lightning Checkpoint @@ -37,11 +45,33 @@ Run the script. python dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py ``` -The time will print out and the checkpoints can be viewed in GCS at the location passed in. +The time will print out and the checkpoints can be viewed in GCS at the location passed in. A sample output is shown below. -### Results +```shell +$ python dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py +GPU available: False, used: False +TPU available: False, using: 0 TPU cores +HPU available: False, using: 0 HPUs +/usr/local/google/home/divyarawal/dataflux_dev/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default + + | Name | Type | Params | Mode +---------------------------------------------- +0 | model | Transformer | 19.8 M | train +---------------------------------------------- +19.8 M Trainable params +0 Non-trainable params +19.8 M Total params +79.189 Total estimated model params size (MB) +/usr/local/google/home/divyarawal/dataflux_dev/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance. +Epoch 0: 0%| | 5/59674 [00:11<39:37:59, 0.42it/s, v_num=2]`Trainer.fit` stopped: `max_steps=5` reached. +Epoch 0: 0%| | 5/59674 [00:11<39:38:05, 0.42it/s, v_num=2] +Time to train over 5 steps: 14.197517395019531 seconds +Time to save one checkpoint: 2.075364589691162 seconds +``` -The table below contains benchmarking times to run trainer.fit() with writing checkpoints to GCS. +## Results + +The table below contains benchmarking times writing checkpoints to GCS. Dataflux's implementation of CheckpointIO for PyTorch Lightning is undergoing active development. The numbers below will be continuously updated to reflect the current state and performance of Dataflux's PyTorch Lightning checkpoint utility. These values are compared to `Default`, which refers to fsspec. @@ -49,13 +79,17 @@ Dataflux's implementation of CheckpointIO for PyTorch Lightning is undergoing ac
Checkpoint Type Size + Layers Checkpoint Size (MB) per step Steps Time (s) + Train Time (s) + Single Checkpoint Save Time (s) + Write Throughput (MB/s)
13.25 1.64 + 46.09 +
Dataflux @@ -77,12 +115,14 @@ Dataflux's implementation of CheckpointIO for PyTorch Lightning is undergoing ac 75.6 5 14.08 2.07 + 36.52 +
Default @@ -95,6 +135,10 @@ Dataflux's implementation of CheckpointIO for PyTorch Lightning is undergoing ac 36.55 5.21 + 57.20 +
Dataflux @@ -103,12 +147,14 @@ Dataflux's implementation of CheckpointIO for PyTorch Lightning is undergoing ac 298 5 44.07 7.04 + 42.32 +
Default @@ -121,6 +167,10 @@ Dataflux's implementation of CheckpointIO for PyTorch Lightning is undergoing ac 266.16 39.14 + 63.87 +
Dataflux @@ -133,5 +183,9 @@ Dataflux's implementation of CheckpointIO for PyTorch Lightning is undergoing ac 349.19 53.71 + 46.55 +
diff --git a/dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py b/dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py index 66708287..8e0d5156 100644 --- a/dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py +++ b/dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py @@ -39,10 +39,10 @@ def train_dataloader(self) -> DataLoader: return DataLoader(dataset) -def main(project: str, bucket: str, ckpt_dir_path: str, save_only_latest: bool, dataflux_ckpt: bool, model_size: int = 100, steps: int = 5): +def main(project: str, bucket: str, ckpt_dir_path: str, save_only_latest: bool, dataflux_ckpt: bool, layers: int = 100, steps: int = 5): dataset = WikiText2() dataloader = DataLoader(dataset, num_workers=1) - model = LightningTransformer(vocab_size=dataset.vocab_size, nlayers=model_size) + model = LightningTransformer(vocab_size=dataset.vocab_size, nlayers=layers) ckpt = TorchCheckpointIO() if dataflux_ckpt: ckpt = DatafluxLightningCheckpoint(project_name=project,bucket_name=bucket) @@ -67,13 +67,18 @@ def main(project: str, bucket: str, ckpt_dir_path: str, save_only_latest: bool, start = time.time() trainer.fit(model, dataloader) end = time.time() - print(end-start) + print(f"Time to train over {steps} steps: " + str(end-start) + " seconds") + + start = time.time() + trainer.save_checkpoint(ckpt_dir_path) + end = time.time() + print("Time to save one checkpoint: " + str(end-start) + " seconds") if __name__ == "__main__": - DEFAULT_SIZE = 100 + DEFAULT_LAYERS = 100 DEFAULT_STEPS = 5 - size = int(os.getenv("CHECKPOINT_SIZE",DEFAULT_SIZE)) + layers = int(os.getenv("LAYERS",DEFAULT_LAYERS)) steps = int(os.getenv("STEPS",DEFAULT_STEPS)) main( @@ -82,6 +87,6 @@ def main(project: str, bucket: str, ckpt_dir_path: str, save_only_latest: bool, os.getenv("CKPT_DIR_PATH"), os.getenv("SAVE_ONLY_LATEST") == "1", os.getenv("DATAFLUX_CKPT") == "1", - size, + layers, steps, ) From 42fad571153642a52403a40caeb9b0f5b70ca319 Mon Sep 17 00:00:00 2001 From: Divya Rawal Date: Wed, 10 Jul 2024 22:59:04 +0000 Subject: [PATCH 10/10] Added description for main function --- .../lightning_checkpoint_benchmark.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py b/dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py index 8e0d5156..a25e268b 100644 --- a/dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py +++ b/dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py @@ -39,6 +39,41 @@ def train_dataloader(self) -> DataLoader: return DataLoader(dataset) +"""Checkpoints a PyTorch Ligthning demo model to GCS using gcsfs or DatafluxLightningCheckpoint. + +This function utilizes PyTorch Lightning to checkpoint the WikiText2 dataset. It +takes in information regarding the gcs location to save the checkpoints, the type of +checkpoint, and other configuration variables. Default this function runs on +gcsfs to write PyTorch Ligthtning checkpoints, TorchCheckpointIO. If dataflux_ckpt +is enabled the Trainer will be passed a DatafluxLightningCheckpoint, which is an +implementation of the CheckpointIO interface, as a plugin. + +Typical usage example: + + Run DatafluxLightningCheckpoint over 10 steps: + + project = 'test-project' + bucket = 'test-bucket' + ckpt_dir_path = 'gs://path/to/dir/' + save_only_latest = False + dataflux_ckpt = True + layers = 1000 + steps = 10 + + main(project=project, bucket=bucket, save_only_latest=save_onlylatest, + dataflux_ckpt=dataflux_ckpt, layers=layers, steps=steps) + + Run gcsfs over 10 steps: + + ckpt_dir_path = 'gs://path/to/dir/' + save_only_latest = False + dataflux_ckpt = False + layers = 1000 + steps = 10 + + main(project=project, bucket=bucket, save_only_latest=save_onlylatest, + dataflux_ckpt=dataflux_ckpt, layers=layers, steps=steps) +""" def main(project: str, bucket: str, ckpt_dir_path: str, save_only_latest: bool, dataflux_ckpt: bool, layers: int = 100, steps: int = 5): dataset = WikiText2() dataloader = DataLoader(dataset, num_workers=1)