-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathreplicate.test.ts
108 lines (97 loc) · 3.49 KB
/
replicate.test.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import test from "ava";
import Replicate from "./replicate.js";
import { DefaultFetchHTTPClient, HTTPClient } from "./replicate.js";
interface HTTPResponseMocks {
get?: string | Record<string, any> | any[];
post?: string | Record<string, any> | any[];
}
interface MockHttpClient extends HTTPClient {}
class MockHttpClient {
constructor(httpResponseMocks: HTTPResponseMocks) {
for (const [httpVerb, mockedResponse] of Object.entries(
httpResponseMocks
)) {
const timelineOfResponses =
mockedResponse instanceof Array ? mockedResponse : [mockedResponse];
const currentResponseIndex = 0;
const mockedHttpFunction = () =>
Promise.resolve(
this[`${httpVerb}Response`][this[`${httpVerb}Index`]++]
);
this[httpVerb] = mockedHttpFunction;
this[`${httpVerb}Index`] = currentResponseIndex;
this[`${httpVerb}Response`] = timelineOfResponses;
}
}
}
// Wipe out environment variables like REPLICATE_API_TOKEN for unit tests
globalThis.process.env = {};
test("complains if no token or proxy url provided", (t) => {
const error = t.throws(() => new Replicate());
t.is(error?.message, "Missing Replicate token");
});
test("accepts a manual token", (t) => {
t.truthy(new Replicate({ token: "abctoken" }));
});
test("accepts an environment variable token", (t) => {
globalThis.process.env.REPLICATE_API_TOKEN = "abctoken";
t.truthy(new Replicate());
globalThis.process.env.REPLICATE_API_TOKEN = undefined;
});
test("accepts a proxy url in lieu of a token", (t) => {
t.truthy(new Replicate({ proxyUrl: "http://localhost.com:3000" }));
});
test("fetches details of a model", async (t) => {
const client = new MockHttpClient({
get: { results: [{ id: "1" }] },
});
const replicate = new Replicate({ httpClient: client, token: "abctoken" });
const model = await replicate.models.get("kuprel/min-dalle");
t.is(model.modelDetails.id, "1");
});
test("fetches details of a model of a specific version", async (t) => {
const client = new MockHttpClient({
get: { results: [{ id: "1" }] },
});
const replicate = new Replicate({ httpClient: client, token: "abctoken" });
const model = await replicate.models.get("kuprel/min-dalle", "1");
t.is(model.modelDetails.id, "1");
});
test("makes a prediction", async (t) => {
const client = new MockHttpClient({
get: [
{ results: [{ id: "1" }] }, // response for /versions
{ status: "succeeded", output: "expectedoutput" }, // for predictions/{id}
],
post: { status: "starting" },
});
const replicate = new Replicate({
httpClient: client,
token: "abctoken",
pollingInterval: 1,
});
const model = await replicate.models.get("kuprel/min-dalle");
const prediction = await model.predict();
t.is(prediction, "expectedoutput");
});
test("built-in http client gets & posts", async (t) => {
globalThis.fetch = (calledUrl, usedOptions) =>
Promise.resolve({ json: () => [calledUrl, usedOptions] }) as any;
const httpClient = new DefaultFetchHTTPClient("abctoken");
var [calledUrl, usedOptions] = await httpClient.get({
url: "https://api.replicate.com/v1/versions",
token: "",
method: "get",
event: "getModel",
});
t.is(calledUrl, "https://api.replicate.com/v1/versions");
t.is(usedOptions.headers.Authorization, "Token abctoken");
var [calledUrl, usedOptions] = await httpClient.post({
url: "/predictions",
token: "",
method: "post",
event: "startPrediction",
body: {},
});
t.is(usedOptions.body, "{}"); //?
});