Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Conditional Router #539

Merged
merged 10 commits into from
Aug 26, 2024
9 changes: 9 additions & 0 deletions src/errors/RouterError.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
export class RouterError extends Error {
constructor(
message: string,
public cause?: Error
) {
super(message);
this.name = 'RouterError';
}
}
1 change: 1 addition & 0 deletions src/globals.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export const HEADER_KEYS: Record<string, string> = {
PROVIDER: `x-${POWERED_BY}-provider`,
TRACE_ID: `x-${POWERED_BY}-trace-id`,
CACHE: `x-${POWERED_BY}-cache`,
METADATA: `x-${POWERED_BY}-metadata`,
FORWARD_HEADERS: `x-${POWERED_BY}-forward-headers`,
CUSTOM_HOST: `x-${POWERED_BY}-custom-host`,
REQUEST_TIMEOUT: `x-${POWERED_BY}-request-timeout`,
Expand Down
13 changes: 11 additions & 2 deletions src/handlers/chatCompletionsHandler.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { RouterError } from '../errors/RouterError';
import {
constructConfigFromRequestHeaders,
tryTargetsRecursively,
Expand Down Expand Up @@ -30,13 +31,21 @@ export async function chatCompletionsHandler(c: Context): Promise<Response> {
return tryTargetsResponse;
} catch (err: any) {
console.log('chatCompletion error', err.message);
let statusCode = 500;
let errorMessage = 'Something went wrong';

if (err instanceof RouterError) {
statusCode = 400;
errorMessage = err.message;
}

return new Response(
JSON.stringify({
status: 'failure',
message: 'Something went wrong',
message: errorMessage,
}),
{
status: 500,
status: statusCode,
headers: {
'content-type': 'application/json',
},
Expand Down
9 changes: 9 additions & 0 deletions src/handlers/completionsHandler.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { RouterError } from '../errors/RouterError';
import {
constructConfigFromRequestHeaders,
tryTargetsRecursively,
Expand Down Expand Up @@ -31,6 +32,14 @@ export async function completionsHandler(c: Context): Promise<Response> {
return tryTargetsResponse;
} catch (err: any) {
console.log('completion error', err.message);
let statusCode = 500;
let errorMessage = 'Something went wrong';

if (err instanceof RouterError) {
statusCode = 400;
errorMessage = err.message;
}

return new Response(
JSON.stringify({
status: 'failure',
Expand Down
11 changes: 10 additions & 1 deletion src/handlers/embeddingsHandler.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { RouterError } from '../errors/RouterError';
import {
constructConfigFromRequestHeaders,
tryTargetsRecursively,
Expand Down Expand Up @@ -30,7 +31,15 @@ export async function embeddingsHandler(c: Context): Promise<Response> {

return tryTargetsResponse;
} catch (err: any) {
console.log('completion error', err.message);
console.log('embeddings error', err.message);
let statusCode = 500;
let errorMessage = 'Something went wrong';

if (err instanceof RouterError) {
statusCode = 400;
errorMessage = err.message;
}

return new Response(
JSON.stringify({
status: 'failure',
Expand Down
38 changes: 35 additions & 3 deletions src/handlers/handlerUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@ import {
Options,
Params,
ShortConfig,
StrategyModes,
Targets,
} from '../types/requestBody';
import { convertKeysToCamelCase } from '../utils';
import { retryRequest } from './retryHandler';
import { env, getRuntimeKey } from 'hono/adapter';
import { afterRequestHookHandler, responseHandler } from './responseHandlers';
import { HookSpan, HooksManager } from '../middlewares/hooks';
import { ConditionalRouter } from '../services/conditionalRouter';
import { RouterError } from '../errors/RouterError';

/**
* Constructs the request options for the API call.
Expand Down Expand Up @@ -794,7 +797,7 @@ export async function tryTargetsRecursively(
let response;

switch (strategyMode) {
case 'fallback':
case StrategyModes.FALLBACK:
for (let [index, target] of currentTarget.targets.entries()) {
response = await tryTargetsRecursively(
c,
Expand All @@ -815,7 +818,7 @@ export async function tryTargetsRecursively(
}
break;

case 'loadbalance':
case StrategyModes.LOADBALANCE:
currentTarget.targets.forEach((t: Options) => {
if (t.weight === undefined) {
t.weight = 1;
Expand Down Expand Up @@ -846,7 +849,35 @@ export async function tryTargetsRecursively(
}
break;

case 'single':
case StrategyModes.CONDITIONAL:
let metadata: Record<string, string>;
try {
metadata = JSON.parse(requestHeaders[HEADER_KEYS.METADATA]);
} catch (err) {
metadata = {};
}
let conditionalRouter: ConditionalRouter;
let finalTarget: Targets;
try {
conditionalRouter = new ConditionalRouter(currentTarget, { metadata });
finalTarget = conditionalRouter.resolveTarget();
} catch (conditionalRouter: any) {
throw new RouterError(conditionalRouter.message);
}

response = await tryTargetsRecursively(
c,
finalTarget,
request,
requestHeaders,
fn,
method,
`${currentJsonPath}.targets[${finalTarget.index}]`,
currentInheritedConfig
);
break;

case StrategyModes.SINGLE:
response = await tryTargetsRecursively(
c,
currentTarget.targets[0],
Expand Down Expand Up @@ -1016,6 +1047,7 @@ export function constructConfigFromRequestHeaders(
'params',
'checks',
'vertex_service_account_json',
'conditions',
]) as any;
}

Expand Down
8 changes: 8 additions & 0 deletions src/handlers/imageGenerationsHandler.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { RouterError } from '../errors/RouterError';
import {
constructConfigFromRequestHeaders,
tryTargetsRecursively,
Expand Down Expand Up @@ -31,6 +32,13 @@ export async function imageGenerationsHandler(c: Context): Promise<Response> {
return tryTargetsResponse;
} catch (err: any) {
console.log('imageGenerate error', err.message);
let statusCode = 500;
let errorMessage = 'Something went wrong';

if (err instanceof RouterError) {
statusCode = 400;
errorMessage = err.message;
}
return new Response(
JSON.stringify({
status: 'failure',
Expand Down
18 changes: 15 additions & 3 deletions src/middlewares/requestValidator/schema/config.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { z } from 'zod';
import { any, z } from 'zod';
import { OLLAMA, VALID_PROVIDERS, GOOGLE_VERTEX_AI } from '../../../globals';

export const configSchema: any = z
Expand All @@ -8,13 +8,25 @@ export const configSchema: any = z
mode: z
.string()
.refine(
(value) => ['single', 'loadbalance', 'fallback'].includes(value),
(value) =>
['single', 'loadbalance', 'fallback', 'conditional'].includes(
value
),
{
message:
"Invalid 'mode' value. Must be one of: single, loadbalance, fallback",
"Invalid 'mode' value. Must be one of: single, loadbalance, fallback, conditional",
}
),
on_status_codes: z.array(z.number()).optional(),
conditions: z
.array(
z.object({
query: z.object({}),
then: z.string(),
})
)
.optional(),
default: z.string().optional(),
})
.optional(),
provider: z
Expand Down
152 changes: 152 additions & 0 deletions src/services/conditionalRouter.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import { StrategyModes, Targets } from '../types/requestBody';

type Query = {
[key: string]: any;
};

interface RouterContext {
metadata?: Record<string, string>;
}

enum Operator {
// Comparison Operators
Equal = '$eq',
NotEqual = '$ne',
GreaterThan = '$gt',
GreaterThanOrEqual = '$gte',
LessThan = '$lt',
LessThanOrEqual = '$lte',
In = '$in',
NotIn = '$nin',
Regex = '$regex',

// Logical Operators
And = '$and',
Or = '$or',
}

export class ConditionalRouter {
private config: Targets;
private context: RouterContext;

constructor(config: Targets, context: RouterContext) {
this.config = config;
this.context = context;
if (this.config.strategy?.mode !== StrategyModes.CONDITIONAL) {
throw new Error('Unsupported strategy mode');
}
}

resolveTarget(): Targets {
if (!this.config.strategy?.conditions) {
throw new Error('No conditions passed in the query router');
}

for (const condition of this.config.strategy.conditions) {
if (this.evaluateQuery(condition.query)) {
const targetName = condition.then;
return this.findTarget(targetName);
}
}

// If no conditions matched and a default is specified, return the default target
if (this.config.strategy.default) {
return this.findTarget(this.config.strategy.default);
}

throw new Error('Query router did not resolve to any valid target');
}

private evaluateQuery(query: Query): boolean {
for (const [key, value] of Object.entries(query)) {
if (key === Operator.Or && Array.isArray(value)) {
return value.some((subCondition: Query) =>
this.evaluateQuery(subCondition)
);
}

if (key === Operator.And && Array.isArray(value)) {
return value.every((subCondition: Query) =>
this.evaluateQuery(subCondition)
);
}

const metadataValue = this.getContextValue(key);

if (typeof value === 'object' && value !== null) {
if (!this.evaluateOperator(value, metadataValue)) {
return false;
}
} else if (metadataValue !== value) {
return false;
}
}

return true;
}

private evaluateOperator(operator: string, value: any): boolean {
for (const [op, compareValue] of Object.entries(operator)) {
switch (op) {
case Operator.Equal:
if (value !== compareValue) return false;
break;
case Operator.NotEqual:
if (value === compareValue) return false;
break;
case Operator.GreaterThan:
if (!(parseFloat(value) > parseFloat(compareValue))) return false;
break;
case Operator.GreaterThanOrEqual:
if (!(parseFloat(value) >= parseFloat(compareValue))) return false;
break;
case Operator.LessThan:
if (!(parseFloat(value) < parseFloat(compareValue))) return false;
break;
case Operator.LessThanOrEqual:
if (!(parseFloat(value) <= parseFloat(compareValue))) return false;
break;
case Operator.In:
if (!Array.isArray(compareValue) || !compareValue.includes(value))
return false;
break;
case Operator.NotIn:
if (!Array.isArray(compareValue) || compareValue.includes(value))
return false;
break;
case Operator.Regex:
try {
const regex = new RegExp(compareValue);
return value.test(regex);
} catch (e) {
return false;
}
default:
throw new Error(
`Unsupported operator used in the query router: ${op}`
);
}
}
return true;
}

private findTarget(name: string): Targets {
const index =
this.config.targets?.findIndex((target) => target.name === name) ?? -1;
if (index === -1) {
throw new Error(`Invalid target name found in the query router: ${name}`);
}

return {
...this.config.targets?.[index],
index,
};
}

private getContextValue(key: string): any {
const parts = key.split('.');
let value: any = this.context;
value = value[parts[0]]?.[parts[1]];
return value;
}
}
Loading
Loading