diff --git a/Dockerfile b/Dockerfile index 1ec26c2..0fb30f5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.11-slim +FROM python:3.11-slim AS base_image WORKDIR /app @@ -8,13 +8,23 @@ RUN pip install --upgrade pip setuptools COPY requirements.txt . RUN pip3 install -r requirements.txt +FROM base_image AS download_model + +WORKDIR /app + ARG TARGETARCH ARG MODEL_NAME ARG ONNX_RUNTIME ENV ONNX_CPU=${TARGETARCH} +RUN mkdir nltk_data COPY download.py . RUN ./download.py +FROM base_image AS t2v_transformers + +WORKDIR /app +COPY --from=download_model /app/models /app/models +COPY --from=download_model /app/nltk_data /app/nltk_data COPY . . ENTRYPOINT ["/bin/sh", "-c"] diff --git a/download.py b/download.py index cd6bc5c..7dc8617 100755 --- a/download.py +++ b/download.py @@ -16,6 +16,7 @@ model_dir = './models/model' +nltk_dir = './nltk_data' model_name = os.getenv('MODEL_NAME', None) force_automodel = os.getenv('FORCE_AUTOMODEL', False) if not model_name: @@ -108,7 +109,7 @@ def download_model(model_name: str, model_dir: str): model.save_pretrained(model_dir) tokenizer.save_pretrained(model_dir) - nltk.download('punkt', download_dir='./nltk_data') + nltk.download('punkt', download_dir=nltk_dir) if onnx_runtime == "true": download_onnx_model(model_name, model_dir)