Skip to content

Commit

Permalink
Update test to give consistent results
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasantony committed Apr 7, 2023
1 parent 18917dd commit 440798f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 13 deletions.
20 changes: 15 additions & 5 deletions tests/test_llama_context.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import sys
import array
import llamacpp
import pytest
Expand All @@ -7,13 +9,17 @@
def llama_context():
params = llamacpp.LlamaContextParams()
params.seed = 19472
return llamacpp.LlamaContext("../models/7B/ggml-model-f16.bin", params)
# Get path to current file
current_file_path = os.path.dirname(os.path.realpath(__file__))
# Get path to the model
model_path = os.path.join(current_file_path, "../models/7B/ggml-model-f16.bin")
return llamacpp.LlamaContext(model_path, params)


def test_str_to_token(llama_context):
prompt = "Hello World"
prompt_tokens = llama_context.str_to_token(prompt, True)
assert prompt_tokens == [1, 10994, 2787]
assert all(prompt_tokens == [1, 10994, 2787])


def test_token_to_str(llama_context):
Expand All @@ -39,10 +45,10 @@ def test_eval(llama_context):
top_k = 40
top_p = 0.95
temp = 0.8
repeat_last_n = 64
repeat_penalty = 1.0

# sending an empty array for the last n tokens
id = llama_context.sample_top_p_top_k(array.array('i', []), top_k, top_p, temp, repeat_last_n)
id = llama_context.sample_top_p_top_k(array.array('i', []), top_k, top_p, temp, repeat_penalty)
# add it to the context
embd.append(id)
# decrement remaining sampling budget
Expand All @@ -55,4 +61,8 @@ def test_eval(llama_context):
n_consumed += 1

output += ''.join([llama_context.token_to_str(id) for id in embd])
assert output == " Llama is the newest member of our farm family"
assert output == " Llama is the newest member of our growing family"


if __name__=='__main__':
sys.exit(pytest.main(['-s', '-v', __file__]))
24 changes: 16 additions & 8 deletions tests/test_llama_inference.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import os
import sys
import pytest
import llamacpp


@pytest.fixture(scope="session")
# @pytest.fixture(scope="session")
@pytest.fixture
def llama_model():
params = llamacpp.InferenceParams()
params.path_model = '../models/7B/ggml-model-f16.bin'
# Get path to current file
current_file_path = os.path.dirname(os.path.realpath(__file__))
# Get path to the model
model_path = os.path.join(current_file_path, "../models/7B/ggml-model-f16.bin")
params.path_model = model_path
params.seed = 19472
params.top_k = 40
params.top_p = 0.95
params.repeat_last_n = 64
params.n_predict = 8
params.repeat_penalty = 1.0
return llamacpp.LlamaInference(params)


Expand All @@ -35,7 +39,7 @@ def test_token_to_str(llama_model):


def test_eval(llama_model):
prompt = "Llama is"
prompt = " Llama is"
prompt_tokens = llama_model.tokenize(prompt, True)
llama_model.update_input(prompt_tokens)
llama_model.ingest_all_pending_input()
Expand All @@ -45,4 +49,8 @@ def test_eval(llama_model):
token = llama_model.sample()
output += llama_model.token_to_str(token)

assert output == " Llama is the newest member of our farm family"
assert output == " Llama is the newest member of our growing family"


if __name__=='__main__':
sys.exit(pytest.main(['-s', '-v', __file__]))

0 comments on commit 440798f

Please sign in to comment.