Skip to content

Commit

Permalink
fix(js/core): correctly handle errors when streaming (#1477)
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelgj authored Dec 11, 2024
1 parent 7255cd6 commit afcd77b
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 38 deletions.
24 changes: 20 additions & 4 deletions genkit-tools/common/src/manager/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
RunActionResponseSchema,
} from '../types/action';
import * as apis from '../types/apis';
import { GenkitErrorData } from '../types/error';
import { TraceData } from '../types/trace';
import { logger } from '../utils/logger';
import {
Expand Down Expand Up @@ -198,9 +199,24 @@ export class RuntimeManager {
rejecter = reject;
});
stream.on('end', () => {
const actionResponse = RunActionResponseSchema.parse(
JSON.parse(buffer)
);
const parsedBuffer = JSON.parse(buffer);
if (parsedBuffer.error) {
const err = new GenkitToolsError(
`Error running action key='${input.key}'.`
);
// massage the error into a shape dev ui expects
err.data = {
...parsedBuffer.error,
stack: (parsedBuffer.error?.details as any).stack,
data: {
genkitErrorMessage: parsedBuffer.error?.message,
genkitErrorDetails: parsedBuffer.error?.details,
},
};
rejecter(err);
return;
}
const actionResponse = RunActionResponseSchema.parse(parsedBuffer);
if (genkitVersion) {
actionResponse.genkitVersion = genkitVersion;
}
Expand Down Expand Up @@ -392,7 +408,7 @@ export class RuntimeManager {
newError.message = (error.response?.data as any).message;
}
// we got a non-200 response; copy the payload and rethrow
newError.data = error.response.data as Record<string, unknown>;
newError.data = error.response.data as GenkitErrorData;
throw newError;
}

Expand Down
4 changes: 3 additions & 1 deletion genkit-tools/common/src/manager/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
* limitations under the License.
*/

import { GenkitErrorData } from '../types/error';

export type Runtime = 'nodejs' | 'go' | undefined;

export class GenkitToolsError extends Error {
public data?: Record<string, unknown>;
public data?: GenkitErrorData;

constructor(msg: string, options?: ErrorOptions) {
super(msg, options);
Expand Down
6 changes: 2 additions & 4 deletions genkit-tools/common/src/server/router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,8 @@ const t = initTRPC.create({
...shape,
data: {
...shape.data,
genkitErrorMessage: (error.cause.data as Record<string, unknown>)
.message,
genkitErrorDetails: (error.cause.data as Record<string, unknown>)
.details,
genkitErrorMessage: error.cause.data.message,
genkitErrorDetails: error.cause.data.details,
},
};
}
Expand Down
16 changes: 12 additions & 4 deletions genkit-tools/common/src/server/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import express, { ErrorRequestHandler } from 'express';
import { Server } from 'http';
import os from 'os';
import path from 'path';
import { GenkitToolsError } from '../manager';
import { RuntimeManager } from '../manager/manager';
import { logger } from '../utils/logger';
import { toolsPackage } from '../utils/package';
Expand Down Expand Up @@ -72,10 +73,17 @@ export function startServer(manager: RuntimeManager, port: number) {
'Transfer-Encoding': 'chunked',
});

const result = await manager.runAction({ key, input, context }, (chunk) => {
res.write(JSON.stringify(chunk) + '\n');
});
res.write(JSON.stringify(result));
try {
const result = await manager.runAction(
{ key, input, context },
(chunk) => {
res.write(JSON.stringify(chunk) + '\n');
}
);
res.write(JSON.stringify(result));
} catch (err) {
res.write(JSON.stringify({ error: (err as GenkitToolsError).data }));
}
res.end();
});

Expand Down
28 changes: 28 additions & 0 deletions genkit-tools/common/src/types/error.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

export interface GenkitErrorData {
message: string;
stack?: string;
details?: any;
data?: {
genkitErrorMessage?: string;
genkitErrorDetails?: {
stack?: string;
traceId: string;
};
};
}
1 change: 1 addition & 0 deletions genkit-tools/common/src/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

export { RuntimeEvent, RuntimeInfo } from '../manager/types';
export { GenkitErrorData } from '../types/error';
export * from './action';
export * from './analytics';
export * from './apis';
Expand Down
19 changes: 13 additions & 6 deletions js/core/src/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,20 @@ export function action<
metadata.name = actionName;
metadata.input = input;

const output = await fn(input, {
context: options?.context,
sendChunk: options?.onChunk ?? ((c) => {}),
});
try {
const output = await fn(input, {
context: options?.context,
sendChunk: options?.onChunk ?? ((c) => {}),
});

metadata.output = JSON.stringify(output);
return output;
metadata.output = JSON.stringify(output);
return output;
} catch (err) {
if (typeof err === 'object') {
(err as any).traceId = traceId;
}
throw err;
}
}
);
output = parseSchema(output, {
Expand Down
55 changes: 37 additions & 18 deletions js/core/src/reflection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,31 +156,50 @@ export class ReflectionServer {
const { key, input, context, telemetryLabels } = request.body;
const { stream } = request.query;
logger.debug(`Running action \`${key}\` with stream=${stream}...`);
let traceId;
try {
const action = await this.registry.lookupAction(key);
if (!action) {
response.status(404).send(`action ${key} not found`);
return;
}
if (stream === 'true') {
const callback = (chunk) => {
response.write(JSON.stringify(chunk) + '\n');
};
const result = await runWithStreamingCallback(
callback,
async () => await action.run(input, { context, onChunk: callback })
);
await flushTracing();
response.write(
JSON.stringify({
result: result.result,
telemetry: {
traceId: result.telemetry.traceId,
try {
const callback = (chunk) => {
response.write(JSON.stringify(chunk) + '\n');
};
const result = await runWithStreamingCallback(callback, () =>
action.run(input, { context, onChunk: callback })
);
await flushTracing();
response.write(
JSON.stringify({
result: result.result,
telemetry: {
traceId: result.telemetry.traceId,
},
} as RunActionResponse)
);
response.end();
} catch (err) {
const { message, stack } = err as Error;
// since we're streaming, we must do special error handling here -- the headers are already sent.
const errorResponse: Status = {
code: StatusCodes.INTERNAL,
message,
details: {
stack,
},
} as RunActionResponse)
);
response.end();
};
if ((err as any).traceId) {
errorResponse.details.traceId = (err as any).traceId;
}
response.write(
JSON.stringify({
error: errorResponse,
} as RunActionResponse)
);
response.end();
}
} else {
const result = await action.run(input, { context, telemetryLabels });
await flushTracing();
Expand All @@ -192,7 +211,7 @@ export class ReflectionServer {
} as RunActionResponse);
}
} catch (err) {
const { message, stack } = err as Error;
const { message, stack, traceId } = err as any;
next({ message, stack, traceId });
}
});
Expand Down
31 changes: 30 additions & 1 deletion js/testapps/flow-simple-ai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ import { GoogleAIFileManager } from '@google/generative-ai/server';
import { AlwaysOnSampler } from '@opentelemetry/sdk-trace-base';
import { initializeApp } from 'firebase-admin/app';
import { getFirestore } from 'firebase-admin/firestore';
import { MessageSchema, genkit, run, z } from 'genkit';
import { GenerateResponseData, MessageSchema, genkit, run, z } from 'genkit';
import { logger } from 'genkit/logging';
import { ModelMiddleware } from 'genkit/model';
import { PluginProvider } from 'genkit/plugin';
import { Allow, parse } from 'partial-json';

Expand Down Expand Up @@ -580,3 +581,31 @@ ai.defineFlow(
return text;
}
);

ai.defineModel(
{
name: 'hiModel',
},
async (request, streamingCallback) => {
return {
finishReason: 'stop',
message: { role: 'model', content: [{ text: 'hi' }] },
};
}
);

const blockingMiddleware: ModelMiddleware = async (req, next) => {
return {
finishReason: 'blocked',
finishMessage: `Model input violated policies: further processing blocked.`,
} as GenerateResponseData;
};

ai.defineFlow('blockingMiddleware', async () => {
const { text } = await ai.generate({
prompt: 'hi',
model: 'hiModel',
use: [blockingMiddleware],
});
return text;
});

0 comments on commit afcd77b

Please sign in to comment.