Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLM server unit tests #90

Merged
merged 17 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: Run Unity Tests

on:
pull_request:
types: [closed]

jobs:
test:
name: Run Tests
runs-on: ubuntu-latest
if: startsWith(github.base_ref, 'release/') && github.event.pull_request.merged == true
steps:
- name: Checkout Repo
uses: actions/checkout@v4

- name: Setup Environment
run: |
mkdir Package~
mv Editor Package~/
mv Runtime Package~/
mv Tests Package~/
mv package.json Package~/package.json

- name: Run Tests
uses: game-ci/unity-test-runner@v4
env:
UNITY_EMAIL: ${{ secrets.UNITY_EMAIL }}
UNITY_PASSWORD: ${{ secrets.UNITY_PASSWORD }}
UNITY_LICENSE: ${{ secrets.UNITY_LICENSE }}
with:
packageMode: true
projectPath: Package~/
testMode: playmode
unityVersion: 2022.3.16f1
5 changes: 3 additions & 2 deletions Runtime/LLM.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ public class LLM : LLMClient
[ModelAdvanced] public int contextSize = 512;
[ModelAdvanced] public int batchSize = 512;

[HideInInspector] public readonly (string, string)[] modelOptions = new (string, string)[]{
[HideInInspector] public readonly (string, string)[] modelOptions = new(string, string)[]
{
("Download model", null),
("Phi 2 (small, best)", "https://huggingface.co/TheBloke/phi-2-GGUF/resolve/main/phi-2.Q4_K_M.gguf?download=true"),
("Mistral 7B Instruct v0.2 (medium, best)", "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/resolve/main/mistral-7b-instruct-v0.2.Q4_K_M.gguf?download=true")
Expand Down Expand Up @@ -144,7 +145,7 @@ private string SelectApeBinary()
return apeExe;
}

bool IsPortInUse()
public bool IsPortInUse()
{
try
{
Expand Down
13 changes: 8 additions & 5 deletions Runtime/LLMClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ public class LLMClient : MonoBehaviour
[Chat] public string AIName = "Assistant";
[TextArea(5, 10), Chat] public string prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.";

private string currentPrompt;
private List<ChatMessage> chat;
protected string currentPrompt;
protected List<ChatMessage> chat;
private List<(string, string)> requestHeaders;
public bool setNKeepToPrompt = true;

Expand Down Expand Up @@ -135,6 +135,7 @@ public async Task SetPrompt(string newPrompt, bool clearChat = true)
nKeep = -1;
await InitPrompt(clearChat);
}

private async Task InitNKeep()
{
if (setNKeepToPrompt && nKeep == -1)
Expand Down Expand Up @@ -171,15 +172,16 @@ public async void SetGrammar(string path)
{
grammar = await LLMUnitySetup.AddAsset(path, LLMUnitySetup.GetAssetPath());
}

#endif

private string RoleString(string role)
public string RoleString(string role)
{
// role as a delimited string for the model
return "\n### " + role + ":";
}

private string RoleMessageString(string role, string message)
public string RoleMessageString(string role, string message)
{
// role and the role message
return RoleString(role) + " " + message;
Expand Down Expand Up @@ -358,15 +360,16 @@ public async Task<Ret> PostRequest<Res, Ret>(string json, string endpoint, Conte
// Check if progress has changed
if (currentProgress != lastProgress && callback != null)
{
if (request.result != UnityWebRequest.Result.Success) throw new System.Exception(request.error);
callback?.Invoke(ConvertContent(request.downloadHandler.text, getContent));
lastProgress = currentProgress;
}
// Wait for the next frame
await Task.Yield();
}
if (request.result != UnityWebRequest.Result.Success) throw new System.Exception(request.error);
result = ConvertContent(request.downloadHandler.text, getContent);
callback?.Invoke(result);
if (request.result != UnityWebRequest.Result.Success) throw new System.Exception(request.error);
}
return result;
}
Expand Down
8 changes: 8 additions & 0 deletions Tests.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions Tests/Runtime.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

152 changes: 152 additions & 0 deletions Tests/Runtime/TestLLM.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
using NUnit.Framework;
using LLMUnity;
using UnityEngine;
using System.Threading.Tasks;
using System.Collections.Generic;
using System;
using System.Collections;
using UnityEngine.TestTools;

namespace LLMUnityTests
{
public class LLMNoAwake : LLM
{
public new void Awake() {}
public new void OnDestroy() {}

public void CallAwake()
{
base.Awake();
}

public void CallOnDestroy()
{
base.OnDestroy();
}

public List<ChatMessage> GetChat()
{
return chat;
}

public string GetCurrentPrompt()
{
return currentPrompt;
}
}


public class TestLLM
{
GameObject gameObject;
LLMNoAwake llm;
int port = 15555;
string AIReply = ":::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::";
Exception error = null;

public TestLLM()
{
Task task = Init();
task.Wait();
}

public async Task Init()
{
gameObject = new GameObject();
gameObject.SetActive(false);

llm = gameObject.AddComponent<LLMNoAwake>();

string modelUrl = "https://huggingface.co/aladar/tiny-random-BloomForCausalLM-GGUF/resolve/main/tiny-random-BloomForCausalLM.gguf?download=true";
string modelPath = "LLMUnityTests/tiny-random-BloomForCausalLM.gguf";
string fullModelPath = LLMUnitySetup.GetAssetPath(modelPath);
_ = LLMUnitySetup.DownloadFile(modelUrl, fullModelPath, false, false, null, null, false);
await llm.SetModel(fullModelPath);
Assert.AreEqual(llm.model, modelPath);

llm.port = port;
llm.prompt = "You";
llm.temperature = 0;
llm.seed = 0;
llm.stream = false;
llm.numPredict = 128;
llm.parallelPrompts = 1;

gameObject.SetActive(true);
}

public async Task RunTests()
{
error = null;
try
{
Assert.That(!llm.IsPortInUse());
llm.CallAwake();
TestAlive();
await llm.Tokenize("I", TestTokens);
await llm.Warmup();
TestInitParameters();
TestWarmup();
await llm.Chat("hi", TestChat);
TestPostChat();
await llm.SetPrompt("You are");
TestInitParameters();
}
catch (Exception e)
{
error = e;
}
}

[UnityTest]
public IEnumerator RunTestsWait()
{
Task task = RunTests();
while (!task.IsCompleted) yield return null;
llm.CallOnDestroy();
if (error != null)
{
Debug.LogError(error.ToString());
throw (error);
}
}

public void TestAlive()
{
Assert.That(llm.serverListening);
Assert.That(llm.IsPortInUse());
}

public async void TestInitParameters()
{
Assert.That(llm.nKeep == (await llm.Tokenize(llm.prompt)).Count);
Assert.That(llm.stop.Count > 0);
Assert.That(llm.GetCurrentPrompt() == llm.prompt);
Assert.That(llm.GetChat().Count == 1);
}

public void TestTokens(List<int> tokens)
{
Assert.AreEqual(tokens, new List<int> {44});
}

public void TestWarmup()
{
Assert.That(llm.GetChat().Count == 1);
Assert.That(llm.GetCurrentPrompt() == llm.prompt);
}

public void TestChat(string reply)
{
Assert.That(llm.GetCurrentPrompt() == llm.prompt);
Assert.That(reply == AIReply);
}

public void TestPostChat()
{
Assert.That(llm.GetChat().Count == 3);
string newPrompt = llm.prompt + llm.RoleMessageString(llm.playerName, "hi") + llm.RoleMessageString(llm.AIName, AIReply);
Assert.That(llm.GetCurrentPrompt() == newPrompt);
}
}
}
11 changes: 11 additions & 0 deletions Tests/Runtime/TestLLM.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions Tests/Runtime/undream.llmunity.Runtime.Tests.asmdef
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"name": "undream.llmunity.Runtime.Tests",
"references": [
"undream.llmunity.Runtime"
],
"optionalUnityReferences": [
"TestAssemblies"
]
}
7 changes: 7 additions & 0 deletions Tests/Runtime/undream.llmunity.Runtime.Tests.asmdef.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
"unity": "2022.3",
"unityRelease": "16f1",
"documentationUrl": "https://github.com/undreamai/LLMUnity",
"testables": [
"undream.llmunity.Runtime.Tests"
],
"keywords": [
"llm",
"large language model",
Expand Down
Loading