diff --git a/js/ai/src/chat.ts b/js/ai/src/chat.ts index 437e9887e..c565b0045 100644 --- a/js/ai/src/chat.ts +++ b/js/ai/src/chat.ts @@ -133,43 +133,47 @@ export class Chat { >( options: string | Part[] | ChatGenerateOptions ): Promise>> { - return runWithSession(this.session, () => - runInNewSpan({ metadata: { name: 'send' } }, async () => { - let resolvedOptions; - let streamingCallback = undefined; - - // string - if (typeof options === 'string') { - resolvedOptions = { - prompt: options, - } as ChatGenerateOptions; - } else if (Array.isArray(options)) { - // Part[] - resolvedOptions = { - prompt: options, - } as ChatGenerateOptions; - } else { - resolvedOptions = options as ChatGenerateOptions; - streamingCallback = resolvedOptions.streamingCallback; + return runWithSession(this.session.registry, this.session, () => + runInNewSpan( + this.session.registry, + { metadata: { name: 'send' } }, + async () => { + let resolvedOptions; + let streamingCallback = undefined; + + // string + if (typeof options === 'string') { + resolvedOptions = { + prompt: options, + } as ChatGenerateOptions; + } else if (Array.isArray(options)) { + // Part[] + resolvedOptions = { + prompt: options, + } as ChatGenerateOptions; + } else { + resolvedOptions = options as ChatGenerateOptions; + streamingCallback = resolvedOptions.streamingCallback; + } + let request: GenerateOptions = { + ...(await this.requestBase), + messages: this.messages, + ...resolvedOptions, + }; + let response = await generate(this.session.registry, { + ...request, + streamingCallback, + }); + this.requestBase = Promise.resolve({ + ...(await this.requestBase), + // these things may get changed by tools calling within generate. + tools: response?.request?.tools, + config: response?.request?.config, + }); + await this.updateMessages(response.messages); + return response; } - let request: GenerateOptions = { - ...(await this.requestBase), - messages: this.messages, - ...resolvedOptions, - }; - let response = await generate(this.session.registry, { - ...request, - streamingCallback, - }); - this.requestBase = Promise.resolve({ - ...(await this.requestBase), - // these things may get changed by tools calling within generate. - tools: response?.request?.tools, - config: response?.request?.config, - }); - await this.updateMessages(response.messages); - return response; - }) + ) ); } @@ -179,47 +183,54 @@ export class Chat { >( options: string | Part[] | GenerateStreamOptions ): Promise>> { - return runWithSession(this.session, () => - runInNewSpan({ metadata: { name: 'send' } }, async () => { - let resolvedOptions; - - // string - if (typeof options === 'string') { - resolvedOptions = { - prompt: options, - } as GenerateStreamOptions; - } else if (Array.isArray(options)) { - // Part[] - resolvedOptions = { - prompt: options, - } as GenerateStreamOptions; - } else { - resolvedOptions = options as GenerateStreamOptions; - } + return runWithSession(this.session.registry, this.session, () => + runInNewSpan( + this.session.registry, + { metadata: { name: 'send' } }, + async () => { + let resolvedOptions; - const { response, stream } = await generateStream( - this.session.registry, - { - ...(await this.requestBase), - messages: this.messages, - ...resolvedOptions, + // string + if (typeof options === 'string') { + resolvedOptions = { + prompt: options, + } as GenerateStreamOptions; + } else if (Array.isArray(options)) { + // Part[] + resolvedOptions = { + prompt: options, + } as GenerateStreamOptions; + } else { + resolvedOptions = options as GenerateStreamOptions< + O, + CustomOptions + >; } - ); - return { - response: response.finally(async () => { - const resolvedResponse = await response; - this.requestBase = Promise.resolve({ + const { response, stream } = await generateStream( + this.session.registry, + { ...(await this.requestBase), - // these things may get changed by tools calling within generate. - tools: resolvedResponse?.request?.tools, - config: resolvedResponse?.request?.config, - }); - this.updateMessages(resolvedResponse.messages); - }), - stream, - }; - }) + messages: this.messages, + ...resolvedOptions, + } + ); + + return { + response: response.finally(async () => { + const resolvedResponse = await response; + this.requestBase = Promise.resolve({ + ...(await this.requestBase), + // these things may get changed by tools calling within generate. + tools: resolvedResponse?.request?.tools, + config: resolvedResponse?.request?.config, + }); + this.updateMessages(resolvedResponse.messages); + }), + stream, + }; + } + ) ); } diff --git a/js/ai/src/evaluator.ts b/js/ai/src/evaluator.ts index 65b75e3e4..64ef67434 100644 --- a/js/ai/src/evaluator.ts +++ b/js/ai/src/evaluator.ts @@ -175,6 +175,7 @@ export function defineEvaluator< }; try { await runInNewSpan( + registry, { metadata: { name: `Test Case ${datapoint.testCaseId}`, diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index e6d1bb2af..1478405bd 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -278,6 +278,7 @@ export async function generate< }; return await runWithStreamingCallback( + registry, resolvedOptions.streamingCallback, async () => { const response = await generateHelper( diff --git a/js/ai/src/generate/action.ts b/js/ai/src/generate/action.ts index 48d6b6e9b..bea20b728 100644 --- a/js/ai/src/generate/action.ts +++ b/js/ai/src/generate/action.ts @@ -83,6 +83,7 @@ export async function generateHelper( ): Promise { // do tracing return await runInNewSpan( + registry, { metadata: { name: 'generate', @@ -139,8 +140,9 @@ async function generate( const accumulatedChunks: GenerateResponseChunkData[] = []; - const streamingCallback = getStreamingCallback(); + const streamingCallback = getStreamingCallback(registry); const response = await runWithStreamingCallback( + registry, streamingCallback ? (chunk: GenerateResponseChunkData) => { // Store accumulated chunk data diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index c0eb0b9c8..a1693b8ec 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -372,7 +372,7 @@ export function defineModel< (input) => { const startTimeMs = performance.now(); - return runner(input, getStreamingCallback()).then((response) => { + return runner(input, getStreamingCallback(registry)).then((response) => { const timedResponse = { ...response, latencyMs: performance.now() - startTimeMs, diff --git a/js/ai/src/session.ts b/js/ai/src/session.ts index 6e84c3039..2d5af2fd9 100644 --- a/js/ai/src/session.ts +++ b/js/ai/src/session.ts @@ -16,7 +16,6 @@ import { z } from '@genkit-ai/core'; import { Registry } from '@genkit-ai/core/registry'; -import { AsyncLocalStorage } from 'node:async_hooks'; import { v4 as uuidv4 } from 'uuid'; import { Chat, ChatOptions, MAIN_THREAD, PromptRenderOptions } from './chat'; import { @@ -192,7 +191,7 @@ export class Session { maybeOptionsOrPreamble?: ChatOptions | ExecutablePrompt, maybeOptions?: ChatOptions ): Chat { - return runWithSession(this, () => { + return runWithSession(this.registry, this, () => { let options: ChatOptions | undefined; let threadName = MAIN_THREAD; let preamble: ExecutablePrompt | undefined; @@ -266,7 +265,7 @@ export class Session { * `ai.currentSession().state` */ run(fn: () => O) { - return runWithSession(this, fn); + return runWithSession(this.registry, this, fn); } toJSON() { @@ -280,21 +279,24 @@ export interface SessionData { threads?: Record; } -const sessionAls = new AsyncLocalStorage>(); +const sessionAlsKey = 'ai.session'; /** * Executes provided function within the provided session state. */ export function runWithSession( + registry: Registry, session: Session, fn: () => O ): O { - return sessionAls.run(session, fn); + return registry.asyncStore.run(sessionAlsKey, session, fn); } /** Returns the current session. */ -export function getCurrentSession(): Session | undefined { - return sessionAls.getStore(); +export function getCurrentSession( + registry: Registry +): Session | undefined { + return registry.asyncStore.getStore(sessionAlsKey); } /** Throw when session state errors occur, ex. missing state, etc. */ diff --git a/js/ai/src/testing/model-tester.ts b/js/ai/src/testing/model-tester.ts index 7caa4b0cc..626a4c3af 100644 --- a/js/ai/src/testing/model-tester.ts +++ b/js/ai/src/testing/model-tester.ts @@ -170,44 +170,48 @@ export async function testModels( } ); - return await runInNewSpan({ metadata: { name: 'testModels' } }, async () => { - const report: TestReport = []; - for (const test of Object.keys(tests)) { - await runInNewSpan({ metadata: { name: test } }, async () => { - report.push({ - description: test, - models: [], - }); - const caseReport = report[report.length - 1]; - for (const model of models) { - caseReport.models.push({ - name: model, - passed: true, // optimistically + return await runInNewSpan( + registry, + { metadata: { name: 'testModels' } }, + async () => { + const report: TestReport = []; + for (const test of Object.keys(tests)) { + await runInNewSpan(registry, { metadata: { name: test } }, async () => { + report.push({ + description: test, + models: [], }); - const modelReport = caseReport.models[caseReport.models.length - 1]; - try { - await tests[test](registry, model); - } catch (e) { - modelReport.passed = false; - if (e instanceof SkipTestError) { - modelReport.skipped = true; - } else if (e instanceof Error) { - modelReport.error = { - message: e.message, - stack: e.stack, - }; - } else { - modelReport.error = { - message: `${e}`, - }; + const caseReport = report[report.length - 1]; + for (const model of models) { + caseReport.models.push({ + name: model, + passed: true, // optimistically + }); + const modelReport = caseReport.models[caseReport.models.length - 1]; + try { + await tests[test](registry, model); + } catch (e) { + modelReport.passed = false; + if (e instanceof SkipTestError) { + modelReport.skipped = true; + } else if (e instanceof Error) { + modelReport.error = { + message: e.message, + stack: e.stack, + }; + } else { + modelReport.error = { + message: `${e}`, + }; + } } } - } - }); - } + }); + } - return report; - }); + return report; + } + ); } class SkipTestError extends Error {} diff --git a/js/ai/src/tool.ts b/js/ai/src/tool.ts index a56f9ad59..33a9d4645 100644 --- a/js/ai/src/tool.ts +++ b/js/ai/src/tool.ts @@ -72,6 +72,7 @@ export type ToolArgument< * Converts an action to a tool action by setting the appropriate metadata. */ export function asTool( + registry: Registry, action: Action ): ToolAction { if (action.__action?.metadata?.type === 'tool') { @@ -79,7 +80,7 @@ export function asTool( } const fn = ((input) => { - setCustomMetadataAttributes({ subtype: 'tool' }); + setCustomMetadataAttributes(registry, { subtype: 'tool' }); return action(input); }) as ToolAction; fn.__action = { @@ -105,7 +106,7 @@ export async function resolveTools< if (typeof ref === 'string') { return await lookupToolByName(registry, ref); } else if ((ref as Action).__action) { - return asTool(ref as Action); + return asTool(registry, ref as Action); } else if (typeof (ref as ExecutablePrompt).asTool === 'function') { return await (ref as ExecutablePrompt).asTool(); } else if (ref.name) { diff --git a/js/core/src/action.ts b/js/core/src/action.ts index 3428370e3..19c28784e 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -15,7 +15,6 @@ */ import { JSONSchema7 } from 'json-schema'; -import { AsyncLocalStorage } from 'node:async_hooks'; import * as z from 'zod'; import { ActionType, Registry } from './registry.js'; import { parseSchema } from './schema.js'; @@ -105,6 +104,7 @@ export type Action< options?: ActionRunOptions ) => Promise>) & { __action: ActionMetadata; + __registry: Registry; run( input: z.infer, options?: ActionRunOptions> @@ -114,7 +114,7 @@ export type Action< /** * Action factory params. */ -type ActionParams< +export type ActionParams< I extends z.ZodTypeAny, O extends z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, @@ -169,6 +169,7 @@ export function actionWithMiddleware< return (await wrapped.run(req)).result; }) as Action; wrapped.__action = action.__action; + wrapped.__registry = action.__registry; wrapped.run = async ( req: z.infer, options?: ActionRunOptions> @@ -217,6 +218,7 @@ export function action< O extends z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, config: ActionParams, fn: ( input: z.infer, @@ -230,6 +232,7 @@ export function action< const actionFn = async (input: I, options?: ActionRunOptions>) => { return (await actionFn.run(input, options)).result; }; + actionFn.__registry = registry; actionFn.__action = { name: actionName, description: config.description, @@ -250,6 +253,7 @@ export function action< let traceId; let spanId; let output = await newTrace( + registry, { name: actionName, labels: { @@ -259,9 +263,9 @@ export function action< }, }, async (metadata, span) => { - setCustomMetadataAttributes({ subtype: config.actionType }); + setCustomMetadataAttributes(registry, { subtype: config.actionType }); if (options?.context) { - setCustomMetadataAttributes({ + setCustomMetadataAttributes(registry, { context: JSON.stringify(options.context), }); } @@ -345,7 +349,7 @@ export function defineAction< options: ActionFnArg> ) => Promise> ): Action { - if (isInRuntimeContext()) { + if (isInRuntimeContext(registry)) { throw new Error( 'Cannot define new actions at runtime.\n' + 'See: https://github.com/firebase/genkit/blob/main/docs/errors/no_new_actions_at_runtime.md' @@ -356,10 +360,14 @@ export function defineAction< } else { validateActionId(config.name.actionId); } - const act = action(config, async (i: I, options): Promise> => { - await registry.initializeAllPlugins(); - return await runInActionRuntimeContext(() => fn(i, options)); - }); + const act = action( + registry, + config, + async (i: I, options): Promise> => { + await registry.initializeAllPlugins(); + return await runInActionRuntimeContext(registry, () => fn(i, options)); + } + ); act.__action.actionType = config.actionType; registry.registerAction(config.actionType, act); return act; @@ -368,7 +376,7 @@ export function defineAction< // Streaming callback function. export type StreamingCallback = (chunk: T) => void; -const streamingAls = new AsyncLocalStorage>(); +const streamingAlsKey = 'core.action.streamingCallback'; const sentinelNoopCallback = () => null; /** @@ -376,35 +384,43 @@ const sentinelNoopCallback = () => null; * using {@link getStreamingCallback}. */ export function runWithStreamingCallback( + registry: Registry, streamingCallback: StreamingCallback | undefined, fn: () => O ): O { - return streamingAls.run(streamingCallback || sentinelNoopCallback, fn); + return registry.asyncStore.run( + streamingAlsKey, + streamingCallback || sentinelNoopCallback, + fn + ); } /** * Retrieves the {@link StreamingCallback} previously set by {@link runWithStreamingCallback} */ -export function getStreamingCallback(): StreamingCallback | undefined { - const cb = streamingAls.getStore(); +export function getStreamingCallback( + registry: Registry +): StreamingCallback | undefined { + const cb = + registry.asyncStore.getStore>(streamingAlsKey); if (cb === sentinelNoopCallback) { return undefined; } return cb; } -const runtimeCtxAls = new AsyncLocalStorage(); +const runtimeContextAslKey = 'core.action.runtimeContext'; /** * Checks whether the caller is currently in the runtime context of an action. */ -export function isInRuntimeContext() { - return !!runtimeCtxAls.getStore(); +export function isInRuntimeContext(registry: Registry) { + return !!registry.asyncStore.getStore(runtimeContextAslKey); } /** * Execute the provided function in the action runtime context. */ -export function runInActionRuntimeContext(fn: () => R) { - return runtimeCtxAls.run('runtime', fn); +export function runInActionRuntimeContext(registry: Registry, fn: () => R) { + return registry.asyncStore.run(runtimeContextAslKey, 'runtime', fn); } diff --git a/js/core/src/auth.ts b/js/core/src/auth.ts index 753be5153..769865593 100644 --- a/js/core/src/auth.ts +++ b/js/core/src/auth.ts @@ -16,16 +16,24 @@ import { AsyncLocalStorage } from 'node:async_hooks'; import { runInActionRuntimeContext } from './action.js'; +import { HasRegistry, Registry } from './registry.js'; -const contextAsyncLocalStorage = new AsyncLocalStorage(); +const contextAlsKey = 'core.auth.context'; +const legacyContextAsyncLocalStorage = new AsyncLocalStorage(); /** * Execute the provided function in the runtime context. Call {@link getFlowContext()} anywhere * within the async call stack to retrieve the context. */ -export function runWithContext(context: any, fn: () => R) { - return contextAsyncLocalStorage.run(context, () => - runInActionRuntimeContext(fn) +export function runWithContext( + registry: Registry, + context: any, + fn: () => R +) { + return legacyContextAsyncLocalStorage.run(context, () => + registry.asyncStore.run(contextAlsKey, context, () => + runInActionRuntimeContext(registry, fn) + ) ); } @@ -34,13 +42,20 @@ export function runWithContext(context: any, fn: () => R) { * * @deprecated use {@link getFlowContext} */ -export function getFlowAuth(): any { - return contextAsyncLocalStorage.getStore(); +export function getFlowAuth(registry?: Registry | HasRegistry): any { + return getFlowContext(registry); } /** * Gets the runtime context of the current flow. */ -export function getFlowContext(): any { - return contextAsyncLocalStorage.getStore(); +export function getFlowContext(registry?: Registry | HasRegistry): any { + if (!registry) { + return legacyContextAsyncLocalStorage.getStore(); + } + if ((registry as HasRegistry).registry) { + registry = (registry as HasRegistry).registry; + } + registry = registry as Registry; + return registry.asyncStore.getStore(contextAlsKey); } diff --git a/js/core/src/flow.ts b/js/core/src/flow.ts index 9fac492c6..7052180f0 100644 --- a/js/core/src/flow.ts +++ b/js/core/src/flow.ts @@ -18,6 +18,7 @@ import * as bodyParser from 'body-parser'; import cors, { CorsOptions } from 'cors'; import express from 'express'; import { Server } from 'http'; +import { AsyncLocalStorage } from 'node:async_hooks'; import { z } from 'zod'; import { Action, @@ -28,7 +29,7 @@ import { import { runWithContext } from './auth.js'; import { getErrorMessage, getErrorStack } from './error.js'; import { logger } from './logging.js'; -import { Registry } from './registry.js'; +import { HasRegistry, Registry } from './registry.js'; import { runInNewSpan, SPAN_TYPE_ATTR } from './tracing.js'; const streamDelimiter = '\n\n'; @@ -163,7 +164,7 @@ export class Flow< readonly action: Action; constructor( - private registry: Registry, + readonly registry: Registry, config: FlowConfig | StreamingFlowConfig, action: Action ) { @@ -549,16 +550,25 @@ function defineFlowAction< }, async (input, { sendChunk, context }) => { await config.authPolicy?.(context, input); - return await runWithContext(context, () => fn(input, sendChunk)); + return await legacyRegistryAls.run(registry, () => + runWithContext(registry, context, () => fn(input, sendChunk)) + ); } ); } -export function run(name: string, func: () => Promise): Promise; +const legacyRegistryAls = new AsyncLocalStorage(); + +export function run( + name: string, + func: () => Promise, + registry?: Registry +): Promise; export function run( name: string, input: any, - func: (input?: any) => Promise + func: (input?: any) => Promise, + registry?: Registry ): Promise; /** @@ -567,14 +577,47 @@ export function run( export function run( name: string, funcOrInput: () => Promise, - fn?: (input?: any) => Promise + fnOrRegistry?: Registry | HasRegistry | ((input?: any) => Promise), + maybeRegistry?: Registry | HasRegistry ): Promise { - const func = arguments.length === 3 ? fn : funcOrInput; - const input = arguments.length === 3 ? funcOrInput : undefined; + let func; + let input; + let registry: Registry | undefined; + if (typeof funcOrInput === 'function') { + func = funcOrInput; + } else { + input = funcOrInput; + } + if (typeof fnOrRegistry === 'function') { + func = fnOrRegistry; + } else if ( + fnOrRegistry instanceof Registry || + (fnOrRegistry as HasRegistry)?.registry + ) { + registry = (fnOrRegistry as HasRegistry)?.registry + ? (fnOrRegistry as HasRegistry)?.registry + : (fnOrRegistry as Registry); + } + if (maybeRegistry) { + registry = (maybeRegistry as HasRegistry).registry + ? (maybeRegistry as HasRegistry).registry + : (maybeRegistry as Registry); + } + + if (!registry) { + registry = legacyRegistryAls.getStore(); + } + if (!registry) { + throw new Error( + 'Unable to resolve registry. Consider explicitly passing Genkit instance.' + ); + } + if (!func) { throw new Error('unable to resolve run function'); } return runInNewSpan( + registry, { metadata: { name }, labels: { diff --git a/js/core/src/plugin.ts b/js/core/src/plugin.ts index 34276c1d9..ae5bfe340 100644 --- a/js/core/src/plugin.ts +++ b/js/core/src/plugin.ts @@ -15,7 +15,7 @@ */ import { z } from 'zod'; -import { Action, isInRuntimeContext } from './action.js'; +import { Action } from './action.js'; export interface Provider { id: string; @@ -57,12 +57,6 @@ export function genkitPlugin( pluginName: string, initFn: T ): Plugin> { - if (isInRuntimeContext()) { - throw new Error( - 'Cannot define new plugins at runtime.\n' + - 'See: https://github.com/firebase/genkit/blob/main/docs/errors/no_new_actions_at_runtime.md' - ); - } return (...args: Parameters) => ({ name: pluginName, initializer: async () => { diff --git a/js/core/src/reflection.ts b/js/core/src/reflection.ts index 3201e15e5..9ef844503 100644 --- a/js/core/src/reflection.ts +++ b/js/core/src/reflection.ts @@ -167,8 +167,10 @@ export class ReflectionServer { const callback = (chunk) => { response.write(JSON.stringify(chunk) + '\n'); }; - const result = await runWithStreamingCallback(callback, () => - action.run(input, { context, onChunk: callback }) + const result = await runWithStreamingCallback( + this.registry, + callback, + () => action.run(input, { context, onChunk: callback }) ); await flushTracing(); response.write( diff --git a/js/core/src/registry.ts b/js/core/src/registry.ts index ed8414887..2127453f7 100644 --- a/js/core/src/registry.ts +++ b/js/core/src/registry.ts @@ -14,6 +14,7 @@ * limitations under the License. */ +import { AsyncLocalStorage } from 'node:async_hooks'; import * as z from 'zod'; import { Action } from './action.js'; import { logger } from './logging.js'; @@ -66,6 +67,8 @@ export class Registry { private valueByTypeAndName: Record> = {}; private allPluginsInitialized = false; + readonly asyncStore = new AsyncStore(); + constructor(public parent?: Registry) {} /** @@ -234,3 +237,28 @@ export class Registry { return this.schemasByName[name] || this.parent?.lookupSchema(name); } } + +/** + * Manages AsyncLocalStorage instances in a single place. + */ +export class AsyncStore { + private asls: Record> = {}; + + getStore(key: string): T | undefined { + return this.asls[key]?.getStore(); + } + + run(key: string, store: T, callback: () => R): R { + if (!this.asls[key]) { + this.asls[key] = new AsyncLocalStorage(); + } + return this.asls[key].run(store, callback); + } +} + +/** + * An object that has a reference to Genkit Registry. + */ +export interface HasRegistry { + get registry(): Registry; +} diff --git a/js/core/src/tracing/instrumentation.ts b/js/core/src/tracing/instrumentation.ts index 0dd0173cb..397460c40 100644 --- a/js/core/src/tracing/instrumentation.ts +++ b/js/core/src/tracing/instrumentation.ts @@ -21,13 +21,13 @@ import { SpanStatusCode, trace, } from '@opentelemetry/api'; -import { AsyncLocalStorage } from 'node:async_hooks'; import { performance } from 'node:perf_hooks'; +import { HasRegistry, Registry } from '../registry.js'; import { ensureBasicTelemetryInstrumentation } from '../tracing.js'; import { PathMetadata, SpanMetadata, TraceMetadata } from './types.js'; -export const spanMetadataAls = new AsyncLocalStorage(); -export const traceMetadataAls = new AsyncLocalStorage(); +export const spanMetadataAlsKey = 'core.tracing.instrumentation.span'; +export const traceMetadataAlsKey = 'core.tracing.instrumentation.trace'; export const ATTR_PREFIX = 'genkit'; export const SPAN_TYPE_ATTR = ATTR_PREFIX + ':type'; @@ -38,6 +38,7 @@ const TRACER_VERSION = 'v1'; * */ export async function newTrace( + registry: Registry | HasRegistry, opts: { name: string; labels?: Record; @@ -45,14 +46,21 @@ export async function newTrace( }, fn: (metadata: SpanMetadata, rootSpan: ApiSpan) => Promise ) { + registry = (registry as HasRegistry).registry + ? (registry as HasRegistry).registry + : (registry as Registry); + await ensureBasicTelemetryInstrumentation(); - const traceMetadata: TraceMetadata = traceMetadataAls.getStore() || { + const traceMetadata: TraceMetadata = registry.asyncStore.getStore( + traceMetadataAlsKey + ) || { paths: new Set(), timestamp: performance.now(), featureName: opts.name, }; - return await traceMetadataAls.run(traceMetadata, () => + return await registry.asyncStore.run(traceMetadataAlsKey, traceMetadata, () => runInNewSpan( + registry, { metadata: { name: opts.name, @@ -68,9 +76,10 @@ export async function newTrace( } /** - * + * Runs the provided function in a new span. */ export async function runInNewSpan( + registry: Registry | HasRegistry, opts: { metadata: SpanMetadata; labels?: Record; @@ -79,9 +88,13 @@ export async function runInNewSpan( fn: (metadata: SpanMetadata, otSpan: ApiSpan, isRoot: boolean) => Promise ): Promise { await ensureBasicTelemetryInstrumentation(); + const resolvedRegistry = (registry as HasRegistry).registry + ? (registry as HasRegistry).registry + : (registry as Registry); const tracer = trace.getTracer(TRACER_NAME, TRACER_VERSION); - const parentStep = spanMetadataAls.getStore(); + const parentStep = + resolvedRegistry.asyncStore.getStore(spanMetadataAlsKey); const isInRoot = parentStep?.isRoot === true; if (!parentStep) opts.metadata.isRoot ||= true; return await tracer.startActiveSpan( @@ -96,17 +109,19 @@ export async function runInNewSpan( opts.labels ); - const output = await spanMetadataAls.run(opts.metadata, () => - fn(opts.metadata, otSpan, isInRoot) + const output = await resolvedRegistry.asyncStore.run( + spanMetadataAlsKey, + opts.metadata, + () => fn(opts.metadata, otSpan, isInRoot) ); if (opts.metadata.state !== 'error') { opts.metadata.state = 'success'; } - recordPath(opts.metadata); + recordPath(resolvedRegistry, opts.metadata); return output; } catch (e) { - recordPath(opts.metadata, e); + recordPath(resolvedRegistry, opts.metadata, e); opts.metadata.state = 'error'; otSpan.setStatus({ code: SpanStatusCode.ERROR, @@ -183,8 +198,12 @@ function metadataToAttributes(metadata: SpanMetadata): Record { /** * Sets provided attribute value in the current span. */ -export function setCustomMetadataAttribute(key: string, value: string) { - const currentStep = getCurrentSpan(); +export function setCustomMetadataAttribute( + registry: Registry, + key: string, + value: string +) { + const currentStep = getCurrentSpan(registry); if (!currentStep) { return; } @@ -197,8 +216,11 @@ export function setCustomMetadataAttribute(key: string, value: string) { /** * Sets provided attribute values in the current span. */ -export function setCustomMetadataAttributes(values: Record) { - const currentStep = getCurrentSpan(); +export function setCustomMetadataAttributes( + registry: Registry, + values: Record +) { + const currentStep = getCurrentSpan(registry); if (!currentStep) { return; } @@ -216,8 +238,8 @@ export function toDisplayPath(path: string): string { return Array.from(path.matchAll(pathPartRegex), (m) => m[1]).join(' > '); } -function getCurrentSpan(): SpanMetadata { - const step = spanMetadataAls.getStore(); +function getCurrentSpan(registry: Registry): SpanMetadata { + const step = registry.asyncStore.getStore(spanMetadataAlsKey); if (!step) { throw new Error('running outside step context'); } @@ -236,24 +258,29 @@ function buildPath( return parentPath + `/{${name}${stepType}}`; } -function recordPath(spanMeta: SpanMetadata, err?: any) { +function recordPath(registry: Registry, spanMeta: SpanMetadata, err?: any) { const path = spanMeta.path || ''; const decoratedPath = decoratePathWithSubtype(spanMeta); // Only add the path if a child has not already been added. In the event that // an error is rethrown, we don't want to add each step in the unwind. const paths = Array.from( - traceMetadataAls.getStore()?.paths || new Set() + registry.asyncStore.getStore(traceMetadataAlsKey)?.paths || + new Set() ); const status = err ? 'failure' : 'success'; if (!paths.some((p) => p.path.startsWith(path) && p.status === status)) { const now = performance.now(); - const start = traceMetadataAls.getStore()?.timestamp || now; - traceMetadataAls.getStore()?.paths?.add({ - path: decoratedPath, - error: err?.name, - latency: now - start, - status, - }); + const start = + registry.asyncStore.getStore(traceMetadataAlsKey) + ?.timestamp || now; + registry.asyncStore + .getStore(traceMetadataAlsKey) + ?.paths?.add({ + path: decoratedPath, + error: err?.name, + latency: now - start, + status, + }); } spanMeta.path = decoratedPath; } diff --git a/js/core/tests/action_test.ts b/js/core/tests/action_test.ts index 410cc4f4a..30437639c 100644 --- a/js/core/tests/action_test.ts +++ b/js/core/tests/action_test.ts @@ -15,13 +15,20 @@ */ import assert from 'node:assert'; -import { describe, it } from 'node:test'; +import { beforeEach, describe, it } from 'node:test'; import { z } from 'zod'; import { action } from '../src/action.js'; +import { Registry } from '../src/registry.js'; describe('action', () => { + var registry: Registry; + beforeEach(() => { + registry = new Registry(); + }); + it('applies middleware', async () => { const act = action( + registry, { name: 'foo', inputSchema: z.string(), @@ -31,6 +38,7 @@ describe('action', () => { async (input, opts, next) => (await next(input + 'middle2', opts)) + 2, ], + actionType: 'util', }, async (input) => { return input.length; @@ -45,6 +53,7 @@ describe('action', () => { it('returns telemetry info', async () => { const act = action( + registry, { name: 'foo', inputSchema: z.string(), @@ -54,6 +63,7 @@ describe('action', () => { async (input, opts, next) => (await next(input + 'middle2', opts)) + 2, ], + actionType: 'util', }, async (input) => { return input.length; @@ -79,10 +89,12 @@ describe('action', () => { it('run the action with options', async () => { let passedContext; const act = action( + registry, { name: 'foo', inputSchema: z.string(), outputSchema: z.number(), + actionType: 'util', }, async (input, { sendChunk, context }) => { passedContext = context; diff --git a/js/core/tests/registry_test.ts b/js/core/tests/registry_test.ts index d54fdd415..2b52dde0b 100644 --- a/js/core/tests/registry_test.ts +++ b/js/core/tests/registry_test.ts @@ -28,12 +28,14 @@ describe('registry class', () => { describe('listActions', () => { it('returns all registered actions', async () => { const fooSomethingAction = action( - { name: 'foo_something' }, + registry, + { name: 'foo_something', actionType: 'util' }, async () => null ); registry.registerAction('model', fooSomethingAction); const barSomethingAction = action( - { name: 'bar_something' }, + registry, + { name: 'bar_something', actionType: 'util' }, async () => null ); registry.registerAction('model', barSomethingAction); @@ -53,11 +55,13 @@ describe('registry class', () => { }, }); const fooSomethingAction = action( + registry, { name: { pluginId: 'foo', actionId: 'something', }, + actionType: 'util', }, async () => null ); @@ -69,11 +73,13 @@ describe('registry class', () => { }, }); const barSomethingAction = action( + registry, { name: { pluginId: 'bar', actionId: 'something', }, + actionType: 'util', }, async () => null ); @@ -88,12 +94,14 @@ describe('registry class', () => { const child = Registry.withParent(registry); const fooSomethingAction = action( - { name: 'foo_something' }, + registry, + { name: 'foo_something', actionType: 'util' }, async () => null ); registry.registerAction('model', fooSomethingAction); const barSomethingAction = action( - { name: 'bar_something' }, + registry, + { name: 'bar_something', actionType: 'util' }, async () => null ); child.registerAction('model', barSomethingAction); @@ -140,12 +148,14 @@ describe('registry class', () => { it('returns registered action', async () => { const fooSomethingAction = action( - { name: 'foo_something' }, + registry, + { name: 'foo_something', actionType: 'util' }, async () => null ); registry.registerAction('model', fooSomethingAction); const barSomethingAction = action( - { name: 'bar_something' }, + registry, + { name: 'bar_something', actionType: 'util' }, async () => null ); registry.registerAction('model', barSomethingAction); @@ -169,11 +179,13 @@ describe('registry class', () => { }, }); const somethingAction = action( + registry, { name: { pluginId: 'foo', actionId: 'something', }, + actionType: 'util', }, async () => null ); @@ -194,7 +206,11 @@ describe('registry class', () => { it('should lookup parent registry when child missing action', async () => { const childRegistry = new Registry(registry); - const fooAction = action({ name: 'foo' }, async () => null); + const fooAction = action( + registry, + { name: 'foo', actionType: 'util' }, + async () => null + ); registry.registerAction('model', fooAction); assert.strictEqual(await registry.lookupAction('/model/foo'), fooAction); @@ -209,7 +225,11 @@ describe('registry class', () => { assert.strictEqual(childRegistry.parent, registry); - const fooAction = action({ name: 'foo' }, async () => null); + const fooAction = action( + registry, + { name: 'foo', actionType: 'util' }, + async () => null + ); childRegistry.registerAction('model', fooAction); assert.strictEqual(await registry.lookupAction('/model/foo'), undefined); diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index 673ff8fd3..a7c81abc1 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -121,6 +121,7 @@ import { StreamingFlowConfig, z, } from '@genkit-ai/core'; +import { HasRegistry } from '@genkit-ai/core/registry'; import { defineDotprompt, defineHelper, @@ -168,7 +169,7 @@ export type PromptMetadata< * * There may be multiple Genkit instances in a single codebase. */ -export class Genkit { +export class Genkit implements HasRegistry { /** Developer-configured options. */ readonly options: GenkitOptions; /** Environments that have been configured (at minimum dev). */ @@ -1036,7 +1037,7 @@ export class Genkit { * Gets the current session from async local storage. */ currentSession(): Session { - const currentSession = getCurrentSession(); + const currentSession = getCurrentSession(this.registry); if (!currentSession) { throw new SessionError('not running within a session'); } diff --git a/js/genkit/src/tracing.ts b/js/genkit/src/tracing.ts index 252c8df96..766e0d1ef 100644 --- a/js/genkit/src/tracing.ts +++ b/js/genkit/src/tracing.ts @@ -38,9 +38,7 @@ export { setCustomMetadataAttribute, setCustomMetadataAttributes, setTelemetryServerUrl, - spanMetadataAls, toDisplayPath, - traceMetadataAls, type PathMetadata, type SpanData, type SpanMetadata, diff --git a/js/plugins/checks/src/evaluation.ts b/js/plugins/checks/src/evaluation.ts index a3eb9afcc..bd1fd5df9 100644 --- a/js/plugins/checks/src/evaluation.ts +++ b/js/plugins/checks/src/evaluation.ts @@ -124,6 +124,7 @@ function createPolicyEvaluator( }; const response = await checksEvalInstance( + ai, projectId, auth, partialRequest, @@ -149,12 +150,14 @@ function createPolicyEvaluator( } async function checksEvalInstance( + ai: Genkit, projectId: string, auth: GoogleAuth, partialRequest: any, responseSchema: ResponseType ): Promise> { return await runInNewSpan( + ai, { metadata: { name: 'EvaluationService#evaluateInstances', diff --git a/js/plugins/dotprompt/src/prompt.ts b/js/plugins/dotprompt/src/prompt.ts index a78a14e98..5e55da350 100644 --- a/js/plugins/dotprompt/src/prompt.ts +++ b/js/plugins/dotprompt/src/prompt.ts @@ -182,8 +182,8 @@ export class Dotprompt implements PromptMetadata { */ renderMessages(input?: I, options?: RenderMetadata): MessageData[] { let sessionStateData: Record | undefined = undefined; - if (getCurrentSession()) { - sessionStateData = { state: getCurrentSession()?.state }; + if (getCurrentSession(this.registry)) { + sessionStateData = { state: getCurrentSession(this.registry)?.state }; } input = parseSchema(input, { schema: this.input?.schema, @@ -278,6 +278,7 @@ export class Dotprompt implements PromptMetadata { >(opt: PromptGenerateOptions): Promise> { const spanName = this.variant ? `${this.name}.${this.variant}` : this.name; return runInNewSpan( + this.registry, { metadata: { name: spanName, @@ -288,7 +289,11 @@ export class Dotprompt implements PromptMetadata { }, }, async (metadata) => { - setCustomMetadataAttribute('prompt_fingerprint', this.hash); + setCustomMetadataAttribute( + this.registry, + 'prompt_fingerprint', + this.hash + ); const generateOptions = this._generateOptions(opt); metadata.output = generateOptions; return generateOptions; diff --git a/js/plugins/express/src/index.ts b/js/plugins/express/src/index.ts index 9afdf7dca..0d229e8f0 100644 --- a/js/plugins/express/src/index.ts +++ b/js/plugins/express/src/index.ts @@ -23,6 +23,7 @@ import { z, } from 'genkit'; import { logger } from 'genkit/logging'; +import { Registry } from 'genkit/registry'; import { getErrorMessage, getErrorStack } from './utils'; const streamDelimiter = '\n\n'; @@ -82,6 +83,7 @@ export function handler< ? (f as CallableFlow).flow : undefined; const action: Action = flow ? flow.action : (f as Action); + const registry: Registry = flow ? flow.registry : action.__registry; return async ( request: RequestWithAuth, response: express.Response @@ -121,7 +123,7 @@ export function handler< 'data: ' + JSON.stringify({ message: chunk }) + streamDelimiter ); }; - const result = await runWithStreamingCallback(onChunk, () => + const result = await runWithStreamingCallback(registry, onChunk, () => action.run(input, { onChunk, context: auth, diff --git a/js/plugins/vertexai/src/evaluation/evaluator_factory.ts b/js/plugins/vertexai/src/evaluation/evaluator_factory.ts index 821f4631b..4b33144e3 100644 --- a/js/plugins/vertexai/src/evaluation/evaluator_factory.ts +++ b/js/plugins/vertexai/src/evaluation/evaluator_factory.ts @@ -47,6 +47,7 @@ export class EvaluatorFactory { async (datapoint: BaseEvalDataPoint) => { const responseSchema = config.responseSchema; const response = await this.evaluateInstances( + ai, toRequest(datapoint), responseSchema ); @@ -60,11 +61,13 @@ export class EvaluatorFactory { } async evaluateInstances( + ai: Genkit, partialRequest: any, responseSchema: ResponseType ): Promise> { const locationName = `projects/${this.projectId}/locations/${this.location}`; return await runInNewSpan( + ai, { metadata: { name: 'EvaluationService#evaluateInstances', diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index d2c7db228..c1711a784 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -1329,7 +1329,7 @@ importers: version: link:../../plugins/ollama genkitx-openai: specifier: ^0.10.1 - version: 0.10.1(@genkit-ai/ai@0.9.7)(@genkit-ai/core@0.9.7) + version: 0.10.1(@genkit-ai/ai@1.0.0-dev.0)(@genkit-ai/core@1.0.0-dev.0) devDependencies: rimraf: specifier: ^6.0.1 @@ -1384,6 +1384,9 @@ importers: '@genkit-ai/vertexai': specifier: workspace:* version: link:../../plugins/vertexai + '@google-cloud/firestore': + specifier: ^7.11.0 + version: 7.11.0(encoding@0.1.13) firebase-admin: specifier: '>=12.2' version: 12.3.1(encoding@0.1.13) @@ -2143,11 +2146,11 @@ packages: '@firebase/util@1.9.5': resolution: {integrity: sha512-PP4pAFISDxsf70l3pEy34Mf3GkkUcVQ3MdKp6aSVb7tcpfUQxnsdV7twDd8EkfB6zZylH6wpUAoangQDmCUMqw==} - '@genkit-ai/ai@0.9.7': - resolution: {integrity: sha512-CHnN12+J/577EN3uo7qw8tSTa124qwIE+NPHn+TBEtoEkHc70Ck+CxJm0UCfN606hAQyJjj0i6BJ6M4Lr0sMzw==} + '@genkit-ai/ai@1.0.0-dev.0': + resolution: {integrity: sha512-dtQyym12Z/yPw04h7n8n7amU0dpO8iB+uTyDTpSaRe7lkf3SUkNhWM+DDoKUZ70A6cqpAvXN0/haqD9ZpSN+FA==} - '@genkit-ai/core@0.9.7': - resolution: {integrity: sha512-dNKw172HSzgjgRwf8gwyyjYGnopYwfW3iVPHUaCvIXTCX+C/7kJGYea4Xes7b9ushWYOmhG34q/uea7rhLS/Qg==} + '@genkit-ai/core@1.0.0-dev.0': + resolution: {integrity: sha512-mHSAdziskC7YOXkGXCK8299RmGu111dPnxzW2hy8epRpyTfis3SXkB1iJEZE3lxA9tKV9vG2zDxDI0bk4iQbXg==} '@google-cloud/aiplatform@3.25.0': resolution: {integrity: sha512-qKnJgbyCENjed8e1G5zZGFTxxNKhhaKQN414W2KIVHrLxMFmlMuG+3QkXPOWwXBnT5zZ7aMxypt5og0jCirpHg==} @@ -2161,12 +2164,12 @@ packages: resolution: {integrity: sha512-7NBC5vD0au75nkctVs2vEGpdUPFs1BaHTMpeI+RVEgQSMe5/wEU6dx9p0fmZA0bj4HgdpobMKeegOcLUiEoxng==} engines: {node: '>=14.0.0'} - '@google-cloud/firestore@7.6.0': - resolution: {integrity: sha512-WUDbaLY8UnPxgwsyIaxj6uxCtSDAaUyvzWJykNH5rZ9i92/SZCsPNNMN0ajrVpAR81hPIL4amXTaMJ40y5L+Yg==} + '@google-cloud/firestore@7.11.0': + resolution: {integrity: sha512-88uZ+jLsp1aVMj7gh3EKYH1aulTAMFAp8sH/v5a9w8q8iqSG27RiWLoxSAFr/XocZ9hGiWH1kEnBw+zl3xAgNA==} engines: {node: '>=14.0.0'} - '@google-cloud/firestore@7.9.0': - resolution: {integrity: sha512-c4ALHT3G08rV7Zwv8Z2KG63gZh66iKdhCBeDfCpIkLrjX6EAjTD/szMdj14M+FnQuClZLFfW5bAgoOjfNmLtJg==} + '@google-cloud/firestore@7.6.0': + resolution: {integrity: sha512-WUDbaLY8UnPxgwsyIaxj6uxCtSDAaUyvzWJykNH5rZ9i92/SZCsPNNMN0ajrVpAR81hPIL4amXTaMJ40y5L+Yg==} engines: {node: '>=14.0.0'} '@google-cloud/logging-winston@6.0.0': @@ -6927,9 +6930,9 @@ snapshots: dependencies: tslib: 2.6.2 - '@genkit-ai/ai@0.9.7': + '@genkit-ai/ai@1.0.0-dev.0': dependencies: - '@genkit-ai/core': 0.9.7 + '@genkit-ai/core': 1.0.0-dev.0 '@opentelemetry/api': 1.9.0 '@types/node': 20.16.9 colorette: 2.0.20 @@ -6940,7 +6943,7 @@ snapshots: transitivePeerDependencies: - supports-color - '@genkit-ai/core@0.9.7': + '@genkit-ai/core@1.0.0-dev.0': dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/context-async-hooks': 1.25.1(@opentelemetry/api@1.9.0) @@ -7001,26 +7004,26 @@ snapshots: - encoding - supports-color - '@google-cloud/firestore@7.6.0(encoding@0.1.13)': + '@google-cloud/firestore@7.11.0(encoding@0.1.13)': dependencies: + '@opentelemetry/api': 1.9.0 fast-deep-equal: 3.1.3 functional-red-black-tree: 1.0.1 - google-gax: 4.3.2(encoding@0.1.13) - protobufjs: 7.2.6 + google-gax: 4.4.1(encoding@0.1.13) + protobufjs: 7.3.2 transitivePeerDependencies: - encoding - supports-color - '@google-cloud/firestore@7.9.0(encoding@0.1.13)': + '@google-cloud/firestore@7.6.0(encoding@0.1.13)': dependencies: fast-deep-equal: 3.1.3 functional-red-black-tree: 1.0.1 - google-gax: 4.3.7(encoding@0.1.13) - protobufjs: 7.3.2 + google-gax: 4.3.2(encoding@0.1.13) + protobufjs: 7.2.6 transitivePeerDependencies: - encoding - supports-color - optional: true '@google-cloud/logging-winston@6.0.0(encoding@0.1.13)(winston@3.13.0)': dependencies: @@ -7154,7 +7157,7 @@ snapshots: dependencies: lodash.camelcase: 4.3.0 long: 5.2.3 - protobufjs: 7.2.6 + protobufjs: 7.3.2 yargs: 17.7.2 '@grpc/proto-loader@0.7.13': @@ -9337,7 +9340,7 @@ snapshots: node-forge: 1.3.1 uuid: 10.0.0 optionalDependencies: - '@google-cloud/firestore': 7.9.0(encoding@0.1.13) + '@google-cloud/firestore': 7.11.0(encoding@0.1.13) '@google-cloud/storage': 7.10.1(encoding@0.1.13) transitivePeerDependencies: - encoding @@ -9459,10 +9462,10 @@ snapshots: - encoding - supports-color - genkitx-openai@0.10.1(@genkit-ai/ai@0.9.7)(@genkit-ai/core@0.9.7): + genkitx-openai@0.10.1(@genkit-ai/ai@1.0.0-dev.0)(@genkit-ai/core@1.0.0-dev.0): dependencies: - '@genkit-ai/ai': 0.9.7 - '@genkit-ai/core': 0.9.7 + '@genkit-ai/ai': 1.0.0-dev.0 + '@genkit-ai/core': 1.0.0-dev.0 openai: 4.53.0(encoding@0.1.13) zod: 3.23.8 transitivePeerDependencies: @@ -11130,7 +11133,7 @@ snapshots: proto3-json-serializer@2.0.1: dependencies: - protobufjs: 7.2.6 + protobufjs: 7.3.2 proto3-json-serializer@2.0.2: dependencies: diff --git a/js/testapps/rag/package.json b/js/testapps/rag/package.json index 3b2bdf5a5..6a610fc0c 100644 --- a/js/testapps/rag/package.json +++ b/js/testapps/rag/package.json @@ -21,9 +21,10 @@ "@genkit-ai/firebase": "workspace:*", "@genkit-ai/googleai": "workspace:*", "@genkit-ai/vertexai": "workspace:*", + "@google-cloud/firestore": "^7.11.0", + "firebase-admin": ">=12.2", "genkit": "workspace:*", "genkitx-chromadb": "workspace:*", - "firebase-admin": ">=12.2", "genkitx-pinecone": "workspace:*", "google-auth-library": "^9.6.3", "llm-chunk": "^0.0.1",