Skip to content

Commit

Permalink
make embeddimg model dim optional and remove pull_ollama_model
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh committed Jan 15, 2025
1 parent 4ba5eac commit 0a804a8
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 37 deletions.
2 changes: 1 addition & 1 deletion wren-ai-service/src/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def embedder_processor(entry: dict) -> dict:
returned[identifier] = {
"provider": entry["provider"],
"model": model["model"],
"dimension": model["dimension"],
"dimension": model.get("dimension"),
**others,
}

Expand Down
5 changes: 2 additions & 3 deletions wren-ai-service/src/providers/embedder/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,8 @@ def __init__(
dimension: int = (
int(os.getenv("EMBEDDING_MODEL_DIMENSION"))
if os.getenv("EMBEDDING_MODEL_DIMENSION")
else 0
)
or EMBEDDING_MODEL_DIMENSION,
else None,
),
timeout: Optional[float] = (
float(os.getenv("EMBEDDER_TIMEOUT"))
if os.getenv("EMBEDDER_TIMEOUT")
Expand Down
11 changes: 4 additions & 7 deletions wren-ai-service/src/providers/embedder/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tqdm import tqdm

from src.core.provider import EmbedderProvider
from src.providers.loader import provider, pull_ollama_model
from src.providers.loader import provider
from src.utils import remove_trailing_slash

logger = logging.getLogger("wren-ai-service")
Expand Down Expand Up @@ -159,12 +159,11 @@ def __init__(
self,
url: str = os.getenv("EMBEDDER_OLLAMA_URL") or EMBEDDER_OLLAMA_URL,
model: str = os.getenv("EMBEDDING_MODEL") or EMBEDDING_MODEL,
dimension: int = (
dimension: Optional[int] = (
int(os.getenv("EMBEDDING_MODEL_DIMENSION"))
if os.getenv("EMBEDDING_MODEL_DIMENSION")
else 0
)
or EMBEDDING_MODEL_DIMENSION,
else None
),
timeout: Optional[int] = (
int(os.getenv("EMBEDDER_TIMEOUT")) if os.getenv("EMBEDDER_TIMEOUT") else 120
),
Expand All @@ -175,8 +174,6 @@ def __init__(
self._embedding_model_dim = dimension
self._timeout = timeout

pull_ollama_model(self._url, self._embedding_model)

logger.info(f"Using Ollama Embedding Model: {self._embedding_model}")
logger.info(f"Using Ollama URL: {self._url}")

Expand Down
7 changes: 3 additions & 4 deletions wren-ai-service/src/providers/embedder/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,11 @@ def __init__(
api_base: str = os.getenv("EMBEDDER_OPENAI_API_BASE")
or EMBEDDER_OPENAI_API_BASE,
model: str = os.getenv("EMBEDDING_MODEL") or EMBEDDING_MODEL,
dimension: int = (
dimension: Optional[int] = (
int(os.getenv("EMBEDDING_MODEL_DIMENSION"))
if os.getenv("EMBEDDING_MODEL_DIMENSION")
else 0
)
or EMBEDDING_MODEL_DIMENSION,
else None
),
timeout: Optional[float] = (
float(os.getenv("EMBEDDER_TIMEOUT"))
if os.getenv("EMBEDDER_TIMEOUT")
Expand Down
4 changes: 1 addition & 3 deletions wren-ai-service/src/providers/llm/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from haystack_integrations.components.generators.ollama import OllamaGenerator

from src.core.provider import LLMProvider
from src.providers.loader import provider, pull_ollama_model
from src.providers.loader import provider
from src.utils import remove_trailing_slash

logger = logging.getLogger("wren-ai-service")
Expand Down Expand Up @@ -148,8 +148,6 @@ def __init__(
self._model_kwargs = kwargs
self._timeout = timeout

pull_ollama_model(self._url, self._model)

logger.info(f"Using Ollama LLM: {self._model}")
logger.info(f"Using Ollama URL: {self._url}")
logger.info(f"Using Ollama model kwargs: {self._model_kwargs}")
Expand Down
19 changes: 0 additions & 19 deletions wren-ai-service/src/providers/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import logging
import pkgutil

from ollama import Client

logger = logging.getLogger("wren-ai-service")


Expand Down Expand Up @@ -99,20 +97,3 @@ def get_default_embedding_model_dim(embedder_provider: str):
return importlib.import_module(
f"src.providers.embedder.{file_name}"
).EMBEDDING_MODEL_DIMENSION


# TODO: remove this function after litellm provider is stable; users should pull the model themselves. Make the solution simpler.
def pull_ollama_model(url: str, model_name: str):
client = Client(host=url)
models = [model["name"] for model in client.list()["models"]]
if model_name not in models:
logger.info(f"Pulling Ollama model {model_name}")
percentage = 0
for progress in client.pull(model_name, stream=True):
if "completed" in progress and "total" in progress:
new_percentage = int(progress["completed"] / progress["total"] * 100)
if new_percentage > percentage:
percentage = new_percentage
logger.info(f"Pulling Ollama model {model_name}: {percentage}%")
else:
logger.info(f"Ollama model {model_name} already exists")

0 comments on commit 0a804a8

Please sign in to comment.