Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Local tokenizer and processor for more consistent CI #16

Merged
merged 14 commits into from
Jun 10, 2024
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.json filter=lfs diff=lfs merge=lfs -text
Git LFS file not shown
3 changes: 3 additions & 0 deletions ultravox/assets/hf/Meta-Llama-3-8B-Instruct/tokenizer.json
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
3 changes: 3 additions & 0 deletions ultravox/assets/hf/wav2vec2-base-960h/special_tokens_map.json
Git LFS file not shown
3 changes: 3 additions & 0 deletions ultravox/assets/hf/wav2vec2-base-960h/tokenizer_config.json
Git LFS file not shown
3 changes: 3 additions & 0 deletions ultravox/assets/hf/wav2vec2-base-960h/vocab.json
Git LFS file not shown
21 changes: 7 additions & 14 deletions ultravox/inference/infer_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import logging
import os
from unittest import mock

import numpy as np
Expand All @@ -12,23 +10,21 @@
from ultravox.inference import infer
from ultravox.model import ultravox_processing

os.environ["TOKENIZERS_PARALLELISM"] = "false"


# We cache these files in our repo to make CI faster and also
# work properly for external contributions (since Llama 3 is gated).
@pytest.fixture(scope="module")
def tokenizer():
logging.info("Loading tokenizer")
yield transformers.AutoTokenizer.from_pretrained(
"meta-llama/Meta-Llama-3-8B-Instruct"
return transformers.AutoTokenizer.from_pretrained(
"./assets/hf/Meta-Llama-3-8B-Instruct", local_files_only=True
)
logging.info("Tearing down tokenizer")


@pytest.fixture(scope="module")
def audio_processor():
logging.info("Loading audio processor")
yield transformers.AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
logging.info("Tearing down audio processor")
return transformers.AutoProcessor.from_pretrained(
"./assets/hf/wav2vec2-base-960h", local_files_only=True
)


class FakeInference(infer.LocalInference):
Expand All @@ -50,9 +46,6 @@ def __init__(
self.model.device = "cpu"
self.model.generate = mock.MagicMock(return_value=[range(25)])

def __del__(self):
logging.info("Tearing down inference")


EXPECTED_TOKEN_IDS_START = [128000, 128006, 882, 128007]
EXPECTED_TOKEN_IDS_END = [128009, 128006, 78191, 128007, 271]
Expand Down
Loading