Skip to content

Commit

Permalink
chore: exclude keys list_supported_models
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Jan 18, 2024
1 parent d1969c5 commit 7a64189
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions fastembed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,15 @@ def embed(self, texts: Iterable[str], batch_size: int = 256, parallel: int = Non
raise NotImplementedError

@classmethod
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
def list_supported_models(cls, exclude: List[str] = []) -> List[Dict[str, Union[str, Union[int, float]]]]:
"""
Lists the supported models.
"""
models_file_path = Path(__file__).with_name("models.json")
models = json.load(open(str(models_file_path)))
with open(models_file_path, "r") as file:
models = json.load(file)

models = [{k: v for k, v in model.items() if k not in exclude} for model in models]

return models

Expand Down Expand Up @@ -264,7 +267,7 @@ def download_files_from_huggingface(cls, model_name: str, cache_dir: Optional[st
Returns:
Path: The path to the model directory.
"""
models = cls.list_supported_models()
models = cls.list_supported_models(exclude=["gcs_sources"])

hf_sources = [item for model in models if model["model"] == model_name for item in model["hf_sources"]]

Expand Down Expand Up @@ -343,7 +346,7 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:

model_tar_gz = Path(cache_dir) / f"{fast_model_name}.tar.gz"

models = self.list_supported_models()
models = self.list_supported_models(exclude=["hf_sources"])

gcs_sources = [item for model in models if model["model"] == model_name for item in model["gcs_sources"]]

Expand Down Expand Up @@ -520,12 +523,16 @@ def embed(
yield from normalize(embeddings[:, 0]).astype(np.float32)

@classmethod
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
def list_supported_models(
cls, exclude: List[str] = ["gcs_sources", "hf_sources"]
) -> List[Dict[str, Union[str, Union[int, float]]]]:
"""
Lists the supported models.
"""
# jina models are not supported by this class
return [model for model in super().list_supported_models() if not model["model"].startswith("jinaai")]
return [
model for model in super().list_supported_models(exclude=exclude) if not model["model"].startswith("jinaai")
]


class DefaultEmbedding(FlagEmbedding):
Expand Down Expand Up @@ -638,12 +645,16 @@ def embed(
yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)

@classmethod
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
def list_supported_models(
cls, exclude: List[str] = ["gcs_sources", "hf_sources"]
) -> List[Dict[str, Union[str, Union[int, float]]]]:
"""
Lists the supported models.
"""
# only jina models are supported by this class
return [model for model in Embedding.list_supported_models() if model["model"].startswith("jinaai")]
return [
model for model in Embedding.list_supported_models(exclude=exclude) if model["model"].startswith("jinaai")
]

@staticmethod
def mean_pooling(model_output, attention_mask):
Expand Down

0 comments on commit 7a64189

Please sign in to comment.