Skip to content

Commit

Permalink
fix validator lookup and test
Browse files Browse the repository at this point in the history
  • Loading branch information
mjh1 committed Oct 2, 2024
1 parent 110d060 commit 22345df
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
14 changes: 10 additions & 4 deletions packages/api/src/controllers/generate.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ describe("controllers/generate", () => {
"image-to-video",
"upscale",
"segment-anything-2",
"llm",
];
for (const api of apis) {
aiGatewayServer.app.post(`/${api}`, async (req, res) => {
Expand Down Expand Up @@ -135,13 +136,18 @@ describe("controllers/generate", () => {
textFields: Record<string, any>,
multipartField = { name: "image", contentType: "image/png" },
) => {
const form = buildForm(textFields);
form.append(multipartField.name, "dummy", {
contentType: multipartField.contentType,
});
return form;
};

const buildForm = (textFields: Record<string, any>) => {
const form = new FormData();
for (const [k, v] of Object.entries(textFields)) {
form.append(k, v);
}
form.append(multipartField.name, "dummy", {
contentType: multipartField.contentType,
});
return form;
};

Expand Down Expand Up @@ -231,7 +237,7 @@ describe("controllers/generate", () => {
it("should call the AI Gateway for generate API /llm", async () => {
const res = await client.fetch("/beta/generate/llm", {
method: "POST",
body: buildMultipartBody({}),
body: buildForm({ prompt: "foo" }),
});
expect(res.status).toBe(200);
expect(await res.json()).toEqual({
Expand Down
4 changes: 4 additions & 0 deletions packages/api/src/controllers/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import { BadRequestError } from "../store/errors";
import { fetchWithTimeout, kebabToCamel } from "../util";
import { experimentSubjectsOnly } from "./experiment";
import { pathJoin2 } from "./helpers";
import validators from "../schema/validators";

const AI_GATEWAY_TIMEOUT = 10 * 60 * 1000; // 10 minutes
const RATE_LIMIT_WINDOW = 60 * 1000; // 1 minute
Expand Down Expand Up @@ -181,6 +182,9 @@ function registerGenerateHandler(
if (isJSONReq) {
payloadParsers = [validatePost(`${camelType}Params`)];
} else {
if (!validators[`Body_gen${camelType}`]) {
camelType = type.toUpperCase();
}
payloadParsers = [
multipart.any(),
validateFormData(`Body_gen${camelType}`),
Expand Down

0 comments on commit 22345df

Please sign in to comment.