Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add boilerplate code for Dataflux-Pytorch Lightning demo #57

Merged
merged 9 commits into from
Jul 16, 2024
24 changes: 24 additions & 0 deletions demo/lightning/data.py
abhibyreddi marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch
import lightning.pytorch as pl

from dataflux_pytorch import dataflux_mapstyle_dataset


class Unet3DDataModule(pl.LighitningDataModule):
def __init__(self, gcs_bucket):
self.data_dir = gcs_bucket + "/images"
abhibyreddi marked this conversation as resolved.
Show resolved Hide resolved
abhibyreddi marked this conversation as resolved.
Show resolved Hide resolved
self.transform = 1

def prepare_data(self):

pass

def setup(self, state="train"):
# Init DatafluxPytTrain object
# Pass transform function as argument - modify DatafluxPyTrain constructor
pass

def train_dataloader(self):
# Init dataflux dataloader
# Wrap it in torch.DataLoader
pass
22 changes: 22 additions & 0 deletions demo/lightning/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch
import lightning.pytorch as pl


class Unet3D(pl.LightningModule):
def __init__(self):
pass

def forward(self, x):
pass

def configure_optimizers(self):
pass

def training_step(self, train_batch, batch_idx):
pass

def backward(self, trainer, loss, optimizer, optimizer_idx):
pass

def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx):
pass
14 changes: 14 additions & 0 deletions demo/lightning/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import lightning as pl

from model import Unet3D
from data import Unet3DDataModule

if __name__ == "__main__":
model = Unet3D()
train_data_loader = Unet3DDataModule()
trainer = pl.Trainer(
devices=2,
accelerator="gpu",
max_epochs=5
abhibyreddi marked this conversation as resolved.
Show resolved Hide resolved
)
trainer.fit(model=model, train_dataloaders=train_data_loader)