Skip to content

Commit

Permalink
Add local_files_only as env
Browse files Browse the repository at this point in the history
  • Loading branch information
Merk0ff committed Oct 30, 2024
1 parent 06d8448 commit 8921c76
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions vectorizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import asyncio
import math
from concurrent.futures import ThreadPoolExecutor
Expand Down Expand Up @@ -83,9 +84,17 @@ class SentenceTransformerVectorizer:
cuda_core: str

def __init__(self, model_path: str, model_name: str, cuda_core: str):
local_files_only = os.getenv("LOCAL_FILES_ONLY", "False").lower() in (
"true",
"1",
"t",
)
self.cuda_core = cuda_core
self.model = SentenceTransformer(
model_name, cache_folder=model_path, device=self.get_device()
model_name,
cache_folder=model_path,
device=self.get_device(),
local_files_only=local_files_only,
)
self.model.eval() # make sure we're in inference mode, not training

Expand Down Expand Up @@ -245,7 +254,6 @@ def vectorize(self, text: str, config: VectorInputConfig):


class HFModel:

def __init__(self, cuda_support: bool, cuda_core: str):
super().__init__()
self.model = None
Expand Down Expand Up @@ -317,7 +325,6 @@ def pool_sum(self, embeddings, attention_mask):


class DPRModel(HFModel):

def __init__(self, architecture: str, cuda_support: bool, cuda_core: str):
super().__init__(cuda_support, cuda_core)
self.model = None
Expand All @@ -343,7 +350,6 @@ def pool_embedding(self, batch_results, tokens, config: VectorInputConfig):


class T5Model(HFModel):

def __init__(self, cuda_support: bool, cuda_core: str):
super().__init__(cuda_support, cuda_core)
self.model = None
Expand Down Expand Up @@ -384,7 +390,6 @@ def get_batch_results(self, tokens, text):


class ModelFactory:

@staticmethod
def model(model_type, architecture, cuda_support: bool, cuda_core: str):
if model_type == "t5":
Expand Down

0 comments on commit 8921c76

Please sign in to comment.