From 2436eeb3787208de9f9c0b30230127a60b414521 Mon Sep 17 00:00:00 2001 From: Marcin Antas Date: Sun, 29 Oct 2023 11:01:39 +0100 Subject: [PATCH 1/2] Add support for ONNX AI models --- .dockerignore | 8 +++ .github/workflows/main.yaml | 40 +++++++++++- Dockerfile | 5 +- app.py | 19 ++++-- cicd/build.sh | 6 +- cicd/docker_push.sh | 21 +++--- cicd/markdown_table_from_api.py | 18 ------ cicd/travis_yml_to_markdown_table.py | 24 ------- custom.Dockerfile | 2 +- download.py | 96 +++++++++++++++++++++------- meta.py | 2 +- requirements-test.txt | 3 + requirements.txt | 5 +- vectorizer.py | 42 ++++++++++-- 14 files changed, 200 insertions(+), 91 deletions(-) delete mode 100755 cicd/markdown_table_from_api.py delete mode 100755 cicd/travis_yml_to_markdown_table.py diff --git a/.dockerignore b/.dockerignore index 4b32cea..66ccb58 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,2 +1,10 @@ +__pycache__ +.github +.venv +.vscode +cicd models nltk_data +smoke_test.py +test_app.py +requirements-test.txt diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index cceb21a..3b43c2d 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -20,54 +20,90 @@ jobs: include: - model_name: distilbert-base-uncased model_tag_name: distilbert-base-uncased + onnx_runtime: false - model_name: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 model_tag_name: sentence-transformers-paraphrase-multilingual-MiniLM-L12-v2 + onnx_runtime: false - model_name: sentence-transformers/multi-qa-MiniLM-L6-cos-v1 model_tag_name: sentence-transformers-multi-qa-MiniLM-L6-cos-v1 + onnx_runtime: false - model_name: sentence-transformers/multi-qa-mpnet-base-cos-v1 model_tag_name: sentence-transformers-multi-qa-mpnet-base-cos-v1 + onnx_runtime: false - model_name: sentence-transformers/all-mpnet-base-v2 model_tag_name: sentence-transformers-all-mpnet-base-v2 + onnx_runtime: false - model_name: sentence-transformers/all-MiniLM-L12-v2 model_tag_name: sentence-transformers-all-MiniLM-L12-v2 + onnx_runtime: false - model_name: sentence-transformers/paraphrase-multilingual-mpnet-base-v2 model_tag_name: sentence-transformers-paraphrase-multilingual-mpnet-base-v2 + onnx_runtime: false - model_name: sentence-transformers/all-MiniLM-L6-v2 model_tag_name: sentence-transformers-all-MiniLM-L6-v2 + onnx_runtime: false - model_name: sentence-transformers/multi-qa-distilbert-cos-v1 model_tag_name: sentence-transformers-multi-qa-distilbert-cos-v1 + onnx_runtime: false - model_name: sentence-transformers/gtr-t5-base model_tag_name: sentence-transformers-gtr-t5-base + onnx_runtime: false - model_name: sentence-transformers/gtr-t5-large model_tag_name: sentence-transformers-gtr-t5-large + onnx_runtime: false - model_name: sentence-transformers/sentence-t5-base model_tag_name: sentence-transformers-sentence-t5-base + onnx_runtime: false - model_name: vblagoje/dpr-ctx_encoder-single-lfqa-wiki model_tag_name: vblagoje-dpr-ctx_encoder-single-lfqa-wiki + onnx_runtime: false - model_name: vblagoje/dpr-question_encoder-single-lfqa-wiki model_tag_name: vblagoje-dpr-question_encoder-single-lfqa-wiki + onnx_runtime: false - model_name: facebook/dpr-ctx_encoder-single-nq-base model_tag_name: facebook-dpr-ctx_encoder-single-nq-base + onnx_runtime: false - model_name: facebook/dpr-question_encoder-single-nq-base model_tag_name: facebook-dpr-question_encoder-single-nq-base + onnx_runtime: false - model_name: google/flan-t5-base model_tag_name: google-flan-t5-base + onnx_runtime: false - model_name: google/flan-t5-large model_tag_name: google-flan-t5-large + onnx_runtime: false - model_name: biu-nlp/abstract-sim-sentence model_tag_name: biu-nlp-abstract-sim-sentence + onnx_runtime: false - model_name: biu-nlp/abstract-sim-query model_tag_name: biu-nlp-abstract-sim-query + onnx_runtime: false + - model_name: BAAI/bge-small-en + model_tag_name: baai-bge-small-en + onnx_runtime: true + - model_name: BAAI/bge-small-en-v1.5 + model_tag_name: baai-bge-small-en-v1.5 + onnx_runtime: true + - model_name: BAAI/bge-base-en + model_tag_name: baai-bge-base-en + onnx_runtime: true + - model_name: BAAI/bge-base-en-v1.5 + model_tag_name: baai-bge-base-en-v1.5 + onnx_runtime: true + - model_name: sentence-transformers/all-MiniLM-L6-v2 + model_tag_name: sentence-transformers-all-MiniLM-L6-v2 + onnx_runtime: true env: LOCAL_REPO: transformers-inference REMOTE_REPO: semitechnologies/transformers-inference MODEL_NAME: ${{matrix.model_name}} MODEL_TAG_NAME: ${{matrix.model_tag_name}} + ONNX_RUNTIME: ${{matrix.onnx_runtime}} steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: - python-version: "3.10" + python-version: "3.11" cache: 'pip' # caching pip dependencies - name: Login to Docker Hub if: ${{ !github.event.pull_request.head.repo.fork }} # no PRs from fork @@ -96,7 +132,7 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: - python-version: "3.10" + python-version: "3.11" - name: Login to Docker Hub if: ${{ !github.event.pull_request.head.repo.fork }} # no PRs from fork uses: docker/login-action@v2 diff --git a/Dockerfile b/Dockerfile index b6e8c57..1ec26c2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.10-slim +FROM python:3.11-slim WORKDIR /app @@ -8,7 +8,10 @@ RUN pip install --upgrade pip setuptools COPY requirements.txt . RUN pip3 install -r requirements.txt +ARG TARGETARCH ARG MODEL_NAME +ARG ONNX_RUNTIME +ENV ONNX_CPU=${TARGETARCH} COPY download.py . RUN ./download.py diff --git a/app.py b/app.py index 6650949..1fb3ab6 100644 --- a/app.py +++ b/app.py @@ -43,16 +43,25 @@ def startup_event(): if transformers_direct_tokenize is not None and transformers_direct_tokenize == "true" or transformers_direct_tokenize == "1": direct_tokenize = True + model_dir = "./models/model" def get_model_directory() -> str: - if os.path.exists("./models/model/model_name"): - with open("./models/model/model_name", "r") as f: + if os.path.exists(f"{model_dir}/model_name"): + with open(f"{model_dir}/model_name", "r") as f: model_name = f.read() - return f"./models/model/{model_name}" - return "./models/model" + return f"{model_dir}/{model_name}" + return model_dir + + def get_onnx_runtime() -> bool: + if os.path.exists(f"{model_dir}/onnx_runtime"): + with open(f"{model_dir}/onnx_runtime", "r") as f: + onnx_runtime = f.read() + return onnx_runtime == "true" + return False meta_config = Meta(get_model_directory()) vec = Vectorizer(get_model_directory(), cuda_support, cuda_core, cuda_per_process_memory_fraction, - meta_config.getModelType(), meta_config.get_architecture(), direct_tokenize) + meta_config.get_model_type(), meta_config.get_architecture(), + direct_tokenize, get_onnx_runtime()) @app.get("/.well-known/live", response_class=Response) diff --git a/cicd/build.sh b/cicd/build.sh index 713b698..c705643 100755 --- a/cicd/build.sh +++ b/cicd/build.sh @@ -4,5 +4,9 @@ set -eou pipefail local_repo=${LOCAL_REPO?Variable LOCAL_REPO is required} model_name=${MODEL_NAME?Variable MODEL_NAME is required} +onnx_runtime=${ONNX_RUNTIME?Variable ONNX_RUNTIME is required} -docker build --build-arg "MODEL_NAME=$model_name" -t "$local_repo" . +docker build \ + --build-arg "MODEL_NAME=$model_name" \ + --build-arg "ONNX_RUNTIME=$onnx_runtime" \ + -t "$local_repo" . diff --git a/cicd/docker_push.sh b/cicd/docker_push.sh index 1e28017..48535c6 100755 --- a/cicd/docker_push.sh +++ b/cicd/docker_push.sh @@ -2,16 +2,11 @@ set -eou pipefail -# Docker push rules -# If on tag (e.g. 1.0.0) -# - any commit is pushed as :- -# - any commit is pushed as :-latest -# - any commit is pushed as : -git_hash= remote_repo=${REMOTE_REPO?Variable REMOTE_REPO is required} model_name=${MODEL_NAME?Variable MODEL_NAME is required} docker_username=${DOCKER_USERNAME?Variable DOCKER_USERNAME is required} docker_password=${DOCKER_PASSWORD?Variable DOCKER_PASSWORD is required} +onnx_runtime=${ONNX_RUNTIME?Variable ONNX_RUNTIME is required} original_model_name=$model_name git_tag=$GITHUB_REF_NAME @@ -20,6 +15,7 @@ function main() { echo "git ref type is $GITHUB_REF_TYPE" echo "git ref name is $GITHUB_REF_NAME" echo "git tag is $git_tag" + echo "onnx_runtime is $onnx_runtime" push_tag } @@ -31,8 +27,6 @@ function init() { model_name="$MODEL_TAG_NAME" fi - git_hash="$(git rev-parse HEAD | head -c 7)" - docker run --rm --privileged multiarch/qemu-user-static --reset -p yes docker buildx create --use echo "$docker_password" | docker login -u "$docker_username" --password-stdin @@ -40,13 +34,18 @@ function init() { function push_tag() { if [ ! -z "$git_tag" ] && [ "$GITHUB_REF_TYPE" == "tag" ]; then - tag_git="$remote_repo:$model_name-$git_tag" - tag_latest="$remote_repo:$model_name-latest" - tag="$remote_repo:$model_name" + model_name_part=$model_name + if [ "$onnx_runtime" == "true" ]; then + model_name_part="$model_name-onnx" + fi + tag_git="$remote_repo:$model_name_part-$git_tag" + tag_latest="$remote_repo:$model_name_part-latest" + tag="$remote_repo:$model_name_part" echo "Tag & Push $tag, $tag_latest, $tag_git" docker buildx build --platform=linux/arm64,linux/amd64 \ --build-arg "MODEL_NAME=$original_model_name" \ + --build-arg "ONNX_RUNTIME=$onnx_runtime" \ --push \ --tag "$tag_git" \ --tag "$tag_latest" \ diff --git a/cicd/markdown_table_from_api.py b/cicd/markdown_table_from_api.py deleted file mode 100755 index 21f779c..0000000 --- a/cicd/markdown_table_from_api.py +++ /dev/null @@ -1,18 +0,0 @@ -#!/usr/bin/env python3 - -import requests - - -print("|Model Name|Description|Image Name|") -print("|---|---|---|") - -res = requests.get("https://configuration.semi.technology/v2/parameters/transformers_model?media_type=text&weaviate_version=v1.4.1&text_module=text2vec-transformers") -asJSON = res.json() - -for opt in asJSON["options"]: - name=opt["displayName"] - description=opt["description"].replace('\n', '') - image='semitechnologies/transformers-inference:' + opt["name"] - if opt["name"] == "_custom": - continue - print(f"|{name}|{description}|{image}|") diff --git a/cicd/travis_yml_to_markdown_table.py b/cicd/travis_yml_to_markdown_table.py deleted file mode 100755 index b4c2a1b..0000000 --- a/cicd/travis_yml_to_markdown_table.py +++ /dev/null @@ -1,24 +0,0 @@ -#!/usr/bin/env python3 - -import yaml - -print("|Model Name|Image Name|") -print("|---|---|") - -with open(".travis.yml", 'r') as stream: - try: - travis = yaml.safe_load(stream) - for model in travis['jobs']['include']: - if model['stage'] != "buildanddeploy": - continue - - model_name = model['env']['MODEL_NAME'] - tag = model_name - if 'MODEL_TAG_NAME' in model['env']: - tag = model['env']['MODEL_TAG_NAME'] - - image_name = 'semitechnologies/transformers-inference:' + tag - link = 'https://huggingface.co/' + model_name - print("|`{}` ([Info]({}))|`{}`|".format(model_name, link, image_name)) - except yaml.YAMLError as exc: - print(exc) diff --git a/custom.Dockerfile b/custom.Dockerfile index a020a4d..dea87a0 100644 --- a/custom.Dockerfile +++ b/custom.Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.10-slim +FROM python:3.11-slim WORKDIR /app diff --git a/download.py b/download.py index 4d53d50..791fb09 100755 --- a/download.py +++ b/download.py @@ -9,6 +9,10 @@ AutoConfig, ) from sentence_transformers import SentenceTransformer +from optimum.onnxruntime import ORTModelForFeatureExtraction +from optimum.onnxruntime.configuration import AutoQuantizationConfig +from optimum.onnxruntime import ORTQuantizer +from pathlib import Path model_dir = './models/model' @@ -22,30 +26,78 @@ if force_automodel: print(f"Using AutoModel for {model_name} to instantiate model") -print(f"Downloading model {model_name} from huggingface model hub") -config = AutoConfig.from_pretrained(model_name) -model_type = config.to_dict()['model_type'] +onnx_runtime = os.getenv('ONNX_RUNTIME') +if not onnx_runtime: + onnx_runtime = "false" -if model_type is not None and model_type == "t5": - SentenceTransformer(model_name, cache_folder=model_dir) - with open(f"{model_dir}/model_name", "w") as f: - f.write(model_name.replace("/", "_")) -else: - if config.architectures and not force_automodel: - print(f"Using class {config.architectures[0]} to load model weights") - mod = __import__('transformers', fromlist=[config.architectures[0]]) - try: - klass_architecture = getattr(mod, config.architectures[0]) - model = klass_architecture.from_pretrained(model_name) - except AttributeError: - print(f"{config.architectures[0]} not found in transformers, fallback to AutoModel") - model = AutoModel.from_pretrained(model_name) - else: - model = AutoModel.from_pretrained(model_name) +onnx_cpu_arch = os.getenv('ONNX_CPU') +if not onnx_cpu_arch: + onnx_cpu_arch = "arm64" + +print(f"Downloading MODEL_NAME={model_name} with FORCE_AUTOMODEL={force_automodel} ONNX_RUNTIME={onnx_runtime} ONNX_CPU={onnx_cpu_arch}") +def download_onnx_model(model_name: str, model_dir: str): + # Download model and tokenizer + onnx_path = Path(model_dir) + ort_model = ORTModelForFeatureExtraction.from_pretrained(model_name, from_transformers=True) + # Save model + ort_model.save_pretrained(onnx_path) + + def quantization_config(onnx_cpu_arch: str): + if onnx_cpu_arch.lower() == "avx512_vnni": + print("Quantize Model for x86_64 (amd64) (avx512_vnni)") + return AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False) + if onnx_cpu_arch.lower() == "arm64": + print(f"Quantize Model for ARM64") + return AutoQuantizationConfig.arm64(is_static=False, per_channel=False) + # default is AMD64 + print(f"Quantize Model for x86_64 (amd64) (AVX2)") + return AutoQuantizationConfig.avx2(is_static=False, per_channel=False) + + # Quantize the model / convert to ONNX + qconfig = quantization_config(onnx_cpu_arch) + quantizer = ORTQuantizer.from_pretrained(ort_model) + # Apply dynamic quantization on the model + quantizer.quantize(save_dir=onnx_path, quantization_config=qconfig) + # Remove model.onnx file, leave only model_quantized.onnx + if os.path.isfile(f"{model_dir}/model.onnx"): + os.remove(f"{model_dir}/model.onnx") + # Save information about ONNX runtime + with open(f"{model_dir}/onnx_runtime", "w") as f: + f.write(onnx_runtime) tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.save_pretrained(onnx_path) - model.save_pretrained(model_dir) - tokenizer.save_pretrained(model_dir) +def download_model(model_name: str, model_dir: str): + print(f"Downloading model {model_name} from huggingface model hub") + config = AutoConfig.from_pretrained(model_name) + model_type = config.to_dict()['model_type'] -nltk.download('punkt', download_dir='./nltk_data') + if model_type is not None and model_type == "t5": + SentenceTransformer(model_name, cache_folder=model_dir) + with open(f"{model_dir}/model_name", "w") as f: + f.write(model_name.replace("/", "_")) + else: + if config.architectures and not force_automodel: + print(f"Using class {config.architectures[0]} to load model weights") + mod = __import__('transformers', fromlist=[config.architectures[0]]) + try: + klass_architecture = getattr(mod, config.architectures[0]) + model = klass_architecture.from_pretrained(model_name) + except AttributeError: + print(f"{config.architectures[0]} not found in transformers, fallback to AutoModel") + model = AutoModel.from_pretrained(model_name) + else: + model = AutoModel.from_pretrained(model_name) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + + model.save_pretrained(model_dir) + tokenizer.save_pretrained(model_dir) + + nltk.download('punkt', download_dir='./nltk_data') + +if onnx_runtime == "true": + download_onnx_model(model_name, model_dir) +else: + download_model(model_name, model_dir) diff --git a/meta.py b/meta.py index 3887d67..1574b03 100644 --- a/meta.py +++ b/meta.py @@ -12,7 +12,7 @@ def get(self): 'model': self.config.to_dict() } - def getModelType(self): + def get_model_type(self): return self.config.to_dict()['model_type'] def get_architecture(self): diff --git a/requirements-test.txt b/requirements-test.txt index 6dc578d..87b827e 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -6,4 +6,7 @@ nltk==3.8.1 torch==2.0.1 sentencepiece==0.1.97 sentence-transformers==2.2.2 +optimum==1.13.2 +onnxruntime==1.16.1 +onnx==1.14.1 pytest diff --git a/requirements.txt b/requirements.txt index 5091cd0..012885e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,10 @@ -transformers==4.34.0 +transformers==4.34.1 fastapi==0.103.2 uvicorn==0.23.2 nltk==3.8.1 torch==2.0.1 sentencepiece==0.1.99 sentence-transformers==2.2.2 +optimum==1.13.2 +onnxruntime==1.16.1 +onnx==1.14.1 diff --git a/vectorizer.py b/vectorizer.py index 3051609..4385dd8 100644 --- a/vectorizer.py +++ b/vectorizer.py @@ -3,6 +3,8 @@ import math from typing import Optional import torch +import torch.nn.functional as F +from pathlib import Path import nltk from nltk.tokenize import sent_tokenize from pydantic import BaseModel @@ -15,6 +17,7 @@ DPRQuestionEncoder, ) from sentence_transformers import SentenceTransformer +from optimum.onnxruntime import ORTModelForFeatureExtraction # limit transformer batch size to limit parallel inference, otherwise we run @@ -34,12 +37,15 @@ class VectorInput(BaseModel): class Vectorizer: executor: ThreadPoolExecutor - def __init__(self, model_path: str, cuda_support: bool, cuda_core: str, cuda_per_process_memory_fraction: float, model_type: str, architecture: str, direct_tokenize: bool): + def __init__(self, model_path: str, cuda_support: bool, cuda_core: str, cuda_per_process_memory_fraction: float, model_type: str, architecture: str, direct_tokenize: bool, onnx_runtime: bool): self.executor = ThreadPoolExecutor() - if model_type == 't5': - self.vectorizer = SentenceTransformerVectorizer(model_path, cuda_core) + if onnx_runtime: + self.vectorizer = ONNXVectorizer(model_path) else: - self.vectorizer = HuggingFaceVectorizer(model_path, cuda_support, cuda_core, cuda_per_process_memory_fraction, model_type, architecture, direct_tokenize) + if model_type == 't5': + self.vectorizer = SentenceTransformerVectorizer(model_path, cuda_core) + else: + self.vectorizer = HuggingFaceVectorizer(model_path, cuda_support, cuda_core, cuda_per_process_memory_fraction, model_type, architecture, direct_tokenize) async def vectorize(self, text: str, config: VectorInputConfig): return await asyncio.wrap_future(self.executor.submit(self.vectorizer.vectorize, text, config)) @@ -64,6 +70,34 @@ def vectorize(self, text: str, config: VectorInputConfig): return embedding[0] +class ONNXVectorizer: + model: ORTModelForFeatureExtraction + tokenizer: AutoTokenizer + + def __init__(self, model_path) -> None: + onnx_path = Path(model_path) + self.model = ORTModelForFeatureExtraction.from_pretrained(onnx_path, file_name="model_quantized.onnx") + self.tokenizer = AutoTokenizer.from_pretrained(onnx_path) + + def mean_pooling(self, model_output, attention_mask): + token_embeddings = model_output[0] #First element of model_output contains all token embeddings + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + + def vectorize(self, text: str, config: VectorInputConfig): + encoded_input = self.tokenizer([text], padding=True, truncation=True, return_tensors='pt') + # Compute token embeddings + with torch.no_grad(): + model_output = self.model(**encoded_input) + + # Perform pooling + sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask']) + + # Normalize embeddings + sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) + return sentence_embeddings[0] + + class HuggingFaceVectorizer: model: AutoModel tokenizer: AutoTokenizer From f5bbf90e4b70b745514c6c1cadc2d30f797b5370 Mon Sep 17 00:00:00 2001 From: Marcin Antas Date: Fri, 1 Dec 2023 08:18:01 +0100 Subject: [PATCH 2/2] Add quantization info --- README.md | 6 ++++++ app.py | 13 ++++++++++++- download.py | 15 ++++++++++++--- vectorizer.py | 4 ++-- 4 files changed, 32 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 633a7bc..a3b24ad 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,12 @@ The pre-built models include: |Bar-Ilan University NLP Lab Models| |`biu-nlp/abstract-sim-sentence` ([Info](https://huggingface.co/biu-nlp/abstract-sim-sentence))|`semitechnologies/transformers-inference:biu-nlp-abstract-sim-sentence`| |`biu-nlp/abstract-sim-query` ([Info](https://huggingface.co/biu-nlp/abstract-sim-query))|`semitechnologies/transformers-inference:biu-nlp-abstract-sim-query`| +|ONNX Models| +|`BAAI/bge-small-en` ([Info](https://huggingface.co/BAAI/bge-small-en))|`semitechnologies/transformers-inference:baai-bge-small-en-onnx`| +|`BAAI/bge-small-en-v1.5` ([Info](https://huggingface.co/BAAI/bge-small-en-v1.5))|`semitechnologies/transformers-inference:baai-bge-small-en-v1.5-onnx`| +|`BAAI/bge-base-en` ([Info](https://huggingface.co/BAAI/bge-base-en))|`semitechnologies/transformers-inference:baai-bge-base-en-onnx`| +|`BAAI/bge-base-en-v1.5` ([Info](https://huggingface.co/BAAI/bge-base-en-v1.5))|`semitechnologies/transformers-inference:baai-bge-base-en-v1.5-onnx`| +|`sentence-transformers/all-MiniLM-L6-v2` ([Info](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2))|`semitechnologies/transformers-inference:sentence-transformers-all-MiniLM-L6-v2-onnx`| The above image names always point to the latest version of the inference diff --git a/app.py b/app.py index 1fb3ab6..b8c20c3 100644 --- a/app.py +++ b/app.py @@ -58,10 +58,21 @@ def get_onnx_runtime() -> bool: return onnx_runtime == "true" return False + def log_info_about_onnx(onnx_runtime: bool): + if onnx_runtime: + onnx_quantization_info = "missing" + if os.path.exists(f"{model_dir}/onnx_quantization_info"): + with open(f"{model_dir}/onnx_quantization_info", "r") as f: + onnx_quantization_info = f.read() + logger.info(f"Running ONNX vectorizer with quantized model for {onnx_quantization_info}") + + onnx_runtime = get_onnx_runtime() + log_info_about_onnx(onnx_runtime) + meta_config = Meta(get_model_directory()) vec = Vectorizer(get_model_directory(), cuda_support, cuda_core, cuda_per_process_memory_fraction, meta_config.get_model_type(), meta_config.get_architecture(), - direct_tokenize, get_onnx_runtime()) + direct_tokenize, onnx_runtime) @app.get("/.well-known/live", response_class=Response) diff --git a/download.py b/download.py index 791fb09..0b54bf5 100755 --- a/download.py +++ b/download.py @@ -43,15 +43,25 @@ def download_onnx_model(model_name: str, model_dir: str): # Save model ort_model.save_pretrained(onnx_path) + def save_to_file(filepath: str, content: str): + with open(filepath, "w") as f: + f.write(content) + + def save_quantization_info(arch: str): + save_to_file(f"{model_dir}/onnx_quantization_info", arch) + def quantization_config(onnx_cpu_arch: str): if onnx_cpu_arch.lower() == "avx512_vnni": print("Quantize Model for x86_64 (amd64) (avx512_vnni)") + save_quantization_info("AVX-512") return AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False) if onnx_cpu_arch.lower() == "arm64": print(f"Quantize Model for ARM64") + save_quantization_info("ARM64") return AutoQuantizationConfig.arm64(is_static=False, per_channel=False) - # default is AMD64 + # default is AMD64 (AVX2) print(f"Quantize Model for x86_64 (amd64) (AVX2)") + save_quantization_info("amd64 (AVX2)") return AutoQuantizationConfig.avx2(is_static=False, per_channel=False) # Quantize the model / convert to ONNX @@ -63,8 +73,7 @@ def quantization_config(onnx_cpu_arch: str): if os.path.isfile(f"{model_dir}/model.onnx"): os.remove(f"{model_dir}/model.onnx") # Save information about ONNX runtime - with open(f"{model_dir}/onnx_runtime", "w") as f: - f.write(onnx_runtime) + save_to_file(f"{model_dir}/onnx_runtime", onnx_runtime) tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.save_pretrained(onnx_path) diff --git a/vectorizer.py b/vectorizer.py index 4385dd8..8d13f34 100644 --- a/vectorizer.py +++ b/vectorizer.py @@ -78,12 +78,12 @@ def __init__(self, model_path) -> None: onnx_path = Path(model_path) self.model = ORTModelForFeatureExtraction.from_pretrained(onnx_path, file_name="model_quantized.onnx") self.tokenizer = AutoTokenizer.from_pretrained(onnx_path) - + def mean_pooling(self, model_output, attention_mask): token_embeddings = model_output[0] #First element of model_output contains all token embeddings input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) - + def vectorize(self, text: str, config: VectorInputConfig): encoded_input = self.tokenizer([text], padding=True, truncation=True, return_tensors='pt') # Compute token embeddings