Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
I8dNLo committed Jan 15, 2025
1 parent c6facad commit e41197d
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Iterable, Optional, Sequence, Type, Union
from typing import Any, Iterable, Optional, Sequence, Type, Union

import numpy as np

Expand Down Expand Up @@ -185,12 +185,12 @@ class OnnxTextEmbedding(TextEmbeddingBase, OnnxTextModel[np.ndarray]):
"""Implementation of the Flag Embedding model."""

@classmethod
def list_supported_models(cls) -> list[Dict[str, Any]]:
def list_supported_models(cls) -> list[dict[str, Any]]:
"""
Lists the supported models.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the model information.
list[dict[str, Any]]: A list of dictionaries containing the model information.
"""
return supported_onnx_models

Expand Down Expand Up @@ -299,7 +299,13 @@ def _preprocess_onnx_input(

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
embeddings = output.model_output
return normalize(embeddings[:, 0]).astype(np.float32)
if embeddings.ndim == 3: # (batch_size, seq_len, embedding_dim)
processed_embeddings = embeddings[:, 0]
elif embeddings.ndim == 2: # (batch_size, embedding_dim)
processed_embeddings = embeddings
else:
raise ValueError(f"Unsupported embedding shape: {embeddings.shape}")
return normalize(processed_embeddings).astype(np.float32)

def load_onnx_model(self) -> None:
self._load_onnx_model(
Expand Down

0 comments on commit e41197d

Please sign in to comment.