From 870831a444634677f732ae4d86f28fa0def2ddb3 Mon Sep 17 00:00:00 2001 From: TheMasterFX Date: Tue, 21 Jan 2025 20:39:48 +0100 Subject: [PATCH] Use huggingface-cache if HF_HUB_OFFLINE is set You need to be online once to download the README.md which is used to fill the metadata --- src/speaches/hf_utils.py | 19 +++++++++++++++---- src/speaches/routers/models.py | 8 ++++++-- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/speaches/hf_utils.py b/src/speaches/hf_utils.py index 3b9dfde0..591b4a83 100644 --- a/src/speaches/hf_utils.py +++ b/src/speaches/hf_utils.py @@ -54,9 +54,7 @@ def list_whisper_models() -> Generator[Model, None, None]: yield transformed_model -def list_local_whisper_models() -> Generator[ - tuple[huggingface_hub.CachedRepoInfo, huggingface_hub.ModelCardData], None, None -]: +def list_local_whisper_models() -> Generator[Model, None, None]: hf_cache = huggingface_hub.scan_cache_dir() hf_models = [repo for repo in list(hf_cache.repos) if repo.repo_type == "model"] for model in hf_models: @@ -76,7 +74,20 @@ def list_local_whisper_models() -> Generator[ and model_card_data.tags is not None and TASK_NAME in model_card_data.tags ): - yield model, model_card_data + if model_card_data.language is None: + language = [] + elif isinstance(model_card_data.language, str): + language = [model_card_data.language] + else: + language = model_card_data.language + transformed_model = Model( + id=model.repo_id, + created=int(model.last_modified), + object_="model", + owned_by=model.repo_id.split("/")[0], + language=language, + ) + yield transformed_model def model_id_from_path(repo_path: Path) -> str: diff --git a/src/speaches/routers/models.py b/src/speaches/routers/models.py index 39b06a0f..bcca8e0f 100644 --- a/src/speaches/routers/models.py +++ b/src/speaches/routers/models.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os from typing import TYPE_CHECKING, Annotated from fastapi import ( @@ -13,7 +14,7 @@ ListModelsResponse, Model, ) -from speaches.hf_utils import list_whisper_models +from speaches.hf_utils import list_local_whisper_models, list_whisper_models if TYPE_CHECKING: from huggingface_hub.hf_api import ModelInfo @@ -23,7 +24,10 @@ @router.get("/v1/models") def get_models() -> ListModelsResponse: - whisper_models = list(list_whisper_models()) + if os.getenv("HF_HUB_OFFLINE") is not None: + whisper_models = list(list_local_whisper_models()) + else: + whisper_models = list(list_whisper_models()) return ListModelsResponse(data=whisper_models)