From dc457c36d374de1ec3387be33cac75579e23261a Mon Sep 17 00:00:00 2001 From: John Flavin Date: Mon, 9 Sep 2024 13:01:30 -0500 Subject: [PATCH] Use TRUST_REMOTE_CODE env from config when downloading custom models --- download.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/download.py b/download.py index 7dc8617..f9a0fab 100755 --- a/download.py +++ b/download.py @@ -3,6 +3,7 @@ import os import sys import nltk +from config import TRUST_REMOTE_CODE from transformers import ( AutoModel, AutoTokenizer, @@ -82,9 +83,9 @@ def quantization_config(onnx_cpu_arch: str): tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.save_pretrained(onnx_path) -def download_model(model_name: str, model_dir: str): - print(f"Downloading model {model_name} from huggingface model hub") - config = AutoConfig.from_pretrained(model_name) +def download_model(model_name: str, model_dir: str, trust_remote_code: bool = False): + print(f"Downloading model {model_name} from huggingface model hub ({trust_remote_code=})") + config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code) model_type = config.to_dict()['model_type'] if (model_type is not None and model_type == "t5") or use_sentence_transformers_vectorizer.lower() == "true": @@ -100,11 +101,11 @@ def download_model(model_name: str, model_dir: str): model = klass_architecture.from_pretrained(model_name) except AttributeError: print(f"{config.architectures[0]} not found in transformers, fallback to AutoModel") - model = AutoModel.from_pretrained(model_name) + model = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code) else: - model = AutoModel.from_pretrained(model_name) + model = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code) - tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code) model.save_pretrained(model_dir) tokenizer.save_pretrained(model_dir) @@ -114,4 +115,4 @@ def download_model(model_name: str, model_dir: str): if onnx_runtime == "true": download_onnx_model(model_name, model_dir) else: - download_model(model_name, model_dir) + download_model(model_name, model_dir, TRUST_REMOTE_CODE)