diff --git a/download.py b/download.py index 7dc8617..52ec42f 100755 --- a/download.py +++ b/download.py @@ -15,36 +15,43 @@ from pathlib import Path -model_dir = './models/model' -nltk_dir = './nltk_data' -model_name = os.getenv('MODEL_NAME', None) -force_automodel = os.getenv('FORCE_AUTOMODEL', False) +model_dir = "./models/model" +nltk_dir = "./nltk_data" +model_name = os.getenv("MODEL_NAME", None) +force_automodel = os.getenv("FORCE_AUTOMODEL", False) if not model_name: print("Fatal: MODEL_NAME is required") - print("Please set environment variable MODEL_NAME to a HuggingFace model name, see https://huggingface.co/models") + print( + "Please set environment variable MODEL_NAME to a HuggingFace model name, see https://huggingface.co/models" + ) sys.exit(1) if force_automodel: print(f"Using AutoModel for {model_name} to instantiate model") -onnx_runtime = os.getenv('ONNX_RUNTIME') +onnx_runtime = os.getenv("ONNX_RUNTIME") if not onnx_runtime: onnx_runtime = "false" -onnx_cpu_arch = os.getenv('ONNX_CPU') +onnx_cpu_arch = os.getenv("ONNX_CPU") if not onnx_cpu_arch: onnx_cpu_arch = "arm64" -use_sentence_transformers_vectorizer = os.getenv('USE_SENTENCE_TRANSFORMERS_VECTORIZER') +use_sentence_transformers_vectorizer = os.getenv("USE_SENTENCE_TRANSFORMERS_VECTORIZER") if not use_sentence_transformers_vectorizer: use_sentence_transformers_vectorizer = "false" -print(f"Downloading MODEL_NAME={model_name} with FORCE_AUTOMODEL={force_automodel} ONNX_RUNTIME={onnx_runtime} ONNX_CPU={onnx_cpu_arch}") +print( + f"Downloading MODEL_NAME={model_name} with FORCE_AUTOMODEL={force_automodel} ONNX_RUNTIME={onnx_runtime} ONNX_CPU={onnx_cpu_arch}" +) + def download_onnx_model(model_name: str, model_dir: str): # Download model and tokenizer onnx_path = Path(model_dir) - ort_model = ORTModelForFeatureExtraction.from_pretrained(model_name, from_transformers=True) + ort_model = ORTModelForFeatureExtraction.from_pretrained( + model_name, from_transformers=True + ) # Save model ort_model.save_pretrained(onnx_path) @@ -59,7 +66,9 @@ def quantization_config(onnx_cpu_arch: str): if onnx_cpu_arch.lower() == "avx512_vnni": print("Quantize Model for x86_64 (amd64) (avx512_vnni)") save_quantization_info("AVX-512") - return AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False) + return AutoQuantizationConfig.avx512_vnni( + is_static=False, per_channel=False + ) if onnx_cpu_arch.lower() == "arm64": print(f"Quantize Model for ARM64") save_quantization_info("ARM64") @@ -82,24 +91,29 @@ def quantization_config(onnx_cpu_arch: str): tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.save_pretrained(onnx_path) + def download_model(model_name: str, model_dir: str): print(f"Downloading model {model_name} from huggingface model hub") config = AutoConfig.from_pretrained(model_name) - model_type = config.to_dict()['model_type'] + model_type = config.to_dict()["model_type"] - if (model_type is not None and model_type == "t5") or use_sentence_transformers_vectorizer.lower() == "true": + if ( + model_type is not None and model_type == "t5" + ) or use_sentence_transformers_vectorizer.lower() == "true": SentenceTransformer(model_name, cache_folder=model_dir) with open(f"{model_dir}/model_name", "w") as f: f.write(model_name) else: if config.architectures and not force_automodel: print(f"Using class {config.architectures[0]} to load model weights") - mod = __import__('transformers', fromlist=[config.architectures[0]]) + mod = __import__("transformers", fromlist=[config.architectures[0]]) try: klass_architecture = getattr(mod, config.architectures[0]) model = klass_architecture.from_pretrained(model_name) except AttributeError: - print(f"{config.architectures[0]} not found in transformers, fallback to AutoModel") + print( + f"{config.architectures[0]} not found in transformers, fallback to AutoModel" + ) model = AutoModel.from_pretrained(model_name) else: model = AutoModel.from_pretrained(model_name) @@ -109,7 +123,9 @@ def download_model(model_name: str, model_dir: str): model.save_pretrained(model_dir) tokenizer.save_pretrained(model_dir) - nltk.download('punkt', download_dir=nltk_dir) + nltk.download("punkt", download_dir=nltk_dir) + nltk.download("punkt_tab", download_dir=nltk_dir) + if onnx_runtime == "true": download_onnx_model(model_name, model_dir) diff --git a/requirements-test.txt b/requirements-test.txt index fb84e20..dfc5234 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -2,7 +2,7 @@ requests==2.32.3 transformers==4.42.4 fastapi==0.112.0 uvicorn==0.30.5 -nltk==3.8.1 +nltk==3.9.1 torch==2.4.0 sentencepiece==0.2.0 sentence-transformers==3.0.1 diff --git a/requirements.txt b/requirements.txt index 88fd174..d2c8e98 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ transformers==4.42.4 fastapi==0.112.0 uvicorn==0.30.5 -nltk==3.8.1 +nltk==3.9.1 torch==2.4.0 sentencepiece==0.2.0 sentence-transformers==3.0.1