Skip to content

Commit

Permalink
training loop for classifcation
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyxu committed Sep 6, 2022
1 parent 8349486 commit e22cce3
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 23 deletions.
6 changes: 3 additions & 3 deletions python/benchmarks/parse_pet.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def read_metadata(self, check_quality=False) -> pd.DataFrame:
no_index = pd.Index(names.values).difference(df.filename)
self._data_quality_issues["missing_index"] = no_index

# TODO lance doesn't support writing booleans yet
with_xmls['segmented'] = with_xmls.segmented.astype(pd.Int8Dtype())
with_xmls['segmented'] = with_xmls.segmented.apply(
lambda x: pd.NA if pd.isnull(x) else bool(x)).astype(pd.BooleanDtype())
return with_xmls

def _get_index(self, name: str) -> pd.DataFrame:
Expand Down Expand Up @@ -152,7 +152,7 @@ def get_schema(self):
pa.string(),
source_schema,
size_schema,
pa.uint8(),
pa.bool_(),
object_schema
]
return pa.schema([pa.field(name, dtype)
Expand Down
68 changes: 57 additions & 11 deletions python/benchmarks/train_pet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,87 @@
"""

import io
from typing import Optional

import click
import pytorch_lightning as pl
import torch
import torchvision
import torchvision.transforms as T
from PIL import Image
from torch import optim
from torchvision.models.efficientnet import EfficientNet_B0_Weights

import lance
import lance.pytorch.data


def collate_fn(batch):
# TODO: Labels should be converted via torch.LanceDataset
labels = torch.randint(0, 31, size=(len(batch[1]),))
# TODO: Image conversion should in torch.LanceDataset
preprocessing = EfficientNet_B0_Weights.DEFAULT.transforms()
images = [
preprocessing(T.ToTensor()(Image.open(io.BytesIO(data)))) for data in batch[0]
]
return torch.stack(images), labels


class Classification(pl.LightningModule):
def __init__(self) -> None:
"""Classification model to train"""

def __init__(
self,
model: torch.nn.Module = torchvision.models.efficientnet_b0(),
benchmark: Optional[str] = None,
) -> None:
"""Build a PyTorch classification model."""
super().__init__()
self.model = torchvision.models.efficientnet_b1()
self.model = model
self.criterion = torch.nn.CrossEntropyLoss()
self.benchmark = benchmark

def training_step(self, batch, batch_idx):
print(batch, batch_idx)
x, y = batch
pass
"""
https://github.com/pytorch/vision/blob/main/references/classification/train.py
"""
images, labels = batch
if self.benchmark == "io":
# only test I/O
pass
else:
output = self.model(images)
loss = self.criterion(output, labels)
return loss

def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-3)
return optimizer


@click.command()
@click.option("-b", "--batch_size", default=64, help="batch size", show_default=True)
@click.option("-b", "--batch_size", default=4, help="batch size", show_default=True)
@click.option("-e", "--epochs", default=10, help="set max ephochs", show_default=True)
@click.option("--benchmark", type=click.Choice(["io", "train"]), default="train")
@click.argument("dataset")
def train(dataset: str, batch_size: int):
def train(dataset: str, batch_size: int, epochs: int, benchmark: str):
print(f"Running benchmark: {benchmark}")
dataset = lance.pytorch.data.LanceDataset(
dataset, columns=["class", "image"], batch_size=batch_size
dataset,
columns=["image", "class"],
batch_size=batch_size,
)
train_loader = torch.utils.data.DataLoader(
dataset,
num_workers=8,
batch_size=None,
collate_fn=collate_fn,
)
model = Classification(benchmark=benchmark)
trainer = pl.Trainer(
limit_train_batches=100, max_epochs=epochs, accelerator="gpu", devices=-1
)
train_loader = torch.utils.data.DataLoader(dataset, num_workers=0, batch_size=None)
model = Classification()
trainer = pl.Trainer(limit_train_batches=100, max_epochs=1)
trainer.fit(model=model, train_dataloaders=train_loader)


Expand Down
26 changes: 17 additions & 9 deletions python/lance/pytorch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,32 @@
from typing import List, Optional, Union

import pyarrow as pa
import pyarrow.dataset
import numpy as np

try:
import torch
from torch.utils.data import IterableDataset
except ImportError:
raise ImportError("Please install pytorch via pip install lance[pytorch]")
except ImportError as e:
raise ImportError("Please install pytorch first", e)

from lance import dataset, scanner

__all__ = ["LanceDataset"]


def to_numpy(arr: pa.Array):
"""Convert pyarrow array to numpy array"""
# TODO: arrow.to_numpy(writable=True) makes a new copy of data.
# Investigate how to directly perform zero-copy into Torch Tensor.
np_arr = arr.to_numpy(zero_copy_only=False, writable=True)
if pa.types.is_binary(arr.type) or pa.types.is_large_binary(arr.type):
return np_arr.astype(np.bytes_)
elif pa.types.is_string(arr.type) or pa.types.is_large_string(arr.type):
return np_arr.astype(np.str_)
else:
return np_arr


class LanceDataset(IterableDataset):
"""An PyTorch IterableDataset.
Expand All @@ -55,9 +68,4 @@ def __repr__(self):
def __iter__(self):
"""Yield dataset"""
for batch in self.scanner.to_reader():
# TODO: arrow.to_numpy(writable=True) makes a new copy of data.
# Investigate how to directly perform zero-copy into Torch Tensor.
yield [
arr.to_numpy(zero_copy_only=False, writable=True)
for arr in batch.columns
]
yield [to_numpy(arr) for arr in batch.columns]

0 comments on commit e22cce3

Please sign in to comment.