Skip to content

Commit

Permalink
Overload embed decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
druzsan committed Dec 1, 2023
1 parent a70ecf4 commit ba1adfc
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 14 deletions.
71 changes: 66 additions & 5 deletions renumics/spotlight/embeddings/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,84 @@
"""

import functools
from typing import Callable, Dict, Optional, Any
from typing import (
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Any,
Union,
overload,
)

import PIL.Image
import numpy as np
from renumics.spotlight import dtypes
from renumics.spotlight.media.audio import Audio
from renumics.spotlight.media.embedding import Embedding

from renumics.spotlight.dtypes import create_dtype
from renumics.spotlight.embeddings.preprocessors import (
from renumics.spotlight.media.image import Image
from renumics.spotlight.media.sequence_1d import Sequence1D
from .preprocessors import (
preprocess_audio_batch,
preprocess_batch,
preprocess_image_batch,
)
from .typing import EmbedFunc, FunctionalEmbedder
from .registry import register_embedder
from .typing import EmbedFunc, FunctionalEmbedder


EmbedImageFunc = Callable[
[Iterable[List[PIL.Image.Image]]], Iterable[List[Optional[np.ndarray]]]
]
EmbedArrayFunc = Callable[
[Iterable[List[np.ndarray]]], Iterable[List[Optional[np.ndarray]]]
]


@overload
def embed(
dtype: Union[Literal["image", "Image"], Image], *, name: Optional[str] = None
) -> Callable[[EmbedImageFunc], EmbedImageFunc]:
...


@overload
def embed(
dtype: Union[Literal["audio", "Audio"], Audio],
*,
name: Optional[str] = None,
sampling_rate: int,
) -> Callable[[EmbedArrayFunc], EmbedArrayFunc]:
...


@overload
def embed(
dtype: Union[
Literal["embedding", "Embedding", "sequence1d", "Sequence1D"],
Embedding,
Sequence1D,
],
*,
name: Optional[str] = None,
) -> Callable[[EmbedArrayFunc], EmbedArrayFunc]:
...


@overload
def embed(
dtype: Any, *, name: Optional[str] = None, sampling_rate: Optional[int] = None
) -> Callable[[EmbedFunc], EmbedFunc]:
...


def embed(
dtype: Any, *, name: Optional[str] = None, sampling_rate: Optional[int] = None
) -> Callable[[EmbedFunc], EmbedFunc]:
dtype = create_dtype(dtype)
dtype = dtypes.create_dtype(dtype)

kwargs: Dict[str, Any] = {}

Expand Down
6 changes: 3 additions & 3 deletions renumics/spotlight/embeddings/embedders/gte.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, List
from typing import Iterable, List, Optional

import numpy as np
import sentence_transformers
Expand All @@ -9,11 +9,11 @@
try:
import torch
except ImportError:
logger.warning("`GTE Embedder` requires `pytorch` to be installed.")
logger.warning("GTE embedder requires `pytorch` to be installed.")
else:

@embed("str")
def gte(batches: Iterable[List[str]]) -> Iterable[List[np.ndarray]]:
def gte(batches: Iterable[List[str]]) -> Iterable[List[Optional[np.ndarray]]]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = sentence_transformers.SentenceTransformer(
"thenlper/gte-base", device=device
Expand Down
8 changes: 5 additions & 3 deletions renumics/spotlight/embeddings/embedders/vit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, List
from typing import Iterable, List, Optional

import PIL.Image
import numpy as np
Expand All @@ -10,11 +10,13 @@
try:
import torch
except ImportError:
logger.warning("`ViTEmbedder` requires `pytorch` to be installed.")
logger.warning("ViT embedder requires `pytorch` to be installed.")
else:

@embed("image")
def vit(batches: Iterable[List[PIL.Image.Image]]) -> Iterable[List[np.ndarray]]:
def vit(
batches: Iterable[List[PIL.Image.Image]],
) -> Iterable[List[Optional[np.ndarray]]]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_name = "google/vit-base-patch16-224"
processor = transformers.AutoImageProcessor.from_pretrained(model_name)
Expand Down
8 changes: 5 additions & 3 deletions renumics/spotlight/embeddings/embedders/wav2vec2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, List
from typing import Iterable, List, Optional

import numpy as np
import transformers
Expand All @@ -9,11 +9,13 @@
try:
import torch
except ImportError:
logger.warning("`Wav2Vec Embedder` requires `pytorch` to be installed.")
logger.warning("Wav2Vec embedder requires `pytorch` to be installed.")
else:

@embed("audio", sampling_rate=16000)
def wav2vec2(batches: Iterable[List[np.ndarray]]) -> Iterable[List[np.ndarray]]:
def wav2vec2(
batches: Iterable[List[np.ndarray]],
) -> Iterable[List[Optional[np.ndarray]]]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_name = "facebook/wav2vec2-base-960h"
processor = transformers.AutoFeatureExtractor.from_pretrained(model_name)
Expand Down

0 comments on commit ba1adfc

Please sign in to comment.