diff --git a/dataflux_client_python b/dataflux_client_python
index ac12adba..4a4acb55 160000
--- a/dataflux_client_python
+++ b/dataflux_client_python
@@ -1 +1 @@
-Subproject commit ac12adba3aa3f4a6d27d3703223aac4b2f8b850c
+Subproject commit 4a4acb5543758bc9910ac74a656564921dc3e225
diff --git a/dataflux_pytorch/benchmark/README.md b/dataflux_pytorch/benchmark/README.md
new file mode 100644
index 00000000..9a50ab2b
--- /dev/null
+++ b/dataflux_pytorch/benchmark/README.md
@@ -0,0 +1,191 @@
+# Benchmarking PyTorch Lightning Checkpoints with Google Cloud Storage
+
+This benchmarking script will allow you to run and benchmark the performance of the PyTorch Lightning Checkpoint save function. This script does not rely on GPUs, TPUs or CPU Clusters and can be run directly on your machine. The script runs the `WikiText2` PyTorch Lightning demo code with some modifications.
+
+## Getting started
+
+### Installation
+
+```shell
+pip install gcs-torch-dataflux gcsfs
+```
+
+### Configuration
+
+First ensure you are running within a virtual python enviroment, then set the enviroment variables.
+
+`CKPT_DIR_PATH` is the location of where to save the checkpoints. `STEPS` is the number of steps the model will take (the number of checkpoints created will be the same). The default value for `STEPS` is 5.
+
+```shell
+export CKPT_DIR_PATH=`gs://path/to/directory/`
+export STEPS=5
+```
+
+You can also optionally change the size of the model. The `LAYERS` variable will be passed into `nn.Transformer` for `num_encoder_layers` and `num_decoder_layers`. The default value for `LAYERS` is 100.
+
+```shell
+export LAYERS=1000
+```
+
+### Dataflux Lightning Checkpoint
+
+If you are benchmarking Dataflux Lightning Checkpoint, save information regarding your project and bucket, and make sure to enable the flag by setting it to `1`.
+
+```shell
+export PROJECT=`YOUR_PROJECT_NAME`
+export BUCKET=`YOUR_BUCKET_NAME`
+export DATAFLUX_CKPT=1
+```
+
+### Running
+
+Run the script.
+
+```shell
+python dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py
+```
+
+The time will print out and the checkpoints can be viewed in GCS at the location passed in. A sample output is shown below.
+
+```shell
+$ python dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py
+GPU available: False, used: False
+TPU available: False, using: 0 TPU cores
+HPU available: False, using: 0 HPUs
+/usr/local/google/home/divyarawal/dataflux_dev/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
+
+ | Name | Type | Params | Mode
+----------------------------------------------
+0 | model | Transformer | 19.8 M | train
+----------------------------------------------
+19.8 M Trainable params
+0 Non-trainable params
+19.8 M Total params
+79.189 Total estimated model params size (MB)
+/usr/local/google/home/divyarawal/dataflux_dev/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.
+Epoch 0: 0%| | 5/59674 [00:11<39:37:59, 0.42it/s, v_num=2]`Trainer.fit` stopped: `max_steps=5` reached.
+Epoch 0: 0%| | 5/59674 [00:11<39:38:05, 0.42it/s, v_num=2]
+Time to train over 5 steps: 14.197517395019531 seconds
+Time to save one checkpoint: 2.075364589691162 seconds
+```
+
+## Results
+
+The table below contains benchmarking times writing checkpoints to GCS.
+
+Dataflux's implementation of CheckpointIO for PyTorch Lightning is undergoing active development. The numbers below will be continuously updated to reflect the current state and performance of Dataflux's PyTorch Lightning checkpoint utility. These values are compared to `Default`, which refers to fsspec.
+
+
+
+ Checkpoint Type
+ |
+ Layers
+ |
+ Checkpoint Size (MB) per step
+ |
+ Steps
+ |
+ Train Time (s)
+ |
+ Single Checkpoint Save Time (s)
+ |
+ Write Throughput (MB/s)
+ |
+
+
+ Default
+ |
+ 10
+ |
+ 75.6
+ |
+ 5
+ |
+ 13.25
+ |
+ 1.64
+ |
+ 46.09
+ |
+
+
+ Dataflux
+ |
+ 10
+ |
+ 75.6
+ |
+ 5
+ |
+ 14.08
+ |
+ 2.07
+ |
+ 36.52
+ |
+
+
+ Default
+ |
+ 100
+ |
+ 298
+ |
+ 5
+ |
+ 36.55
+ |
+ 5.21
+ |
+ 57.20
+ |
+
+
+ Dataflux
+ |
+ 100
+ |
+ 298
+ |
+ 5
+ |
+ 44.07
+ |
+ 7.04
+ |
+ 42.32
+ |
+
+
+ Default
+ |
+ 1000
+ |
+ 2500
+ |
+ 5
+ |
+ 266.16
+ |
+ 39.14
+ |
+ 63.87
+ |
+
+
+ Dataflux
+ |
+ 1000
+ |
+ 2500
+ |
+ 5
+ |
+ 349.19
+ |
+ 53.71
+ |
+ 46.55
+ |
+
+
diff --git a/dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py b/dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py
new file mode 100644
index 00000000..a25e268b
--- /dev/null
+++ b/dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py
@@ -0,0 +1,127 @@
+import os
+import time
+import torch
+
+from typing import Tuple
+from torch.utils.data import DataLoader
+from torch import Tensor
+from torch.utils.data import DataLoader
+from lightning import Trainer
+from lightning.pytorch.callbacks import ModelCheckpoint
+from lightning.pytorch.demos import WikiText2, Transformer
+from lightning.pytorch import LightningModule
+from lightning.pytorch.plugins.io import TorchCheckpointIO
+
+from dataflux_pytorch.lightning import DatafluxLightningCheckpoint
+
+class LightningTransformer(LightningModule):
+ def __init__(self, vocab_size: int = 33278, nlayers: int = 100) -> None:
+ super().__init__()
+ self.model = Transformer(vocab_size=vocab_size, nlayers=nlayers)
+
+ def forward(self, inputs: Tensor, target: Tensor) -> Tensor:
+ return self.model(inputs, target)
+
+ def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
+ inputs, target = batch
+ output = self(inputs, target)
+ loss = torch.nn.functional.nll_loss(output, target.view(-1))
+ return loss
+
+ def configure_optimizers(self) -> torch.optim.Optimizer:
+ return torch.optim.SGD(self.model.parameters(), lr=0.1)
+
+ def prepare_data(self) -> None:
+ WikiText2(download=True)
+
+ def train_dataloader(self) -> DataLoader:
+ dataset = WikiText2()
+ return DataLoader(dataset)
+
+
+"""Checkpoints a PyTorch Ligthning demo model to GCS using gcsfs or DatafluxLightningCheckpoint.
+
+This function utilizes PyTorch Lightning to checkpoint the WikiText2 dataset. It
+takes in information regarding the gcs location to save the checkpoints, the type of
+checkpoint, and other configuration variables. Default this function runs on
+gcsfs to write PyTorch Ligthtning checkpoints, TorchCheckpointIO. If dataflux_ckpt
+is enabled the Trainer will be passed a DatafluxLightningCheckpoint, which is an
+implementation of the CheckpointIO interface, as a plugin.
+
+Typical usage example:
+
+ Run DatafluxLightningCheckpoint over 10 steps:
+
+ project = 'test-project'
+ bucket = 'test-bucket'
+ ckpt_dir_path = 'gs://path/to/dir/'
+ save_only_latest = False
+ dataflux_ckpt = True
+ layers = 1000
+ steps = 10
+
+ main(project=project, bucket=bucket, save_only_latest=save_onlylatest,
+ dataflux_ckpt=dataflux_ckpt, layers=layers, steps=steps)
+
+ Run gcsfs over 10 steps:
+
+ ckpt_dir_path = 'gs://path/to/dir/'
+ save_only_latest = False
+ dataflux_ckpt = False
+ layers = 1000
+ steps = 10
+
+ main(project=project, bucket=bucket, save_only_latest=save_onlylatest,
+ dataflux_ckpt=dataflux_ckpt, layers=layers, steps=steps)
+"""
+def main(project: str, bucket: str, ckpt_dir_path: str, save_only_latest: bool, dataflux_ckpt: bool, layers: int = 100, steps: int = 5):
+ dataset = WikiText2()
+ dataloader = DataLoader(dataset, num_workers=1)
+ model = LightningTransformer(vocab_size=dataset.vocab_size, nlayers=layers)
+ ckpt = TorchCheckpointIO()
+ if dataflux_ckpt:
+ ckpt = DatafluxLightningCheckpoint(project_name=project,bucket_name=bucket)
+ # Save once per step, and if `save_only_latest`, replace the last checkpoint each time.
+ # Replacing is implemented by saving the new checkpoint, and then deleting the previous one.
+ # If `save_only_latest` is False, a new checkpoint is created for each step.
+ checkpoint_callback = ModelCheckpoint(
+ save_top_k=1 if save_only_latest else -1,
+ every_n_train_steps=1,
+ filename="checkpoint-{epoch:02d}-{step:02d}",
+ enable_version_counter=True,
+ )
+ trainer = Trainer(
+ default_root_dir=ckpt_dir_path,
+ plugins=[ckpt],
+ callbacks=[checkpoint_callback],
+ min_epochs=4,
+ max_epochs=5,
+ max_steps=steps,
+ accelerator="cpu",
+ )
+ start = time.time()
+ trainer.fit(model, dataloader)
+ end = time.time()
+ print(f"Time to train over {steps} steps: " + str(end-start) + " seconds")
+
+ start = time.time()
+ trainer.save_checkpoint(ckpt_dir_path)
+ end = time.time()
+ print("Time to save one checkpoint: " + str(end-start) + " seconds")
+
+if __name__ == "__main__":
+
+ DEFAULT_LAYERS = 100
+ DEFAULT_STEPS = 5
+ layers = int(os.getenv("LAYERS",DEFAULT_LAYERS))
+ steps = int(os.getenv("STEPS",DEFAULT_STEPS))
+
+ main(
+ os.getenv("PROJECT"),
+ os.getenv("BUCKET"),
+ os.getenv("CKPT_DIR_PATH"),
+ os.getenv("SAVE_ONLY_LATEST") == "1",
+ os.getenv("DATAFLUX_CKPT") == "1",
+ layers,
+ steps,
+ )
diff --git a/dataflux_pytorch/lightning/dataflux_lightning_checkpoint.py b/dataflux_pytorch/lightning/dataflux_lightning_checkpoint.py
index 3deb20e3..d2356d0d 100644
--- a/dataflux_pytorch/lightning/dataflux_lightning_checkpoint.py
+++ b/dataflux_pytorch/lightning/dataflux_lightning_checkpoint.py
@@ -51,7 +51,7 @@ def save_checkpoint(
key = self._parse_gcs_path(path)
blob = self.bucket.blob(key)
with blob.open("wb", ignore_flush=True) as blobwriter:
- torch.save(checkpoint, blob.open("wb", ignore_flush=True))
+ torch.save(checkpoint, blobwriter)
def load_checkpoint(
self,