Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve ergonomic of the Pytorch dataset and Generate embeddings for oxford pet #157

Merged
merged 19 commits into from
Sep 13, 2022
Merged
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,7 @@ python/lance/_lib.cpp
*.lance

python/thirdparty/arrow/
python/wheels
python/wheels

logs
*.ckpt
1 change: 0 additions & 1 deletion cpp/src/lance/arrow/file_lance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ ::arrow::Result<bool> LanceFileFormat::IsSupported(

::arrow::Result<std::shared_ptr<::arrow::Schema>> LanceFileFormat::Inspect(
const ::arrow::dataset::FileSource& source) const {
fmt::print("Inspect: File source={}\n", source.path());
if (impl_->manifest) {
return impl_->manifest->schema().ToArrow();
}
Expand Down
14 changes: 10 additions & 4 deletions python/benchmarks/bench_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import multiprocessing as mp
import os
import pathlib
Expand All @@ -26,20 +27,24 @@
import pyarrow.dataset as ds
import pyarrow.fs
import pyarrow.parquet as pq
from urllib.parse import urlparse

import lance
from lance.types.image import Image, ImageBinaryType

__all__ = ["download_uris", "timeit", "get_dataset", "get_uri", "BenchmarkSuite"]

KNOWN_FORMATS = ["lance", "parquet", "raw"]


def read_file(uri) -> bytes:
if not urlparse(uri).scheme:
uri = pathlib.Path(uri)
fs, key = pyarrow.fs.FileSystem.from_uri(uri)
return fs.open_input_file(key).read()


def download_uris(uris: Iterable[str], func=read_file) -> Iterable[bytes]:
def download_uris(uris: Iterable[str], func=read_file) -> Iterable[Image]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just double checking -- does pool.map return results in the same order as the input?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's my understanding.

if isinstance(uris, pd.Series):
uris = uris.values
pool = mp.Pool(mp.cpu_count() - 1)
Expand Down Expand Up @@ -235,7 +240,7 @@ def image_uris(self, table):

def make_embedded_dataset(
self,
table: Union[pa.Table | pd.DataFrame],
table: Union[pa.Table, pd.DataFrame],
fmt="lance",
output_path=None,
**kwargs,
Expand All @@ -246,8 +251,9 @@ def make_embedded_dataset(
output_path = output_path or self.default_dataset_path(fmt)
uris = self.image_uris(table)
images = download_uris(pd.Series(uris))
arr = pa.BinaryArray.from_pandas(images)
embedded = table.append_column(pa.field("image", pa.binary()), arr)
# TODO: improve ext type ergonomic
arr = pa.ExtensionArray.from_storage(ImageBinaryType(), pa.array(images))
embedded = table.append_column(pa.field("image", ImageBinaryType()), arr)
if fmt == "parquet":
pq.write_table(embedded, output_path, **kwargs)
elif fmt == "lance":
Expand Down
134 changes: 40 additions & 94 deletions python/benchmarks/train_pet.py → python/benchmarks/oxford_pet/common.py
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,38 +1,19 @@
#!/usr/bin/env python

"""Train and evaluate models on Oxford pet dataset.

"""
#!/usr/bin/env python3

import io
import os
import time
from typing import Callable, Optional

import click
import pyarrow.compute as pc
import pyarrow.fs
import pyarrow
import pytorch_lightning as pl
import torch
import torchvision
import torchvision.transforms as T
from PIL import Image
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just always call this PILImage or use a qualified PIL.Image ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, lemme change it to PIL.Image

from torch import optim
from torchvision.models.efficientnet import EfficientNet_B0_Weights

import lance
import lance.pytorch.data

transform = T.Compose([EfficientNet_B0_Weights.DEFAULT.transforms()])


def raw_collate_fn(batch):
images = []
labels = []
for img, label in batch:
images.append(img)
labels.append(label)
return torch.stack(images), torch.tensor(labels)
NUM_CLASSES = 38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this something we can get from torchvision? Do we want to hard code this to check against the dataset? Or do we want to just compute it from the dataset dictionary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is specific number on the dataset. I feel that it is overkill to calculate it dynamically via dataset. We especially need to support different formats (i.e., it requires some effort to dynamically compute this number in the raw format.).



class RawOxfordPetDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -85,30 +66,31 @@ def __getitem__(self, idx):
return img, self.labels[idx]


def collate_fn(batch):
# TODO: Labels should be converted via torch.LanceDataset
labels = torch.randint(0, 31, size=(len(batch[1]),))
# TODO: Image conversion should in torch.LanceDataset
images = [
transform(Image.open(io.BytesIO(data)).convert("RGB")) for data in batch[0]
]
return torch.stack(images), labels


class Classification(pl.LightningModule):
"""Classification model to train"""

def __init__(
self,
model: torch.nn.Module = torchvision.models.efficientnet_b0(),
backbone: Optional[torch.nn.Module] = None,
learning_rate=0.1,
benchmark: Optional[str] = None,
) -> None:
"""Build a PyTorch classification model."""
super().__init__()
self.model = model
self.backbone = torchvision.models.resnet50(num_classes=NUM_CLASSES)
self.criterion = torch.nn.CrossEntropyLoss()
self.benchmark = benchmark
self.fit_start_time = 0
self.learning_rate = learning_rate

@staticmethod
def get(name: str, **kwargs):
if name == "resnet":
return Classification(backbone=torchvision.models.resnet50(num_classes=NUM_CLASSES))
elif name == "efficientnet":
return Classification(backbone=torchvision.models.efficientnet_b0(num_classes=NUM_CLASSES))
else:
raise ValueError(f"Unsupported model: {name}")

def on_fit_start(self) -> None:
super().on_fit_start()
Expand All @@ -127,73 +109,37 @@ def training_step(self, batch, batch_idx):
# only test I/O
pass
else:
output = self.model(images)
output = self.backbone(images)
loss = self.criterion(output, labels)
self.log_dict({"loss": loss})
return loss

def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-3)
# Use hyperparameters from https://github.com/pytorch/vision/tree/main/references/classification
#
optimizer = torch.optim.SGD(
self.parameters(), lr=(self.learning_rate), momentum=0.9, weight_decay=1e-4
)
return optimizer


@click.command()
@click.option("-b", "--batch_size", default=64, help="batch size", show_default=True)
@click.option("-e", "--epochs", default=10, help="set max ephochs", show_default=True)
@click.option(
"-w",
"--num_workers",
default=os.cpu_count(),
help="set pytorch DataLoader number of workers",
show_default=True,
)
@click.option(
"--format",
"-F",
"data_format",
type=click.Choice(["lance", "raw", "parquet"]),
default="lance",
)
@click.option("--benchmark", type=click.Choice(["io", "train"]), default="train")
@click.argument("dataset")
def train(
dataset: str,
batch_size: int,
epochs: int,
benchmark: str,
num_workers: int,
data_format,
):
print(f"Running benchmark: {benchmark}")
if data_format == "lance":
dataset = lance.pytorch.data.LanceDataset(
dataset,
columns=["image", "class"],
batch_size=batch_size,
# filter=(pc.field("split") == "train")
)
train_loader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_size=None,
collate_fn=collate_fn,
)
elif data_format == "raw":
dataset = RawOxfordPetDataset(dataset, transform=transform)
train_loader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_size=batch_size,
collate_fn=raw_collate_fn,
)
else:
raise ValueError("Unsupported data format")
def collate_fn(transform):
def _collate_fn(batch):
# TODO: convert label to int64 from Dataset?
labels = torch.from_numpy(batch[1]).to(torch.int64)
# TODO: Image conversion should in torch.LanceDataset
images = [
transform(Image.open(io.BytesIO(data)).convert("RGB")) for data in batch[0]
]
return torch.stack(images), labels

model = Classification(benchmark=benchmark)
trainer = pl.Trainer(
limit_train_batches=100, max_epochs=epochs, accelerator="gpu", devices=-1
)
trainer.fit(model=model, train_dataloaders=train_loader)
return _collate_fn


if __name__ == "__main__":
train()
def raw_collate_fn(batch):
images = []
labels = []
for img, label in batch:
images.append(img)
labels.append(label)
return torch.stack(images), torch.tensor(labels)
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@

import os
import pathlib
import sys
from urllib.parse import urlparse

sys.path.append("..")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the purpose of this?

Copy link
Contributor Author

@eddyxu eddyxu Sep 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for importing bench_utils.


import click
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.fs
import xmltodict

from bench_utils import DatasetConverter, download_uris
Expand Down Expand Up @@ -168,10 +171,10 @@ def get_schema(self):
]
types = [
pa.string(),
pa.dictionary(pa.uint8(), pa.string()),
pa.dictionary(pa.uint8(), pa.string()),
pa.dictionary(pa.int8(), pa.string()),
pa.dictionary(pa.int8(), pa.string()),
pa.int16(),
pa.dictionary(pa.uint8(), pa.string()),
pa.dictionary(pa.int8(), pa.string()),
pa.string(),
source_schema,
size_schema,
Expand All @@ -181,7 +184,10 @@ def get_schema(self):
return pa.schema([pa.field(name, dtype) for name, dtype in zip(names, types)])


def _get_xml(uri):
def _get_xml(uri: str):
if not urlparse(uri).scheme:
uri = pathlib.Path(uri)

fs, key = pa.fs.FileSystem.from_uri(uri)
try:
with fs.open_input_file(key) as fh:
Expand Down
92 changes: 92 additions & 0 deletions python/benchmarks/oxford_pet/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#!/usr/bin/env python3

"""Generate embeddings"""

import io
import os

import click
import pandas as pd
import pyarrow
import torch
import torchvision
from common import Classification, RawOxfordPetDataset, raw_collate_fn
import PIL
from torchvision.models.feature_extraction import create_feature_extractor

import lance

transform = torchvision.models.ResNet50_Weights.DEFAULT.transforms()


def collate_fn(batch):
pk = batch[1]
images = [
transform(PIL.Image.open(io.BytesIO(data)).convert("RGB")) for data in batch[0]
]
return torch.stack(images), pk


@click.command()
@click.argument("checkpoint")
@click.argument("dataset")
@click.option(
"-f", "--format", "data_format", type=click.Choice(["lance", "raw", "parquet"])
)
@click.option("-b", "--batch_size", type=int, default=128)
@click.option(
"-w",
"--num_workers",
default=os.cpu_count(),
help="set pytorch DataLoader number of workers",
show_default=True,
)
@click.option(
"-o", "--output", default="embeddings.lance", help="Output path", show_default=True
)
def gen_embeddings(checkpoint, dataset, output, batch_size, num_workers, data_format):
model = Classification.load_from_checkpoint(checkpoint)

if data_format == "lance":
dataset = lance.pytorch.data.LanceDataset(
dataset,
columns=["image", "filename"],
batch_size=batch_size,
# filter=(pc.field("split") == "train")
)
train_loader = torch.utils.data.DataLoader(
dataset, num_workers=num_workers, batch_size=None, collate_fn=collate_fn
)
elif data_format == "raw":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we also need to compare against parquet or is this sufficient?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add support for parquet later. Maybe tfrecord as well.

dataset = RawOxfordPetDataset(dataset, transform=transform)
train_loader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
num_workers=num_workers,
batch_size=batch_size,
collate_fn=raw_collate_fn,
)
else:
raise ValueError("Unsupported data format")
model.eval()

extractor = create_feature_extractor(model.backbone, {"avgpool": "features"})
extractor = extractor.to("cuda")
with torch.no_grad():
dfs = []
for batch, pk in train_loader:
batch = batch.to("cuda")
features = extractor(batch)["features"].squeeze()
df = pd.DataFrame(
{
"pk": pk,
"features": features.tolist(),
}
)
dfs.append(df)
Comment on lines +73 to +86
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does fast.ai / huggingface have any conveniences for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the rest is using pytorch lightning, using fast.ai / hf seems need to convert them to raw pytorch and make it adapt fastai/hf?

Prob we can use https://pytorch-lightning.readthedocs.io/en/stable/deploy/production_basic.html

It still needs to match pk tho.

I can make it to use pytorch lightning's predict if desired.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, actually, this need to wrap into a separate module, as it uses create_feature_extractor(model.backbone, {"avgpool": "features"}) feature extractor, while the original module should do basic predictions (i.e, just return detected class).

df = pd.concat(dfs)
lance.write_table(pyarrow.Table.from_pandas(df, preserve_index=False), output)


if __name__ == "__main__":
gen_embeddings()
Loading