Skip to content

Commit

Permalink
refactor(js/core): manage als instances centrally in the registry (#1490
Browse files Browse the repository at this point in the history
)
  • Loading branch information
pavelgj authored Dec 12, 2024
1 parent 6fbe98e commit 540d3b7
Show file tree
Hide file tree
Showing 25 changed files with 425 additions and 230 deletions.
157 changes: 84 additions & 73 deletions js/ai/src/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -133,43 +133,47 @@ export class Chat {
>(
options: string | Part[] | ChatGenerateOptions<O, CustomOptions>
): Promise<GenerateResponse<z.infer<O>>> {
return runWithSession(this.session, () =>
runInNewSpan({ metadata: { name: 'send' } }, async () => {
let resolvedOptions;
let streamingCallback = undefined;

// string
if (typeof options === 'string') {
resolvedOptions = {
prompt: options,
} as ChatGenerateOptions<O, CustomOptions>;
} else if (Array.isArray(options)) {
// Part[]
resolvedOptions = {
prompt: options,
} as ChatGenerateOptions<O, CustomOptions>;
} else {
resolvedOptions = options as ChatGenerateOptions<O, CustomOptions>;
streamingCallback = resolvedOptions.streamingCallback;
return runWithSession(this.session.registry, this.session, () =>
runInNewSpan(
this.session.registry,
{ metadata: { name: 'send' } },
async () => {
let resolvedOptions;
let streamingCallback = undefined;

// string
if (typeof options === 'string') {
resolvedOptions = {
prompt: options,
} as ChatGenerateOptions<O, CustomOptions>;
} else if (Array.isArray(options)) {
// Part[]
resolvedOptions = {
prompt: options,
} as ChatGenerateOptions<O, CustomOptions>;
} else {
resolvedOptions = options as ChatGenerateOptions<O, CustomOptions>;
streamingCallback = resolvedOptions.streamingCallback;
}
let request: GenerateOptions = {
...(await this.requestBase),
messages: this.messages,
...resolvedOptions,
};
let response = await generate(this.session.registry, {
...request,
streamingCallback,
});
this.requestBase = Promise.resolve({
...(await this.requestBase),
// these things may get changed by tools calling within generate.
tools: response?.request?.tools,
config: response?.request?.config,
});
await this.updateMessages(response.messages);
return response;
}
let request: GenerateOptions = {
...(await this.requestBase),
messages: this.messages,
...resolvedOptions,
};
let response = await generate(this.session.registry, {
...request,
streamingCallback,
});
this.requestBase = Promise.resolve({
...(await this.requestBase),
// these things may get changed by tools calling within generate.
tools: response?.request?.tools,
config: response?.request?.config,
});
await this.updateMessages(response.messages);
return response;
})
)
);
}

Expand All @@ -179,47 +183,54 @@ export class Chat {
>(
options: string | Part[] | GenerateStreamOptions<O, CustomOptions>
): Promise<GenerateStreamResponse<z.infer<O>>> {
return runWithSession(this.session, () =>
runInNewSpan({ metadata: { name: 'send' } }, async () => {
let resolvedOptions;

// string
if (typeof options === 'string') {
resolvedOptions = {
prompt: options,
} as GenerateStreamOptions<O, CustomOptions>;
} else if (Array.isArray(options)) {
// Part[]
resolvedOptions = {
prompt: options,
} as GenerateStreamOptions<O, CustomOptions>;
} else {
resolvedOptions = options as GenerateStreamOptions<O, CustomOptions>;
}
return runWithSession(this.session.registry, this.session, () =>
runInNewSpan(
this.session.registry,
{ metadata: { name: 'send' } },
async () => {
let resolvedOptions;

const { response, stream } = await generateStream(
this.session.registry,
{
...(await this.requestBase),
messages: this.messages,
...resolvedOptions,
// string
if (typeof options === 'string') {
resolvedOptions = {
prompt: options,
} as GenerateStreamOptions<O, CustomOptions>;
} else if (Array.isArray(options)) {
// Part[]
resolvedOptions = {
prompt: options,
} as GenerateStreamOptions<O, CustomOptions>;
} else {
resolvedOptions = options as GenerateStreamOptions<
O,
CustomOptions
>;
}
);

return {
response: response.finally(async () => {
const resolvedResponse = await response;
this.requestBase = Promise.resolve({
const { response, stream } = await generateStream(
this.session.registry,
{
...(await this.requestBase),
// these things may get changed by tools calling within generate.
tools: resolvedResponse?.request?.tools,
config: resolvedResponse?.request?.config,
});
this.updateMessages(resolvedResponse.messages);
}),
stream,
};
})
messages: this.messages,
...resolvedOptions,
}
);

return {
response: response.finally(async () => {
const resolvedResponse = await response;
this.requestBase = Promise.resolve({
...(await this.requestBase),
// these things may get changed by tools calling within generate.
tools: resolvedResponse?.request?.tools,
config: resolvedResponse?.request?.config,
});
this.updateMessages(resolvedResponse.messages);
}),
stream,
};
}
)
);
}

Expand Down
1 change: 1 addition & 0 deletions js/ai/src/evaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ export function defineEvaluator<
};
try {
await runInNewSpan(
registry,
{
metadata: {
name: `Test Case ${datapoint.testCaseId}`,
Expand Down
1 change: 1 addition & 0 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ export async function generate<
};

return await runWithStreamingCallback(
registry,
resolvedOptions.streamingCallback,
async () => {
const response = await generateHelper(
Expand Down
4 changes: 3 additions & 1 deletion js/ai/src/generate/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ export async function generateHelper(
): Promise<GenerateResponseData> {
// do tracing
return await runInNewSpan(
registry,
{
metadata: {
name: 'generate',
Expand Down Expand Up @@ -139,8 +140,9 @@ async function generate(

const accumulatedChunks: GenerateResponseChunkData[] = [];

const streamingCallback = getStreamingCallback();
const streamingCallback = getStreamingCallback(registry);
const response = await runWithStreamingCallback(
registry,
streamingCallback
? (chunk: GenerateResponseChunkData) => {
// Store accumulated chunk data
Expand Down
2 changes: 1 addition & 1 deletion js/ai/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ export function defineModel<
(input) => {
const startTimeMs = performance.now();

return runner(input, getStreamingCallback()).then((response) => {
return runner(input, getStreamingCallback(registry)).then((response) => {
const timedResponse = {
...response,
latencyMs: performance.now() - startTimeMs,
Expand Down
16 changes: 9 additions & 7 deletions js/ai/src/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import { z } from '@genkit-ai/core';
import { Registry } from '@genkit-ai/core/registry';
import { AsyncLocalStorage } from 'node:async_hooks';
import { v4 as uuidv4 } from 'uuid';
import { Chat, ChatOptions, MAIN_THREAD, PromptRenderOptions } from './chat';
import {
Expand Down Expand Up @@ -192,7 +191,7 @@ export class Session<S = any> {
maybeOptionsOrPreamble?: ChatOptions<I, S> | ExecutablePrompt<I>,
maybeOptions?: ChatOptions<I, S>
): Chat {
return runWithSession(this, () => {
return runWithSession(this.registry, this, () => {
let options: ChatOptions<S> | undefined;
let threadName = MAIN_THREAD;
let preamble: ExecutablePrompt<I> | undefined;
Expand Down Expand Up @@ -266,7 +265,7 @@ export class Session<S = any> {
* `ai.currentSession().state`
*/
run<O>(fn: () => O) {
return runWithSession(this, fn);
return runWithSession(this.registry, this, fn);
}

toJSON() {
Expand All @@ -280,21 +279,24 @@ export interface SessionData<S = any> {
threads?: Record<string, MessageData[]>;
}

const sessionAls = new AsyncLocalStorage<Session<any>>();
const sessionAlsKey = 'ai.session';

/**
* Executes provided function within the provided session state.
*/
export function runWithSession<S = any, O = any>(
registry: Registry,
session: Session<S>,
fn: () => O
): O {
return sessionAls.run(session, fn);
return registry.asyncStore.run(sessionAlsKey, session, fn);
}

/** Returns the current session. */
export function getCurrentSession<S = any>(): Session<S> | undefined {
return sessionAls.getStore();
export function getCurrentSession<S = any>(
registry: Registry
): Session<S> | undefined {
return registry.asyncStore.getStore(sessionAlsKey);
}

/** Throw when session state errors occur, ex. missing state, etc. */
Expand Down
72 changes: 38 additions & 34 deletions js/ai/src/testing/model-tester.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,44 +170,48 @@ export async function testModels(
}
);

return await runInNewSpan({ metadata: { name: 'testModels' } }, async () => {
const report: TestReport = [];
for (const test of Object.keys(tests)) {
await runInNewSpan({ metadata: { name: test } }, async () => {
report.push({
description: test,
models: [],
});
const caseReport = report[report.length - 1];
for (const model of models) {
caseReport.models.push({
name: model,
passed: true, // optimistically
return await runInNewSpan(
registry,
{ metadata: { name: 'testModels' } },
async () => {
const report: TestReport = [];
for (const test of Object.keys(tests)) {
await runInNewSpan(registry, { metadata: { name: test } }, async () => {
report.push({
description: test,
models: [],
});
const modelReport = caseReport.models[caseReport.models.length - 1];
try {
await tests[test](registry, model);
} catch (e) {
modelReport.passed = false;
if (e instanceof SkipTestError) {
modelReport.skipped = true;
} else if (e instanceof Error) {
modelReport.error = {
message: e.message,
stack: e.stack,
};
} else {
modelReport.error = {
message: `${e}`,
};
const caseReport = report[report.length - 1];
for (const model of models) {
caseReport.models.push({
name: model,
passed: true, // optimistically
});
const modelReport = caseReport.models[caseReport.models.length - 1];
try {
await tests[test](registry, model);
} catch (e) {
modelReport.passed = false;
if (e instanceof SkipTestError) {
modelReport.skipped = true;
} else if (e instanceof Error) {
modelReport.error = {
message: e.message,
stack: e.stack,
};
} else {
modelReport.error = {
message: `${e}`,
};
}
}
}
}
});
}
});
}

return report;
});
return report;
}
);
}

class SkipTestError extends Error {}
Expand Down
5 changes: 3 additions & 2 deletions js/ai/src/tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,15 @@ export type ToolArgument<
* Converts an action to a tool action by setting the appropriate metadata.
*/
export function asTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
registry: Registry,
action: Action<I, O>
): ToolAction<I, O> {
if (action.__action?.metadata?.type === 'tool') {
return action as ToolAction<I, O>;
}

const fn = ((input) => {
setCustomMetadataAttributes({ subtype: 'tool' });
setCustomMetadataAttributes(registry, { subtype: 'tool' });
return action(input);
}) as ToolAction<I, O>;
fn.__action = {
Expand All @@ -105,7 +106,7 @@ export async function resolveTools<
if (typeof ref === 'string') {
return await lookupToolByName(registry, ref);
} else if ((ref as Action).__action) {
return asTool(ref as Action);
return asTool(registry, ref as Action);
} else if (typeof (ref as ExecutablePrompt).asTool === 'function') {
return await (ref as ExecutablePrompt).asTool();
} else if (ref.name) {
Expand Down
Loading

0 comments on commit 540d3b7

Please sign in to comment.