Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark checkpoint #45

Merged
merged 11 commits into from
Jul 10, 2024
139 changes: 139 additions & 0 deletions dataflux_pytorch/benchmark/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Benchmarking PyTorch Lightning Checkpoints with Google Cloud Storage

This benchmarking script will allow you to run and benchmark the performance of the PyTorch Lightning Checkpoint save function. This script does not rely on GPUs or CPU Clusters and can be run directly on your machine. The script runs the `WikiText2` PyTorch Lightining demo code with some modifications.
divrawal marked this conversation as resolved.
Show resolved Hide resolved

### Getting started

First ensure you are running within a virtual python enviroment, then set the enviroment variables.
divrawal marked this conversation as resolved.
Show resolved Hide resolved

`CKPT_DIR_PATH` is the location of where to save the checkpoints. `STEPS` is the number of steps the model will take (the number of checkpoints created will be the same). The default value for `STEPS` is 5.

```shell
export CKPT_DIR_PATH=`gs://path/to/directory/`
export STEPS=5
```

You can also optionally change the size of the model. The size variable will be passed into `nn.Transformer` for `num_encoder_layers` and `num_decoder_layers`. The default value for size is 100.

```shell
export CHECKPOINT_SIZE=`1000`
divrawal marked this conversation as resolved.
Show resolved Hide resolved
```

### Dataflux Lightning Checkpoint

If you are benchmarking Dataflux Lightning Checkpoint, save information regarding your project and bucket, and make sure to enable the flag by setting it to `1`.

```shell
export PROJECT=`YOUR_PROJECT_NAME`
export BUCKET=`YOUR_BUCKET_NAME`
export DATAFLUX_CKPT=1
```

### Running

Run the script.

```shell
python dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py
```

The time will print out and the checkpoints can be viewed in GCS at the location passed in.

### Results

Type Size Steps Time
Default
Dataflux

<table>
divrawal marked this conversation as resolved.
Show resolved Hide resolved
<tr>
<td style="background-color: #d9d2e9"><strong>Checkpoint Type</strong>
</td>
<td style="background-color: #d9d2e9"><strong>Size</strong>
</td>
<td style="background-color: #d9d2e9"><strong>Checkpoint Size (MB) per step</strong>
</td>
<td style="background-color: #d9d2e9"><strong>Steps</strong>
</td>
<td style="background-color: #d9d2e9"><strong>Time (s)</strong>
</td>
</tr>
<tr>
<td style="background-color: #d9d9d9"> Default
</td>
<td style="background-color: #d9d9d9">10
</td>
<td style="background-color: #d9d9d9">75.6
</td>
<td style="background-color: #d9d9d9">5
</td>
<td style="background-color: #d9d9d9">13.25
</td>
</tr>
<tr>
<td style="background-color: #f3f3f3">Dataflux
</td>
<td style="background-color: #f3f3f3">10
</td>
<td style="background-color: #f3f3f3">75.6
</td>
</td>
<td style="background-color: #f3f3f3">5
</td>
</td>
<td style="background-color: #f3f3f3">14.08
</td>
</tr>
<tr>
<td style="background-color: #d9d9d9">Default
</td>
<td style="background-color: #d9d9d9">100
</td>
<td style="background-color: #d9d9d9">298
</td>
<td style="background-color: #d9d9d9">5
</td>
<td style="background-color: #d9d9d9">36.55
</td>
</tr>
<tr>
<td style="background-color: #f3f3f3">Dataflux
</td>
<td style="background-color: #f3f3f3">100
</td>
<td style="background-color: #f3f3f3">298
</td>
</td>
<td style="background-color: #f3f3f3">5
</td>
</td>
<td style="background-color: #f3f3f3">44.07
</td>
</tr>
<tr>
<td style="background-color: #d9d9d9"> Default
</td>
<td style="background-color: #d9d9d9">1000
</td>
<td style="background-color: #d9d9d9">2500
</td>
<td style="background-color: #d9d9d9">5
</td>
<td style="background-color: #d9d9d9">266.16
</td>
</tr>
<tr>
<td style="background-color: #f3f3f3">Dataflux
</td>
<td style="background-color: #f3f3f3">1000
</td>
<td style="background-color: #f3f3f3">2500
</td>
</td>
<td style="background-color: #f3f3f3">349.19
divrawal marked this conversation as resolved.
Show resolved Hide resolved
</td>
</td>
<td style="background-color: #f3f3f3">
</td>
</tr>
</table>
87 changes: 87 additions & 0 deletions dataflux_pytorch/benchmark/lightning_checkpoint_benchmark.py
divrawal marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os
import time
import torch

from typing import Tuple
from torch.utils.data import DataLoader
from torch import Tensor
from torch.utils.data import DataLoader
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos import WikiText2, Transformer
from lightning.pytorch import LightningModule
from lightning.pytorch.plugins.io import TorchCheckpointIO

from dataflux_pytorch.lightning import DatafluxLightningCheckpoint

class LightningTransformer(LightningModule):
def __init__(self, vocab_size: int = 33278, nlayers: int = 100) -> None:
super().__init__()
self.model = Transformer(vocab_size=vocab_size, nlayers=nlayers)

def forward(self, inputs: Tensor, target: Tensor) -> Tensor:
return self.model(inputs, target)

def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
inputs, target = batch
output = self(inputs, target)
loss = torch.nn.functional.nll_loss(output, target.view(-1))
return loss

def configure_optimizers(self) -> torch.optim.Optimizer:
return torch.optim.SGD(self.model.parameters(), lr=0.1)

def prepare_data(self) -> None:
WikiText2(download=True)

def train_dataloader(self) -> DataLoader:
dataset = WikiText2()
return DataLoader(dataset)


def main(project: str, bucket: str, ckpt_dir_path: str, save_only_latest: bool, dataflux_ckpt: bool, model_size: int = 100, steps: int = 5):
divrawal marked this conversation as resolved.
Show resolved Hide resolved
divrawal marked this conversation as resolved.
Show resolved Hide resolved
dataset = WikiText2()
dataloader = DataLoader(dataset, num_workers=1)
model = LightningTransformer(vocab_size=dataset.vocab_size, nlayers=model_size)
ckpt = TorchCheckpointIO()
if dataflux_ckpt:
ckpt = DatafluxLightningCheckpoint(project_name=project,bucket_name=bucket)
# Save once per step, and if `save_only_latest`, replace the last checkpoint each time.
# Replacing is implemented by saving the new checkpoint, and then deleting the previous one.
# If `save_only_latest` is False, a new checkpoint is created for each step.
checkpoint_callback = ModelCheckpoint(
save_top_k=1 if save_only_latest else -1,
every_n_train_steps=1,
filename="checkpoint-{epoch:02d}-{step:02d}",
enable_version_counter=True,
)
trainer = Trainer(
default_root_dir=ckpt_dir_path,
plugins=[ckpt],
callbacks=[checkpoint_callback],
min_epochs=4,
max_epochs=5,
max_steps=steps,
accelerator="cpu",
)
start = time.time()
trainer.fit(model, dataloader)
end = time.time()
print(end-start)
divrawal marked this conversation as resolved.
Show resolved Hide resolved

if __name__ == "__main__":

DEFAULT_SIZE = 100
DEFAULT_STEPS = 5
size = int(os.getenv("CHECKPOINT_SIZE",DEFAULT_SIZE))
steps = int(os.getenv("STEPS",DEFAULT_STEPS))

main(
os.getenv("PROJECT"),
os.getenv("BUCKET"),
os.getenv("CKPT_DIR_PATH"),
os.getenv("SAVE_ONLY_LATEST") == "1",
os.getenv("DATAFLUX_CKPT") == "1",
size,
steps,
)
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def save_checkpoint(
key = self._parse_gcs_path(path)
blob = self.bucket.blob(key)
with blob.open("wb", ignore_flush=True) as blobwriter:
torch.save(checkpoint, blob.open("wb", ignore_flush=True))
torch.save(checkpoint, blobwriter)

def load_checkpoint(
self,
Expand Down