From 4a7135785234d21d0f7eeee47387fb2f68c3f753 Mon Sep 17 00:00:00 2001 From: Sterne Lee Date: Wed, 2 Oct 2024 07:48:16 +0800 Subject: [PATCH 1/2] feat: pass provider by model name --- src/middlewares/requestValidator/index.ts | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/middlewares/requestValidator/index.ts b/src/middlewares/requestValidator/index.ts index 7707c94d0..6664587f0 100644 --- a/src/middlewares/requestValidator/index.ts +++ b/src/middlewares/requestValidator/index.ts @@ -1,9 +1,28 @@ import { Context } from 'hono'; import { CONTENT_TYPES, POWERED_BY, VALID_PROVIDERS } from '../../globals'; import { configSchema } from './schema/config'; +import Providers from '../../providers'; -export const requestValidator = (c: Context, next: any) => { +const providers = Object.keys(Providers); + +export const requestValidator = async (c: Context, next: any) => { const requestHeaders = Object.fromEntries(c.req.raw.headers); + const clonedReq = c.req.raw.clone(); + const originalBody = await clonedReq.text(); + const modifiedBody = JSON.parse(originalBody); + const [provider, ...modelNames] = modifiedBody.model.split(':'); + if (providers.includes(provider)) { + requestHeaders[`x-${POWERED_BY}-provider`] = provider; + modifiedBody.model = modelNames.join(':') || undefined; + const newHeaders = new Headers(requestHeaders); + const newRequest = new Request(c.req.raw.url, { + method: c.req.raw.method, + headers: newHeaders, + body: JSON.stringify(modifiedBody), + }); + + c.req.raw = newRequest; + } const contentType = requestHeaders['content-type']; if ( @@ -151,5 +170,5 @@ export const requestValidator = (c: Context, next: any) => { ); } } - return next(); + return await next(); }; From 2498e5eb4c3e8db182ef0368093fa430b6815891 Mon Sep 17 00:00:00 2001 From: Sterne Lee Date: Sun, 6 Oct 2024 10:49:34 +0800 Subject: [PATCH 2/2] feat: get provider by url query in get method --- src/middlewares/requestValidator/index.ts | 42 ++++++++++++++++------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/src/middlewares/requestValidator/index.ts b/src/middlewares/requestValidator/index.ts index 6664587f0..e44a7942b 100644 --- a/src/middlewares/requestValidator/index.ts +++ b/src/middlewares/requestValidator/index.ts @@ -8,20 +8,36 @@ const providers = Object.keys(Providers); export const requestValidator = async (c: Context, next: any) => { const requestHeaders = Object.fromEntries(c.req.raw.headers); const clonedReq = c.req.raw.clone(); - const originalBody = await clonedReq.text(); - const modifiedBody = JSON.parse(originalBody); - const [provider, ...modelNames] = modifiedBody.model.split(':'); - if (providers.includes(provider)) { - requestHeaders[`x-${POWERED_BY}-provider`] = provider; - modifiedBody.model = modelNames.join(':') || undefined; - const newHeaders = new Headers(requestHeaders); - const newRequest = new Request(c.req.raw.url, { - method: c.req.raw.method, - headers: newHeaders, - body: JSON.stringify(modifiedBody), - }); + if (clonedReq.method === 'GET') { + const provider = new URLSearchParams(new URL(clonedReq.url).search).get( + 'provider' + ); + if (provider && providers.includes(provider)) { + requestHeaders[`x-${POWERED_BY}-provider`] = provider; + const newHeaders = new Headers(requestHeaders); + const newRequest = new Request(c.req.raw.url, { + method: c.req.raw.method, + headers: newHeaders, + }); + + c.req.raw = newRequest; + } + } else { + const originalBody = await clonedReq.text(); + const modifiedBody = JSON.parse(originalBody); + const [provider, ...modelNames] = modifiedBody.model.split(':'); + if (providers.includes(provider)) { + requestHeaders[`x-${POWERED_BY}-provider`] = provider; + modifiedBody.model = modelNames.join(':') || undefined; + const newHeaders = new Headers(requestHeaders); + const newRequest = new Request(c.req.raw.url, { + method: c.req.raw.method, + headers: newHeaders, + body: JSON.stringify(modifiedBody), + }); - c.req.raw = newRequest; + c.req.raw = newRequest; + } } const contentType = requestHeaders['content-type'];