-
Notifications
You must be signed in to change notification settings - Fork 245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve ergonomic of the Pytorch dataset and Generate embeddings for oxford pet #157
Conversation
eddyxu
commented
Sep 12, 2022
•
edited
Loading
edited
- Improve ergnomic of Pytorch Dataset to load local directory
- Auto tune learning rate and uses other hyperparameters from torchvision site.
- Example code to generate embeddings.
76f1362
to
5ce1701
Compare
python/benchmarks/bench_utils.py
Outdated
@@ -39,7 +40,11 @@ def read_file(uri) -> bytes: | |||
return fs.open_input_file(key).read() | |||
|
|||
|
|||
def download_uris(uris: Iterable[str], func=read_file) -> Iterable[bytes]: | |||
def download_image(uri: str) -> Image: | |||
return ImageUri(uri).to_embedded() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Image.create(uri).to_embedded() should work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, i will try it now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed in #157 , use downloaded binary for now.
return ImageUri(uri).to_embedded() | ||
|
||
|
||
def download_uris(uris: Iterable[str], func=read_file) -> Iterable[Image]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just double checking -- does pool.map return results in the same order as the input?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's my understanding.
|
||
import io | ||
import os | ||
import time | ||
from typing import Callable, Optional | ||
from typing import Optional, Callable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this be alphabetical?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
import lance.pytorch.data | ||
from PIL import Image |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we just always call this PILImage or use a qualified PIL.Image ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, lemme change it to PIL.Image
images.append(img) | ||
labels.append(label) | ||
return torch.stack(images), torch.tensor(labels) | ||
NUM_CLASSES = 38 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this something we can get from torchvision? Do we want to hard code this to check against the dataset? Or do we want to just compute it from the dataset dictionary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is specific number on the dataset. I feel that it is overkill to calculate it dynamically via dataset. We especially need to support different formats (i.e., it requires some effort to dynamically compute this number in the raw format.).
train_loader = torch.utils.data.DataLoader( | ||
dataset, num_workers=num_workers, batch_size=None, collate_fn=collate_fn | ||
) | ||
elif data_format == "raw": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we also need to compare against parquet or is this sufficient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will add support for parquet later. Maybe tfrecord as well.
extractor = create_feature_extractor(model.backbone, {"avgpool": "features"}) | ||
extractor = extractor.to("cuda") | ||
with torch.no_grad(): | ||
dfs = [] | ||
for batch, pk in train_loader: | ||
batch = batch.to("cuda") | ||
features = extractor(batch)["features"].squeeze() | ||
df = pd.DataFrame( | ||
{ | ||
"pk": pk, | ||
"features": features.tolist(), | ||
} | ||
) | ||
dfs.append(df) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does fast.ai / huggingface have any conveniences for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the rest is using pytorch lightning, using fast.ai / hf seems need to convert them to raw pytorch and make it adapt fastai/hf?
Prob we can use https://pytorch-lightning.readthedocs.io/en/stable/deploy/production_basic.html
It still needs to match pk tho.
I can make it to use pytorch lightning's predict if desired.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, actually, this need to wrap into a separate module, as it uses create_feature_extractor(model.backbone, {"avgpool": "features"})
feature extractor, while the original module should do basic predictions (i.e, just return detected class).
import lance | ||
import lance.pytorch.data | ||
|
||
NUM_CLASSES = 38 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
""" | ||
Image transform for training. | ||
|
||
Adding random argumentations. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
augmentation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
def __init__( | ||
self, | ||
crop_size: float, | ||
mean: tuple[float] = (0.485, 0.456, 0.406), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where do these defaults come from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These numbers are used across torchvision, and used in the python code referred above.
f43f274
to
1eef4e7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lgtm