Skip to content

Commit

Permalink
Merge pull request #73 from weaviate/fix-sentence-transformers-models…
Browse files Browse the repository at this point in the history
…-vector-dimensions

Introduce experimental USE_SENTENCE_TRANSFORMERS_VECTORIZER environment setting
  • Loading branch information
antas-marcin authored Jan 29, 2024
2 parents 270d004 + d177d3a commit 55f01cd
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 9 deletions.
14 changes: 8 additions & 6 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ def startup_event():
direct_tokenize = True

model_dir = "./models/model"
def get_model_directory() -> str:
def get_model_directory() -> (str, bool):
if os.path.exists(f"{model_dir}/model_name"):
with open(f"{model_dir}/model_name", "r") as f:
model_name = f.read()
return f"{model_dir}/{model_name}"
return model_dir
return f"{model_dir}/{model_name}", True
# Default model directory is ./models/model
return model_dir, False

def get_onnx_runtime() -> bool:
if os.path.exists(f"{model_dir}/onnx_runtime"):
Expand All @@ -66,13 +67,14 @@ def log_info_about_onnx(onnx_runtime: bool):
onnx_quantization_info = f.read()
logger.info(f"Running ONNX vectorizer with quantized model for {onnx_quantization_info}")

model_dir, use_sentence_transformer_vectorizer = get_model_directory()
onnx_runtime = get_onnx_runtime()
log_info_about_onnx(onnx_runtime)

meta_config = Meta(get_model_directory())
vec = Vectorizer(get_model_directory(), cuda_support, cuda_core, cuda_per_process_memory_fraction,
meta_config = Meta(model_dir)
vec = Vectorizer(model_dir, cuda_support, cuda_core, cuda_per_process_memory_fraction,
meta_config.get_model_type(), meta_config.get_architecture(),
direct_tokenize, onnx_runtime)
direct_tokenize, onnx_runtime, use_sentence_transformer_vectorizer)


@app.get("/.well-known/live", response_class=Response)
Expand Down
6 changes: 5 additions & 1 deletion download.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
if not onnx_cpu_arch:
onnx_cpu_arch = "arm64"

use_sentence_transformers_vectorizer = os.getenv('USE_SENTENCE_TRANSFORMERS_VECTORIZER')
if not use_sentence_transformers_vectorizer:
use_sentence_transformers_vectorizer = "false"

print(f"Downloading MODEL_NAME={model_name} with FORCE_AUTOMODEL={force_automodel} ONNX_RUNTIME={onnx_runtime} ONNX_CPU={onnx_cpu_arch}")

def download_onnx_model(model_name: str, model_dir: str):
Expand Down Expand Up @@ -82,7 +86,7 @@ def download_model(model_name: str, model_dir: str):
config = AutoConfig.from_pretrained(model_name)
model_type = config.to_dict()['model_type']

if model_type is not None and model_type == "t5":
if (model_type is not None and model_type == "t5") or use_sentence_transformers_vectorizer.lower() == "true":
SentenceTransformer(model_name, cache_folder=model_dir)
with open(f"{model_dir}/model_name", "w") as f:
f.write(model_name.replace("/", "_"))
Expand Down
5 changes: 3 additions & 2 deletions vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@ class VectorInput(BaseModel):
class Vectorizer:
executor: ThreadPoolExecutor

def __init__(self, model_path: str, cuda_support: bool, cuda_core: str, cuda_per_process_memory_fraction: float, model_type: str, architecture: str, direct_tokenize: bool, onnx_runtime: bool):
def __init__(self, model_path: str, cuda_support: bool, cuda_core: str, cuda_per_process_memory_fraction: float,
model_type: str, architecture: str, direct_tokenize: bool, onnx_runtime: bool, use_sentence_transformer_vectorizer: bool):
self.executor = ThreadPoolExecutor()
if onnx_runtime:
self.vectorizer = ONNXVectorizer(model_path)
else:
if model_type == 't5':
if model_type == 't5' or use_sentence_transformer_vectorizer:
self.vectorizer = SentenceTransformerVectorizer(model_path, cuda_core)
else:
self.vectorizer = HuggingFaceVectorizer(model_path, cuda_support, cuda_core, cuda_per_process_memory_fraction, model_type, architecture, direct_tokenize)
Expand Down

0 comments on commit 55f01cd

Please sign in to comment.