-
Notifications
You must be signed in to change notification settings - Fork 245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve ergonomic of the Pytorch dataset and Generate embeddings for oxford pet #157
Changes from all commits
c03e4cb
a51c2bb
3394a07
5bc00c2
68a5bb6
438bede
4fe49e0
9d951b6
845a027
0b95646
11fe8ea
a4157ea
d8d442a
3013f8e
a01d2e9
70eadd0
b1b9a56
9ac7419
1eef4e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,4 +52,7 @@ python/lance/_lib.cpp | |
*.lance | ||
|
||
python/thirdparty/arrow/ | ||
python/wheels | ||
python/wheels | ||
|
||
logs | ||
*.ckpt |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,19 @@ | ||
#!/usr/bin/env python | ||
|
||
"""Train and evaluate models on Oxford pet dataset. | ||
|
||
""" | ||
#!/usr/bin/env python3 | ||
|
||
import io | ||
import os | ||
import time | ||
from typing import Callable, Optional | ||
|
||
import click | ||
import pyarrow.compute as pc | ||
import pyarrow.fs | ||
import pyarrow | ||
import pytorch_lightning as pl | ||
import torch | ||
import torchvision | ||
import torchvision.transforms as T | ||
from PIL import Image | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we just always call this PILImage or use a qualified PIL.Image ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, lemme change it to |
||
from torch import optim | ||
from torchvision.models.efficientnet import EfficientNet_B0_Weights | ||
|
||
import lance | ||
import lance.pytorch.data | ||
|
||
transform = T.Compose([EfficientNet_B0_Weights.DEFAULT.transforms()]) | ||
|
||
|
||
def raw_collate_fn(batch): | ||
images = [] | ||
labels = [] | ||
for img, label in batch: | ||
images.append(img) | ||
labels.append(label) | ||
return torch.stack(images), torch.tensor(labels) | ||
NUM_CLASSES = 38 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this something we can get from torchvision? Do we want to hard code this to check against the dataset? Or do we want to just compute it from the dataset dictionary? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is specific number on the dataset. I feel that it is overkill to calculate it dynamically via dataset. We especially need to support different formats (i.e., it requires some effort to dynamically compute this number in the raw format.). |
||
|
||
|
||
class RawOxfordPetDataset(torch.utils.data.Dataset): | ||
|
@@ -85,30 +66,31 @@ def __getitem__(self, idx): | |
return img, self.labels[idx] | ||
|
||
|
||
def collate_fn(batch): | ||
# TODO: Labels should be converted via torch.LanceDataset | ||
labels = torch.randint(0, 31, size=(len(batch[1]),)) | ||
# TODO: Image conversion should in torch.LanceDataset | ||
images = [ | ||
transform(Image.open(io.BytesIO(data)).convert("RGB")) for data in batch[0] | ||
] | ||
return torch.stack(images), labels | ||
|
||
|
||
class Classification(pl.LightningModule): | ||
"""Classification model to train""" | ||
|
||
def __init__( | ||
self, | ||
model: torch.nn.Module = torchvision.models.efficientnet_b0(), | ||
backbone: Optional[torch.nn.Module] = None, | ||
learning_rate=0.1, | ||
benchmark: Optional[str] = None, | ||
) -> None: | ||
"""Build a PyTorch classification model.""" | ||
super().__init__() | ||
self.model = model | ||
self.backbone = torchvision.models.resnet50(num_classes=NUM_CLASSES) | ||
self.criterion = torch.nn.CrossEntropyLoss() | ||
self.benchmark = benchmark | ||
self.fit_start_time = 0 | ||
self.learning_rate = learning_rate | ||
|
||
@staticmethod | ||
def get(name: str, **kwargs): | ||
if name == "resnet": | ||
return Classification(backbone=torchvision.models.resnet50(num_classes=NUM_CLASSES)) | ||
elif name == "efficientnet": | ||
return Classification(backbone=torchvision.models.efficientnet_b0(num_classes=NUM_CLASSES)) | ||
else: | ||
raise ValueError(f"Unsupported model: {name}") | ||
|
||
def on_fit_start(self) -> None: | ||
super().on_fit_start() | ||
|
@@ -127,73 +109,37 @@ def training_step(self, batch, batch_idx): | |
# only test I/O | ||
pass | ||
else: | ||
output = self.model(images) | ||
output = self.backbone(images) | ||
loss = self.criterion(output, labels) | ||
self.log_dict({"loss": loss}) | ||
return loss | ||
|
||
def configure_optimizers(self): | ||
optimizer = optim.Adam(self.parameters(), lr=1e-3) | ||
# Use hyperparameters from https://github.com/pytorch/vision/tree/main/references/classification | ||
# | ||
optimizer = torch.optim.SGD( | ||
self.parameters(), lr=(self.learning_rate), momentum=0.9, weight_decay=1e-4 | ||
) | ||
return optimizer | ||
|
||
|
||
@click.command() | ||
@click.option("-b", "--batch_size", default=64, help="batch size", show_default=True) | ||
@click.option("-e", "--epochs", default=10, help="set max ephochs", show_default=True) | ||
@click.option( | ||
"-w", | ||
"--num_workers", | ||
default=os.cpu_count(), | ||
help="set pytorch DataLoader number of workers", | ||
show_default=True, | ||
) | ||
@click.option( | ||
"--format", | ||
"-F", | ||
"data_format", | ||
type=click.Choice(["lance", "raw", "parquet"]), | ||
default="lance", | ||
) | ||
@click.option("--benchmark", type=click.Choice(["io", "train"]), default="train") | ||
@click.argument("dataset") | ||
def train( | ||
dataset: str, | ||
batch_size: int, | ||
epochs: int, | ||
benchmark: str, | ||
num_workers: int, | ||
data_format, | ||
): | ||
print(f"Running benchmark: {benchmark}") | ||
if data_format == "lance": | ||
dataset = lance.pytorch.data.LanceDataset( | ||
dataset, | ||
columns=["image", "class"], | ||
batch_size=batch_size, | ||
# filter=(pc.field("split") == "train") | ||
) | ||
train_loader = torch.utils.data.DataLoader( | ||
dataset, | ||
num_workers=num_workers, | ||
batch_size=None, | ||
collate_fn=collate_fn, | ||
) | ||
elif data_format == "raw": | ||
dataset = RawOxfordPetDataset(dataset, transform=transform) | ||
train_loader = torch.utils.data.DataLoader( | ||
dataset, | ||
num_workers=num_workers, | ||
batch_size=batch_size, | ||
collate_fn=raw_collate_fn, | ||
) | ||
else: | ||
raise ValueError("Unsupported data format") | ||
def collate_fn(transform): | ||
def _collate_fn(batch): | ||
# TODO: convert label to int64 from Dataset? | ||
labels = torch.from_numpy(batch[1]).to(torch.int64) | ||
# TODO: Image conversion should in torch.LanceDataset | ||
images = [ | ||
transform(Image.open(io.BytesIO(data)).convert("RGB")) for data in batch[0] | ||
] | ||
return torch.stack(images), labels | ||
|
||
model = Classification(benchmark=benchmark) | ||
trainer = pl.Trainer( | ||
limit_train_batches=100, max_epochs=epochs, accelerator="gpu", devices=-1 | ||
) | ||
trainer.fit(model=model, train_dataloaders=train_loader) | ||
return _collate_fn | ||
|
||
|
||
if __name__ == "__main__": | ||
train() | ||
def raw_collate_fn(batch): | ||
images = [] | ||
labels = [] | ||
for img, label in batch: | ||
images.append(img) | ||
labels.append(label) | ||
return torch.stack(images), torch.tensor(labels) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,12 +15,15 @@ | |
|
||
import os | ||
import pathlib | ||
import sys | ||
from urllib.parse import urlparse | ||
|
||
sys.path.append("..") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the purpose of this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is for importing |
||
|
||
import click | ||
import numpy as np | ||
import pandas as pd | ||
import pyarrow as pa | ||
import pyarrow.fs | ||
import xmltodict | ||
|
||
from bench_utils import DatasetConverter, download_uris | ||
|
@@ -168,10 +171,10 @@ def get_schema(self): | |
] | ||
types = [ | ||
pa.string(), | ||
pa.dictionary(pa.uint8(), pa.string()), | ||
pa.dictionary(pa.uint8(), pa.string()), | ||
pa.dictionary(pa.int8(), pa.string()), | ||
pa.dictionary(pa.int8(), pa.string()), | ||
pa.int16(), | ||
pa.dictionary(pa.uint8(), pa.string()), | ||
pa.dictionary(pa.int8(), pa.string()), | ||
pa.string(), | ||
source_schema, | ||
size_schema, | ||
|
@@ -181,7 +184,10 @@ def get_schema(self): | |
return pa.schema([pa.field(name, dtype) for name, dtype in zip(names, types)]) | ||
|
||
|
||
def _get_xml(uri): | ||
def _get_xml(uri: str): | ||
if not urlparse(uri).scheme: | ||
uri = pathlib.Path(uri) | ||
|
||
fs, key = pa.fs.FileSystem.from_uri(uri) | ||
try: | ||
with fs.open_input_file(key) as fh: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
#!/usr/bin/env python3 | ||
|
||
"""Generate embeddings""" | ||
|
||
import io | ||
import os | ||
|
||
import click | ||
import pandas as pd | ||
import pyarrow | ||
import torch | ||
import torchvision | ||
from common import Classification, RawOxfordPetDataset, raw_collate_fn | ||
import PIL | ||
from torchvision.models.feature_extraction import create_feature_extractor | ||
|
||
import lance | ||
|
||
transform = torchvision.models.ResNet50_Weights.DEFAULT.transforms() | ||
|
||
|
||
def collate_fn(batch): | ||
pk = batch[1] | ||
images = [ | ||
transform(PIL.Image.open(io.BytesIO(data)).convert("RGB")) for data in batch[0] | ||
] | ||
return torch.stack(images), pk | ||
|
||
|
||
@click.command() | ||
@click.argument("checkpoint") | ||
@click.argument("dataset") | ||
@click.option( | ||
"-f", "--format", "data_format", type=click.Choice(["lance", "raw", "parquet"]) | ||
) | ||
@click.option("-b", "--batch_size", type=int, default=128) | ||
@click.option( | ||
"-w", | ||
"--num_workers", | ||
default=os.cpu_count(), | ||
help="set pytorch DataLoader number of workers", | ||
show_default=True, | ||
) | ||
@click.option( | ||
"-o", "--output", default="embeddings.lance", help="Output path", show_default=True | ||
) | ||
def gen_embeddings(checkpoint, dataset, output, batch_size, num_workers, data_format): | ||
model = Classification.load_from_checkpoint(checkpoint) | ||
|
||
if data_format == "lance": | ||
dataset = lance.pytorch.data.LanceDataset( | ||
dataset, | ||
columns=["image", "filename"], | ||
batch_size=batch_size, | ||
# filter=(pc.field("split") == "train") | ||
) | ||
train_loader = torch.utils.data.DataLoader( | ||
dataset, num_workers=num_workers, batch_size=None, collate_fn=collate_fn | ||
) | ||
elif data_format == "raw": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we also need to compare against parquet or is this sufficient? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will add support for parquet later. Maybe tfrecord as well. |
||
dataset = RawOxfordPetDataset(dataset, transform=transform) | ||
train_loader = torch.utils.data.DataLoader( | ||
dataset, | ||
shuffle=True, | ||
num_workers=num_workers, | ||
batch_size=batch_size, | ||
collate_fn=raw_collate_fn, | ||
) | ||
else: | ||
raise ValueError("Unsupported data format") | ||
model.eval() | ||
|
||
extractor = create_feature_extractor(model.backbone, {"avgpool": "features"}) | ||
extractor = extractor.to("cuda") | ||
with torch.no_grad(): | ||
dfs = [] | ||
for batch, pk in train_loader: | ||
batch = batch.to("cuda") | ||
features = extractor(batch)["features"].squeeze() | ||
df = pd.DataFrame( | ||
{ | ||
"pk": pk, | ||
"features": features.tolist(), | ||
} | ||
) | ||
dfs.append(df) | ||
Comment on lines
+73
to
+86
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does fast.ai / huggingface have any conveniences for this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the rest is using pytorch lightning, using fast.ai / hf seems need to convert them to raw pytorch and make it adapt fastai/hf? Prob we can use https://pytorch-lightning.readthedocs.io/en/stable/deploy/production_basic.html It still needs to match pk tho. I can make it to use pytorch lightning's predict if desired. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, actually, this need to wrap into a separate module, as it uses |
||
df = pd.concat(dfs) | ||
lance.write_table(pyarrow.Table.from_pandas(df, preserve_index=False), output) | ||
|
||
|
||
if __name__ == "__main__": | ||
gen_embeddings() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just double checking -- does pool.map return results in the same order as the input?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's my understanding.