Skip to content

Commit

Permalink
Use huggingface-cache if HF_HUB_OFFLINE is set
Browse files Browse the repository at this point in the history
You need to be online once to download the README.md which is used to fill the metadata
  • Loading branch information
TheMasterFX authored and fedirz committed Feb 14, 2025
1 parent 5e3780d commit 870831a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
19 changes: 15 additions & 4 deletions src/speaches/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def list_whisper_models() -> Generator[Model, None, None]:
yield transformed_model


def list_local_whisper_models() -> Generator[
tuple[huggingface_hub.CachedRepoInfo, huggingface_hub.ModelCardData], None, None
]:
def list_local_whisper_models() -> Generator[Model, None, None]:
hf_cache = huggingface_hub.scan_cache_dir()
hf_models = [repo for repo in list(hf_cache.repos) if repo.repo_type == "model"]
for model in hf_models:
Expand All @@ -76,7 +74,20 @@ def list_local_whisper_models() -> Generator[
and model_card_data.tags is not None
and TASK_NAME in model_card_data.tags
):
yield model, model_card_data
if model_card_data.language is None:
language = []
elif isinstance(model_card_data.language, str):
language = [model_card_data.language]
else:
language = model_card_data.language
transformed_model = Model(
id=model.repo_id,
created=int(model.last_modified),
object_="model",
owned_by=model.repo_id.split("/")[0],
language=language,
)
yield transformed_model


def model_id_from_path(repo_path: Path) -> str:
Expand Down
8 changes: 6 additions & 2 deletions src/speaches/routers/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Annotated

from fastapi import (
Expand All @@ -13,7 +14,7 @@
ListModelsResponse,
Model,
)
from speaches.hf_utils import list_whisper_models
from speaches.hf_utils import list_local_whisper_models, list_whisper_models

if TYPE_CHECKING:
from huggingface_hub.hf_api import ModelInfo
Expand All @@ -23,7 +24,10 @@

@router.get("/v1/models")
def get_models() -> ListModelsResponse:
whisper_models = list(list_whisper_models())
if os.getenv("HF_HUB_OFFLINE") is not None:
whisper_models = list(list_local_whisper_models())
else:
whisper_models = list(list_whisper_models())
return ListModelsResponse(data=whisper_models)


Expand Down

0 comments on commit 870831a

Please sign in to comment.