-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathreplicate.js
126 lines (126 loc) · 4.93 KB
/
replicate.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
// Default configuration
const BASE_URL = "https://api.replicate.com/v1";
const DEFAULT_POLLING_INTERVAL = 5000;
// Utility functions
const sleep = (ms) => new Promise((resolve) => setTimeout(resolve, ms));
const isNode = typeof process !== "undefined" && process.versions != null && process.versions.node != null;
export class Replicate {
constructor({ token, proxyUrl, httpClient, pollingInterval } = {}) {
this.token = token;
this.baseUrl = proxyUrl ? `${proxyUrl}/${BASE_URL}` : BASE_URL;
this.httpClient = httpClient;
this.pollingInterval = pollingInterval;
// Uses some lesser-known operators to make null-safety easy
this.pollingInterval || (this.pollingInterval = DEFAULT_POLLING_INTERVAL);
this.token || (this.token = isNode ? process?.env?.REPLICATE_API_TOKEN : null);
if (!this.token && !proxyUrl)
throw new Error("Missing Replicate token");
if (!this.httpClient)
this.httpClient = new DefaultFetchHTTPClient(this.token);
this.models = {
get: (path, version = null) => ReplicateModel.fetch({ path, version, replicate: this }),
};
}
async getModel(path) {
return await this.callHttpClient({
url: `/models/${path}/versions`,
method: "get",
event: "getModel",
});
}
async getPrediction(id) {
return await this.callHttpClient({
url: `/predictions/${id}`,
method: "get",
event: "getPrediction",
});
}
async startPrediction(modelVersion, input, webhookCompleted=null) {
return await this.callHttpClient({
url: "/predictions",
method: "post",
event: "startPrediction",
body: { version: modelVersion, input: input, webhook_completed: webhookCompleted },
});
}
async callHttpClient({ url, method, event, body }) {
return await this.httpClient[method]({
url: `${this.baseUrl}${url}`,
method,
event,
body,
token: this.token,
});
}
}
export class ReplicateModel {
static async fetch(options) {
const model = new ReplicateModel(options);
await model.getModelDetails();
return model;
}
constructor({ path, version, replicate }) {
this.path = path;
this.version = version;
this.replicate = replicate;
}
async getModelDetails() {
const response = await this.replicate.getModel(this.path);
const modelVersions = response.results;
const mostRecentVersion = modelVersions[0];
const explicitlySelectedVersion = modelVersions.find((m) => m.id == this.version);
this.modelDetails = explicitlySelectedVersion ? explicitlySelectedVersion : mostRecentVersion;
if (this.version && this.version !== this.modelDetails.id) {
console.warn(`Model (version:${this.version}) not found, defaulting to ${mostRecentVersion.id}`);
}
}
async *predictor(input) {
const startResponse = await this.replicate.startPrediction(this.modelDetails.id, input);
let predictionStatus;
do {
const checkResponse = await this.replicate.getPrediction(startResponse.id);
predictionStatus = checkResponse.status;
await sleep(this.replicate.pollingInterval);
// TODO: only yield if there is a new prediction
yield checkResponse.output;
} while (["starting", "processing"].includes(predictionStatus));
}
async predict(input = "") {
let prediction;
for await (prediction of this.predictor(input)) {
// console.log(prediction);
}
return prediction;
}
}
// This class just makes it a bit easier to call fetch -- interface similar to the axios library
export class DefaultFetchHTTPClient {
constructor(token) {
this.headers = {
Authorization: `Token ${token}`,
"Content-Type": "application/json",
Accept: "application/json",
};
}
// This class uses fetch, which is still experimental in Node 18, so we import a polyfill for Node if fetch is not defined
async importFetch() {
if (isNode && !globalThis.fetch)
globalThis.fetch = (await import("node-fetch"))["default"];
}
async get({ url }) {
await this.importFetch();
const response = await fetch(url, { headers: this.headers });
return await response.json();
}
async post({ url, body }) {
await this.importFetch();
const fetchOptions = {
method: "POST",
headers: this.headers,
body: JSON.stringify(body),
};
const response = await fetch(url, fetchOptions);
return await response.json();
}
}
export default Replicate;