diff --git a/plugins/bedrock/index.test.ts b/plugins/bedrock/bedrock.test.ts similarity index 58% rename from plugins/bedrock/index.test.ts rename to plugins/bedrock/bedrock.test.ts index 0c95eb5d1..17922dc36 100644 --- a/plugins/bedrock/index.test.ts +++ b/plugins/bedrock/bedrock.test.ts @@ -1,7 +1,7 @@ import { PluginContext, PluginParameters } from '../types'; -import { BedrockParameters, pluginHandler } from './index'; -import { bedrockPIIHandler } from './redactPii'; +import { pluginHandler } from './index'; import creds from './.creds.json'; +import { BedrockParameters } from './type'; /** * @example Parameters object @@ -90,7 +90,7 @@ describe('Credentials check', () => { expect(result).toBeDefined(); expect(result.verdict).toBe(false); expect(result.error).toBe(null); - expect(result.data.customWords).toHaveLength(1); + expect(result.data).toBeDefined(); }); test('Should be working with content_filter', async () => { @@ -113,77 +113,77 @@ describe('Credentials check', () => { expect(result).toBeDefined(); expect(result.verdict).toBe(false); expect(result.error).toBe(null); - expect(result.data.filters).toHaveLength(1); + expect(result.data).toBeDefined(); }); - test('Should work fine with redaction for sensitive info', async () => { - const context = { - response: { - json: { - choices: [ - { - message: { - content: - 'Hello, John doe. How are you doing?. I see your email is john@doe.com', - }, - }, - ], - }, - }, - requestType: 'chatComplete', - }; - - const parameters: PluginParameters = { - ...creds, - }; - - const result = await bedrockPIIHandler( - context as unknown as PluginContext, - parameters, - 'afterRequestHook', - { env: {} } - ); - - const outputMessage = - result.transformedData?.response.json.choices[0].message.content; - expect(result).toBeDefined(); - expect(result.verdict).toBe(true); - expect(outputMessage).toEqual( - 'Hello, {NAME}. How are you doing?. I see your email is {EMAIL}\n' - ); - }); - - test('Should work fine with regex redaction for sensitive info', async () => { - const context = { - response: { - json: { - choices: [ - { - message: { - content: 'bedrock-12121, bedrock-12121', - }, - }, - ], - }, - }, - requestType: 'chatComplete', - }; - - const parameters: PluginParameters = { - ...creds, - }; - - const result = await bedrockPIIHandler( - context as unknown as PluginContext, - parameters, - 'afterRequestHook', - { env: {} } - ); - - const outputMessage = - result.transformedData?.response.json.choices[0].message.content; - expect(result).toBeDefined(); - expect(result.verdict).toBe(true); - expect(outputMessage).toBe('{bedrock-id}, {bedrock-id}\n'); - }); + // test('Should work fine with redaction for sensitive info', async () => { + // const context = { + // response: { + // json: { + // choices: [ + // { + // message: { + // content: + // 'Hello, John doe. How are you doing?. I see your email is john@doe.com', + // }, + // }, + // ], + // }, + // }, + // requestType: 'chatComplete', + // }; + + // const parameters: PluginParameters = { + // ...creds, + // }; + + // const result = await bedrockPIIHandler( + // context as unknown as PluginContext, + // parameters, + // 'afterRequestHook', + // { env: {} } + // ); + + // const outputMessage = + // result.transformedData?.response.json.choices[0].message.content; + // expect(result).toBeDefined(); + // expect(result.verdict).toBe(true); + // expect(outputMessage).toEqual( + // 'Hello, {NAME}. How are you doing?. I see your email is {EMAIL}\n' + // ); + // }); + + // test('Should work fine with regex redaction for sensitive info', async () => { + // const context = { + // response: { + // json: { + // choices: [ + // { + // message: { + // content: 'bedrock-12121, bedrock-12121', + // }, + // }, + // ], + // }, + // }, + // requestType: 'chatComplete', + // }; + + // const parameters: PluginParameters = { + // ...creds, + // }; + + // const result = await bedrockPIIHandler( + // context as unknown as PluginContext, + // parameters, + // 'afterRequestHook', + // { env: {} } + // ); + + // const outputMessage = + // result.transformedData?.response.json.choices[0].message.content; + // expect(result).toBeDefined(); + // expect(result.verdict).toBe(true); + // expect(outputMessage).toBe('{bedrock-id}, {bedrock-id}\n'); + // }); }); diff --git a/plugins/bedrock/index.ts b/plugins/bedrock/index.ts index 25395ae07..f32b778e9 100644 --- a/plugins/bedrock/index.ts +++ b/plugins/bedrock/index.ts @@ -1,112 +1,15 @@ -import { PluginHandler } from '../types'; -import { getText, HttpError, post } from '../utils'; -import { generateAWSHeaders } from './util'; +import { HookEventType, PluginContext, PluginHandler } from '../types'; +import { + getCurrentContentPart, + getText, + HttpError, + setCurrentContentPart, +} from '../utils'; +import { BedrockBody, BedrockParameters } from './type'; +import { bedrockPost, redactPii } from './util'; const REQUIRED_CREDENTIAL_KEYS = ['accessKeyId', 'accessKeySecret', 'region']; -type BedrockFunction = 'contentFilter' | 'pii' | 'wordFilter'; -export type BedrockBody = { - source: 'INPUT' | 'OUTPUT'; - content: { text: { text: string } }[]; -}; -type PIIType = - | 'ADDRESS' - | 'AGE' - | 'AWS_ACCESS_KEY' - | 'AWS_SECRET_KEY' - | 'CA_HEALTH_NUMBER' - | 'CA_SOCIAL_INSURANCE_NUMBER' - | 'CREDIT_DEBIT_CARD_CVV' - | 'CREDIT_DEBIT_CARD_EXPIRY' - | 'CREDIT_DEBIT_CARD_NUMBER' - | 'DRIVER_ID' - | 'EMAIL' - | 'INTERNATIONAL_BANK_ACCOUNT_NUMBER' - | 'IP_ADDRESS' - | 'LICENSE_PLATE' - | 'MAC_ADDRESS' - | 'NAME' - | 'PASSWORD' - | 'PHONE' - | 'PIN' - | 'SWIFT_CODE' - | 'UK_NATIONAL_HEALTH_SERVICE_NUMBER' - | 'UK_NATIONAL_INSURANCE_NUMBER' - | 'UK_UNIQUE_TAXPAYER_REFERENCE_NUMBER' - | 'URL' - | 'USERNAME' - | 'US_BANK_ACCOUNT_NUMBER' - | 'US_BANK_ROUTING_NUMBER' - | 'US_INDIVIDUAL_TAX_IDENTIFICATION_NUMBER' - | 'US_PASSPORT_NUMBER' - | 'US_SOCIAL_SECURITY_NUMBER' - | 'VEHICLE_IDENTIFICATION_NUMBER'; - -interface BedrockAction { - action: 'BLOCKED' | T; -} - -interface ContentPolicy extends BedrockAction { - confidence: 'LOW' | 'NONE' | 'MEDIUM' | 'HIGH'; - type: - | 'INSULTS' - | 'HATE' - | 'SEXUAL' - | 'VIOLENCE' - | 'MISCONDUCT' - | 'PROMPT_ATTACK'; - filterStrength: 'LOW' | 'MEDIUM' | 'HIGH'; -} - -interface WordPolicy extends BedrockAction { - match: string; -} - -interface PIIFilter extends BedrockAction<'ANONYMIZED'> { - match: string; - type: PIIType; -} - -interface BedrockResponse { - action: 'NONE' | 'GUARDRAIL_INTERVENED'; - assessments: { - wordPolicy: { - customWords: WordPolicy[]; - managedWordLists: (WordPolicy & { type: 'PROFANITY' })[]; - }; - contentPolicy: { filters: ContentPolicy[] }; - sensitiveInformationPolicy: { - piiEntities: PIIFilter[]; - regexes: (Omit & { name: string; regex: string })[]; - }; - }[]; - output: { - text: string; - }[]; - usage: { - contentPolicyUnits: number; - sensitiveInformationPolicyUnits: number; - wordPolicyUnits: number; - }; -} - -export interface BedrockParameters { - credentials: { - accessKeyId: string; - accessKeySecret: string; - awsSessionToken?: string; - region: string; - }; - guardrailVersion: string; - guardrailId: string; -} - -enum ResponseKey { - pii = 'sensitiveInformationPolicy', - contentFilter = 'contentPolicy', - wordFilter = 'wordPolicy', -} - export const validateCreds = ( credentials?: BedrockParameters['credentials'] ) => { @@ -115,102 +18,107 @@ export const validateCreds = ( ); }; -export const bedrockPost = async ( - credentials: Record, - body: BedrockBody +const transformedData = { + request: { + json: null, + }, + response: { + json: null, + }, +}; + +const handleRedaction = async ( + context: PluginContext, + hookType: HookEventType, + credentials: Record ) => { - const url = `https://bedrock-runtime.${credentials?.region}.amazonaws.com/guardrail/${credentials?.guardrailId}/version/${credentials?.guardrailVersion}/apply`; + const { content, textArray } = getCurrentContentPart(context, hookType); - const headers = await generateAWSHeaders( - body, - { - 'Content-Type': 'application/json', - }, - url, - 'POST', - 'bedrock', - credentials?.region ?? 'us-east-1', - credentials?.accessKeyId!, - credentials?.accessKeySecret!, - credentials?.awsSessionToken || '' - ); + if (!content) { + return []; + } + const redactPromises = textArray.map(async (text) => { + const result = await redactPii(text, hookType, credentials); - return await post(url, body, { - headers, - method: 'POST', + if (result) { + setCurrentContentPart(context, hookType, transformedData, result); + } }); + + await Promise.all(redactPromises); }; -export const pluginHandler: PluginHandler = - async function ( - this: { fn: BedrockFunction }, - context, - parameters, - eventType - ) { - const credentials = parameters.credentials; - - const validate = validateCreds(credentials); - - const guardrailVersion = parameters.guardrailVersion; - const guardrailId = parameters.guardrailId; - - let verdict = true; - let error = null; - let data = null; - - if (!validate || !guardrailVersion || !guardrailId) { - return { - verdict, - error: 'Missing required credentials', - data, - }; - } +export const pluginHandler: PluginHandler< + BedrockParameters['credentials'] +> = async (context, parameters, eventType) => { + const credentials = parameters.credentials; - const body = {} as BedrockBody; + const validate = validateCreds(credentials); - if (eventType === 'beforeRequestHook') { - body.source = 'INPUT'; - } else { - body.source = 'OUTPUT'; - } + const guardrailVersion = parameters.guardrailVersion; + const guardrailId = parameters.guardrailId; + const pii = parameters?.piiCheck as boolean; - body.content = [ - { - text: { - text: getText(context, eventType), - }, - }, - ]; - - try { - const response = await bedrockPost( - { ...(credentials as any), guardrailId, guardrailVersion }, - body - ); - if (response.action === 'GUARDRAIL_INTERVENED') { - data = response.assessments[0]?.[ResponseKey[this.fn]]; - if (this.fn === 'pii' && !!data) { - verdict = false; - } - if (this.fn === 'contentFilter' && !!data) { - verdict = false; - } - - if (this.fn === 'wordFilter' && !!data) { - verdict = false; - } - } - } catch (e) { - if (e instanceof HttpError) { - error = e.response.body; - } else { - error = (e as Error).message; - } - } + let verdict = true; + let error = null; + let data = null; + if (!validate || !guardrailVersion || !guardrailId) { return { verdict, - error, + error: 'Missing required credentials', data, }; + } + + if (pii) { + await handleRedaction(context, eventType, { + ...credentials, + guardrailId, + guardrailVersion, + }); + + return { error, data, verdict: true, transformedData }; + } + + const body = {} as BedrockBody; + + if (eventType === 'beforeRequestHook') { + body.source = 'INPUT'; + } else { + body.source = 'OUTPUT'; + } + + body.content = [ + { + text: { + text: getText(context, eventType), + }, + }, + ]; + + try { + const response = await bedrockPost( + { ...(credentials as any), guardrailId, guardrailVersion }, + body + ); + if (response.action === 'GUARDRAIL_INTERVENED') { + verdict = false; + // Send assessments + data = response.assessments[0] as any; + + delete data['invocationMetrics']; + delete data['usage']; + } + } catch (e) { + if (e instanceof HttpError) { + error = e.response.body; + } else { + error = (e as Error).message; + } + } + return { + verdict, + error, + data, }; +}; diff --git a/plugins/bedrock/manifest.json b/plugins/bedrock/manifest.json index cbc6d698a..bd589cd9d 100644 --- a/plugins/bedrock/manifest.json +++ b/plugins/bedrock/manifest.json @@ -32,8 +32,8 @@ "functions": [ { - "name": "Content Filter", - "id": "contentFilter", + "name": "Apply Bedrock guardrail", + "id": "guard", "type": "guardrail", "supportedHooks": ["beforeRequestHook", "afterRequestHook"], "description": [ @@ -54,34 +54,15 @@ "type": "string", "label": "Guardrail ID", "description": "ID of the guardrail." + }, + "piiCheck": { + "type": "boolean", + "label": "PII Guard", + "description": "Enable Personally Identifiable Information(PII) check" } }, "required": ["guardrailVersion", "guardrailId"] } - }, - { - "id": "pii", - "name": "Sensitive Content", - "type": "guardrail", - "supportedHooks": ["beforeRequestHook", "afterRequestHook"], - "description": [ - { - "type": "subHeading", - "text": "Checks if the content contains any Personally Identifiable Information (PII)." - } - ] - }, - { - "id": "wordFilter", - "name": "Word Filter", - "type": "guardrail", - "supportedHooks": ["beforeRequestHook", "afterRequestHook"], - "description": [ - { - "type": "subHeading", - "text": "Filters out words that are not allowed in the content." - } - ] } ] } diff --git a/plugins/bedrock/redactPii.ts b/plugins/bedrock/redactPii.ts deleted file mode 100644 index 68ac6b9f4..000000000 --- a/plugins/bedrock/redactPii.ts +++ /dev/null @@ -1,121 +0,0 @@ -import { BedrockBody, BedrockParameters, bedrockPost, validateCreds } from '.'; -import { HookEventType, PluginHandler } from '../types'; -import { getCurrentContentPart, setCurrentContentPart } from '../utils'; - -const redactPii = async ( - text: string, - eventType: HookEventType, - credentials: BedrockParameters -) => { - const body = {} as BedrockBody; - - if (eventType === 'beforeRequestHook') { - body.source = 'INPUT'; - } else { - body.source = 'OUTPUT'; - } - - body.content = [ - { - text: { - text, - }, - }, - ]; - - try { - const response = await bedrockPost({ ...(credentials as any) }, body); - let maskedText = text; - const data = response.output?.[0]; - - maskedText = data?.text; - const isRegexMatch = - response.assessments[0].sensitiveInformationPolicy?.regexes?.length > 0; - if (isRegexMatch) { - response.assessments[0].sensitiveInformationPolicy.regexes.forEach( - (regex) => { - maskedText = maskedText.replaceAll(regex.match, `{${regex.name}}`); - } - ); - } - return maskedText; - } catch (e) { - return null; - } -}; - -export const bedrockPIIHandler: PluginHandler< - BedrockParameters['credentials'] -> = async function (context, parameters, eventType) { - let transformedData: Record = { - request: { - json: null, - }, - response: { - json: null, - }, - }; - - const credentials = parameters.credentials; - - const guardrailVersion = parameters.guardrailVersion; - const guardrailId = parameters.guardrailId; - - const validate = validateCreds(credentials); - - if (!validate || !guardrailVersion || !guardrailId) { - return { - verdict: true, - error: 'Missing required credentials', - data: null, - }; - } - - try { - const { content, textArray } = getCurrentContentPart(context, eventType); - - if (!content) { - return { - error: { message: 'request or response json is empty' }, - verdict: true, - data: null, - }; - } - - const transformedTextPromise = textArray.map((text) => - redactPii(text, eventType, { - ...(credentials as any), - guardrailId, - guardrailVersion, - }) - ); - - const transformedText = await Promise.all(transformedTextPromise); - - setCurrentContentPart( - context, - eventType, - transformedData, - null, - transformedText - ); - - return { - error: null, - verdict: true, - data: - transformedText.filter((text) => text !== null).length > 0 - ? { flagged: true } - : null, - transformedData, - }; - } catch (e: any) { - delete e.stack; - return { - error: e as Error, - verdict: true, - data: null, - transformedData, - }; - } -}; diff --git a/plugins/bedrock/type.ts b/plugins/bedrock/type.ts new file mode 100644 index 000000000..f208b5ebf --- /dev/null +++ b/plugins/bedrock/type.ts @@ -0,0 +1,99 @@ +export type BedrockBody = { + source: 'INPUT' | 'OUTPUT'; + content: { text: { text: string } }[]; +}; + +type PIIType = + | 'ADDRESS' + | 'AGE' + | 'AWS_ACCESS_KEY' + | 'AWS_SECRET_KEY' + | 'CA_HEALTH_NUMBER' + | 'CA_SOCIAL_INSURANCE_NUMBER' + | 'CREDIT_DEBIT_CARD_CVV' + | 'CREDIT_DEBIT_CARD_EXPIRY' + | 'CREDIT_DEBIT_CARD_NUMBER' + | 'DRIVER_ID' + | 'EMAIL' + | 'INTERNATIONAL_BANK_ACCOUNT_NUMBER' + | 'IP_ADDRESS' + | 'LICENSE_PLATE' + | 'MAC_ADDRESS' + | 'NAME' + | 'PASSWORD' + | 'PHONE' + | 'PIN' + | 'SWIFT_CODE' + | 'UK_NATIONAL_HEALTH_SERVICE_NUMBER' + | 'UK_NATIONAL_INSURANCE_NUMBER' + | 'UK_UNIQUE_TAXPAYER_REFERENCE_NUMBER' + | 'URL' + | 'USERNAME' + | 'US_BANK_ACCOUNT_NUMBER' + | 'US_BANK_ROUTING_NUMBER' + | 'US_INDIVIDUAL_TAX_IDENTIFICATION_NUMBER' + | 'US_PASSPORT_NUMBER' + | 'US_SOCIAL_SECURITY_NUMBER' + | 'VEHICLE_IDENTIFICATION_NUMBER'; + +interface BedrockAction { + action: 'BLOCKED' | T; +} + +interface ContentPolicy extends BedrockAction { + confidence: 'LOW' | 'NONE' | 'MEDIUM' | 'HIGH'; + type: + | 'INSULTS' + | 'HATE' + | 'SEXUAL' + | 'VIOLENCE' + | 'MISCONDUCT' + | 'PROMPT_ATTACK'; + filterStrength: 'LOW' | 'MEDIUM' | 'HIGH'; +} + +interface WordPolicy extends BedrockAction { + match: string; +} + +export interface PIIFilter extends BedrockAction { + match: string; + type: PIIType; +} + +export interface BedrockResponse { + action: 'NONE' | 'GUARDRAIL_INTERVENED'; + assessments: { + wordPolicy: { + customWords: WordPolicy[]; + managedWordLists: (WordPolicy & { type: 'PROFANITY' })[]; + }; + contentPolicy: { filters: ContentPolicy[] }; + sensitiveInformationPolicy: { + piiEntities: PIIFilter<'ANONYMIZED' | 'BLOCKED'>[]; + regexes: (Omit & { + name: string; + regex: string; + })[]; + }; + }[]; + output: { + text: string; + }[]; + usage: { + contentPolicyUnits: number; + sensitiveInformationPolicyUnits: number; + wordPolicyUnits: number; + }; +} + +export interface BedrockParameters { + credentials: { + accessKeyId: string; + accessKeySecret: string; + awsSessionToken?: string; + region: string; + }; + guardrailVersion: string; + guardrailId: string; +} diff --git a/plugins/bedrock/util.ts b/plugins/bedrock/util.ts index 1bc818c48..fcd822755 100644 --- a/plugins/bedrock/util.ts +++ b/plugins/bedrock/util.ts @@ -1,5 +1,13 @@ import { Sha256 } from '@aws-crypto/sha256-js'; import { SignatureV4 } from '@smithy/signature-v4'; +import { + BedrockBody, + BedrockParameters, + BedrockResponse, + PIIFilter, +} from './type'; +import { post } from '../utils'; +import { HookEventType } from '../types'; export const generateAWSHeaders = async ( body: Record, @@ -7,14 +15,14 @@ export const generateAWSHeaders = async ( url: string, method: string, awsService: string, - awsRegion: string, + region: string, awsAccessKeyID: string, awsSecretAccessKey: string, awsSessionToken: string | undefined ): Promise> => { const signer = new SignatureV4({ service: awsService, - region: awsRegion || 'us-east-1', + region: region || 'us-east-1', credentials: { accessKeyId: awsAccessKeyID, secretAccessKey: awsSecretAccessKey, @@ -44,3 +52,108 @@ export const generateAWSHeaders = async ( const signed = await signer.sign(request); return signed.headers; }; + +export const bedrockPost = async ( + credentials: Record, + body: BedrockBody +) => { + const url = `https://bedrock-runtime.${credentials?.region}.amazonaws.com/guardrail/${credentials?.guardrailId}/version/${credentials?.guardrailVersion}/apply`; + + const headers = await generateAWSHeaders( + body, + { + 'Content-Type': 'application/json', + }, + url, + 'POST', + 'bedrock', + credentials?.region ?? 'us-east-1', + credentials?.accessKeyId!, + credentials?.accessKeySecret!, + credentials?.awsSessionToken || '' + ); + + return await post(url, body, { + headers, + method: 'POST', + }); +}; + +const replaceMatches = ( + filter: PIIFilter & { name?: string }, + text: string, + isRegex?: boolean +) => { + // `filter.type` will be for PII, else use name to `mask` text. + return text.replaceAll( + filter.match, + `{${isRegex ? filter.name : filter.type}}` + ); +}; + +/** + * @description Redacts PII information for the text passed by invoking the bedrock endpoint. + * @param text + * @param eventType + * @param credentials + * @returns + */ +export const redactPii = async ( + text: string, + eventType: HookEventType, + credentials: Record +) => { + const body = {} as BedrockBody; + + if (eventType === 'beforeRequestHook') { + body.source = 'INPUT'; + } else { + body.source = 'OUTPUT'; + } + + body.content = [ + { + text: { + text, + }, + }, + ]; + + try { + const response = await bedrockPost({ ...(credentials as any) }, body); + // `ANONYMIZED` means text is already masked by api invokation + const isMasked = + response.assessments[0].sensitiveInformationPolicy.piiEntities?.find( + (entity) => entity.action === 'ANONYMIZED' + ); + + let maskedText = text; + if (isMasked) { + // Use the invoked text directly. + const data = response.output?.[0]; + + maskedText = data?.text; + } else { + // Replace the all entires of each filter sent from api. + response.assessments[0].sensitiveInformationPolicy.piiEntities.forEach( + (filter) => { + maskedText = replaceMatches(filter, maskedText, false); + } + ); + } + + // Replace the all entires of each filter sent from api for regex + const isRegexMatch = + response.assessments[0].sensitiveInformationPolicy?.regexes?.length > 0; + if (isRegexMatch) { + response.assessments[0].sensitiveInformationPolicy.regexes.forEach( + (regex) => { + maskedText = replaceMatches(regex as any, maskedText, true); + } + ); + } + return maskedText; + } catch (e) { + return null; + } +}; diff --git a/plugins/index.ts b/plugins/index.ts index 6a8256960..c29ea98bd 100644 --- a/plugins/index.ts +++ b/plugins/index.ts @@ -34,15 +34,11 @@ import { handler as patronustoxicity } from './patronus/toxicity'; import { handler as patronuscustom } from './patronus/custom'; import { mistralGuardrailHandler } from './mistral'; import { handler as pangeatextGuard } from './pangea/textGuard'; -import { handler as portkeyredactPii } from './portkey/redactPii'; -import { handler as promptfooRedactPii } from './promptfoo/redactPii'; +import { handler as promptfooPii } from './promptfoo/pii'; import { handler as promptfooHarm } from './promptfoo/harm'; import { handler as promptfooGuard } from './promptfoo/guard'; -import { handler as pangearedactPii } from './pangea/redactPii'; -import { handler as patronusredactPii } from './patronus/redactPii'; -import { handler as patronusredactPhi } from './patronus/redactPhi'; +import { handler as pangeapii } from './pangea/pii'; import { pluginHandler as bedrockHandler } from './bedrock/index'; -import { bedrockPIIHandler } from './bedrock/redactPii'; export const plugins = { default: { @@ -67,7 +63,6 @@ export const plugins = { language: portkeylanguage, pii: portkeypii, gibberish: portkeygibberish, - redactPii: portkeyredactPii, }, aporia: { validateProject: aporiavalidateProject, @@ -91,25 +86,20 @@ export const plugins = { retrievalAnswerRelevance: patronusretrievalAnswerRelevance, toxicity: patronustoxicity, custom: patronuscustom, - redactPii: patronusredactPii, - redactPhi: patronusredactPhi, }, mistral: { moderateContent: mistralGuardrailHandler, }, pangea: { textGuard: pangeatextGuard, - redactPii: pangearedactPii, + pii: pangeapii, }, promptfoo: { - redactPii: promptfooRedactPii, + pii: promptfooPii, harm: promptfooHarm, guard: promptfooGuard, }, bedrock: { - pii: bedrockHandler.bind({ fn: 'pii' }), - contentFilter: bedrockHandler.bind({ fn: 'contentFilter' }), - wordFilter: bedrockHandler.bind({ fn: 'wordFilter' }), - redactPii: bedrockPIIHandler, + guard: bedrockHandler, }, }; diff --git a/plugins/pangea/pangea.test.ts b/plugins/pangea/pangea.test.ts index af3f1b271..5d52c3ec3 100644 --- a/plugins/pangea/pangea.test.ts +++ b/plugins/pangea/pangea.test.ts @@ -1,5 +1,7 @@ import { handler as textGuardContentHandler } from './textGuard'; +import { handler as piiHandler } from './pii'; import testCreds from './.creds.json'; +import { HookEventType, PluginContext } from '../types'; const options = { env: {}, @@ -183,3 +185,198 @@ describe('textGuardContentHandler', () => { expect(result.data).toBeNull(); }); }); + +describe('pii handler', () => { + it('should only detect PII', async () => { + const eventType = 'beforeRequestHook' as HookEventType; + const context = { + request: { + text: 'My email is abc@xyz.com and some random text', + json: { + messages: [ + { + role: 'user', + content: 'My email is abc@xyz.com and some random text', + }, + ], + }, + }, + requestType: 'chatComplete', + }; + const parameters = { + credentials: testCreds, + }; + + const result = await piiHandler( + context as PluginContext, + parameters, + eventType, + { + env: {}, + } + ); + expect(result).toBeDefined(); + expect(result.verdict).toBe(false); + expect(result.error).toBeNull(); + expect(result.data).toBeDefined(); + expect(result.transformedData?.request?.json).toBeNull(); + }); + + it('should detect and redact PII in request text', async () => { + const context = { + request: { + text: 'My email is abc@xyz.com and some random text', + json: { + messages: [ + { + role: 'user', + content: 'My email is abc@xyz.com and some random text', + }, + ], + }, + }, + requestType: 'chatComplete', + }; + const parameters = { + redact: true, + credentials: testCreds, + }; + + const result = await piiHandler( + context as PluginContext, + parameters, + 'beforeRequestHook', + { + env: {}, + } + ); + expect(result.error).toBeNull(); + expect(result.verdict).toBe(false); + expect(result.data).toBeDefined(); + expect(result.transformedData?.request?.json?.messages?.[0]?.content).toBe( + 'My email is and some random text' + ); + }); + + it('should detect and redact PII in request text with multiple content parts', async () => { + const context = { + request: { + text: 'My email is abc@xyz.com My email is abc@xyz.com and some random text', + json: { + messages: [ + { + role: 'user', + content: [ + { + type: 'text', + text: 'My email is abc@xyz.com', + }, + { + type: 'text', + text: 'My email is abc@xyz.com and some random text', + }, + ], + }, + ], + }, + }, + requestType: 'chatComplete', + }; + const parameters = { + redact: true, + credentials: testCreds, + }; + + const result = await piiHandler( + context as PluginContext, + parameters, + 'beforeRequestHook', + { + env: {}, + } + ); + + expect(result.error).toBeNull(); + expect(result.verdict).toBe(false); + expect(result.data).toBeDefined; + expect( + result.transformedData?.request?.json?.messages?.[0]?.content?.[0]?.text + ).toBe('My email is '); + expect( + result.transformedData?.request?.json?.messages?.[0]?.content?.[1]?.text + ).toBe('My email is and some random text'); + }); + + it('should detect and redact PII in response text', async () => { + const context = { + response: { + text: 'My email is abc@xyz.com and some random text', + json: { + choices: [ + { + message: { + role: 'assistant', + content: 'My email is abc@xyz.com and some random text', + }, + }, + ], + }, + }, + requestType: 'chatComplete', + }; + const parameters = { + redact: true, + credentials: testCreds, + }; + + const result = await piiHandler( + context as PluginContext, + parameters, + 'afterRequestHook', + { + env: {}, + } + ); + + expect(result.error).toBeNull(); + expect(result.verdict).toBe(false); + expect(result.data).toBeDefined(); + expect( + result.transformedData?.response?.json?.choices?.[0]?.message?.content + ).toBe('My email is and some random text'); + }); + + it('should pass text without PII', async () => { + const eventType = 'beforeRequestHook' as HookEventType; + const context = { + request: { + text: 'Hello world', + json: { + messages: [ + { + role: 'assistant', + content: 'Hello world', + }, + ], + }, + }, + requestType: 'chatComplete', + }; + const parameters = { + credentials: testCreds, + }; + + const result = await piiHandler( + context as PluginContext, + parameters, + eventType, + { + env: {}, + } + ); + expect(result).toBeDefined(); + expect(result.verdict).toBe(true); + expect(result.error).toBeNull(); + expect(result.data).toBeDefined(); + }); +}); diff --git a/plugins/pangea/redactPii.ts b/plugins/pangea/pii.ts similarity index 89% rename from plugins/pangea/redactPii.ts rename to plugins/pangea/pii.ts index 9857f8b6b..e20432566 100644 --- a/plugins/pangea/redactPii.ts +++ b/plugins/pangea/pii.ts @@ -20,6 +20,7 @@ export const handler: PluginHandler = async ( json: null, }, }; + const redact = parameters.redact || false; try { if (!parameters.credentials?.domain) { @@ -66,8 +67,12 @@ export const handler: PluginHandler = async ( }; const response = await post(url, request, requestOptions); + const piiDetected = + response.result?.count > 0 && response.result.redacted_data + ? true + : false; - if (response.result?.count > 0 && response.result.redacted_data) { + if (piiDetected && redact) { setCurrentContentPart( context, eventType, @@ -78,7 +83,7 @@ export const handler: PluginHandler = async ( return { error: null, - verdict: true, + verdict: !piiDetected, data: { summary: response.summary, }, diff --git a/plugins/patronus/patronus.test.ts b/plugins/patronus/patronus.test.ts index a50dc5ead..46054d381 100644 --- a/plugins/patronus/patronus.test.ts +++ b/plugins/patronus/patronus.test.ts @@ -4,112 +4,384 @@ import { handler as piiHandler } from './pii'; import { handler as toxicityHandler } from './toxicity'; import { handler as retrievalAnswerRelevanceHandler } from './retrievalAnswerRelevance'; import { handler as customHandler } from './custom'; +import { HookEventType, PluginContext } from '../types'; describe('phi handler', () => { - it('should fail if beforeRequestHook is used', async () => { + it('should pass when text is clean', async () => { const eventType = 'beforeRequestHook'; const context = { - request: { text: 'this is a test string for moderations' }, + request: { + text: 'this is a test string for moderations', + json: { + messages: [ + { + role: 'user', + content: 'this is a test string for moderations', + }, + ], + }, + }, + requestType: 'chatComplete', }; + const parameters = { credentials: testCreds }; - const result = await phiHandler(context, parameters, eventType); - // console.log(result); + const result = await phiHandler( + context as PluginContext, + parameters, + eventType + ); expect(result).toBeDefined(); - expect(result.error).toBeDefined(); - expect(result.data).toBeNull(); + expect(result.verdict).toBe(true); + expect(result.error).toBeNull(); + expect(result.data).toBeDefined(); }); - it('should pass when text is clean', async () => { - const eventType = 'afterRequestHook'; + it('should fail when text contains PHI', async () => { + const eventType = 'beforeRequestHook'; const context = { - request: { text: 'this is a test string for moderations' }, - response: { text: 'this is a test string for moderations' }, + request: { + text: 'John Doe has a history of heart disease', + json: { + messages: [ + { + role: 'user', + content: 'John Doe has a history of heart disease', + }, + ], + }, + }, + requestType: 'chatComplete', }; const parameters = { credentials: testCreds }; - const result = await phiHandler(context, parameters, eventType); - // console.log(result); + const result = await phiHandler( + context as PluginContext, + parameters, + eventType + ); expect(result).toBeDefined(); - expect(result.verdict).toBe(true); + expect(result.verdict).toBe(false); expect(result.error).toBeNull(); expect(result.data).toBeDefined(); + expect(result.transformedData?.response?.json).toBeNull(); + expect(result.transformedData?.request?.json).toBeNull(); }); - it('should fail when text contains PHI', async () => { - const eventType = 'afterRequestHook'; + it('should detect and redact PII in request text', async () => { const context = { request: { - text: `Your hospital's patient - John Doe. What is he in for?`, + text: 'John Doe has a history of heart disease', + json: { + messages: [ + { + role: 'user', + content: 'John Doe has a history of heart disease', + }, + ], + }, }, - response: { - text: 'John Doe is in the hospital for a bad case of carpal tunnel.', + requestType: 'chatComplete', + }; + const parameters = { + redact: true, + credentials: testCreds, + }; + + const result = await phiHandler( + context as PluginContext, + parameters, + 'beforeRequestHook', + { + env: {}, + } + ); + + expect(result.error).toBeNull(); + expect(result.verdict).toBe(false); + expect(result.data).toBeDefined(); + expect(result.transformedData?.request?.json?.messages?.[0]?.content).toBe( + 'J******e has a history of heart disease' + ); + }); + + it('should detect and redact PII in request text with multiple content parts', async () => { + const context = { + request: { + text: 'John Doe has a history of heart disease John Doe has a history of heart disease and some random text', + json: { + messages: [ + { + role: 'user', + content: [ + { + type: 'text', + text: 'John Doe has a history of heart disease', + }, + { + type: 'text', + text: 'John Doe has a history of heart disease and some random text', + }, + ], + }, + ], + }, }, + requestType: 'chatComplete', + }; + const parameters = { + redact: true, + credentials: testCreds, }; - const parameters = { credentials: testCreds }; + const result = await phiHandler( + context as PluginContext, + parameters, + 'beforeRequestHook', + { + env: {}, + } + ); - const result = await phiHandler(context, parameters, eventType); - // console.log(result); - expect(result).toBeDefined(); + expect(result.error).toBeNull(); expect(result.verdict).toBe(false); + expect(result.data).toBeDefined; + expect( + result.transformedData?.request?.json?.messages?.[0]?.content?.[0]?.text + ).toBe('J******e has a history of heart disease'); + expect( + result.transformedData?.request?.json?.messages?.[0]?.content?.[1]?.text + ).toBe('J******e has a history of heart disease and some random text'); + }); + + it('should detect and redact PHI in response text', async () => { + const context = { + response: { + text: 'John Doe has a history of heart disease and some random text', + json: { + choices: [ + { + message: { + role: 'assistant', + content: + 'John Doe has a history of heart disease and some random text', + }, + }, + ], + }, + }, + requestType: 'chatComplete', + }; + const parameters = { + redact: true, + credentials: testCreds, + }; + + const result = await phiHandler( + context as PluginContext, + parameters, + 'afterRequestHook', + { + env: {}, + } + ); + expect(result.error).toBeNull(); + expect(result.verdict).toBe(false); expect(result.data).toBeDefined(); + expect( + result.transformedData?.response?.json?.choices?.[0]?.message?.content + ).toBe('J******e has a history of heart disease and some random text'); }); }); describe('pii handler', () => { - it('should fail if beforeRequestHook is used', async () => { + it('should pass when text is clean', async () => { const eventType = 'beforeRequestHook'; const context = { - request: { text: 'this is a test string for moderations' }, + request: { + text: 'this is a test string for moderations', + json: { + messages: [ + { + role: 'user', + content: 'this is a test string for moderations', + }, + ], + }, + }, + requestType: 'chatComplete', }; + const parameters = { credentials: testCreds }; - const result = await piiHandler(context, parameters, eventType); - // console.log(result); + const result = await piiHandler( + context as PluginContext, + parameters, + eventType + ); expect(result).toBeDefined(); - expect(result.error).toBeDefined(); - expect(result.data).toBeNull(); + expect(result.verdict).toBe(true); + expect(result.error).toBeNull(); + expect(result.data).toBeDefined(); }); - it('should pass when text is clean', async () => { - const eventType = 'afterRequestHook'; + it('should fail when text contains PII', async () => { + const eventType = 'beforeRequestHook'; const context = { - request: { text: 'this is a test string for moderations' }, - response: { text: 'this is a test string for moderations' }, + request: { + text: 'My email is abc@xyz.com and some random text', + json: { + messages: [ + { + role: 'user', + content: 'My email is abc@xyz.com and some random text', + }, + ], + }, + }, + requestType: 'chatComplete', }; const parameters = { credentials: testCreds }; - const result = await piiHandler(context, parameters, eventType); - // console.log(result); + const result = await piiHandler( + context as PluginContext, + parameters, + eventType + ); expect(result).toBeDefined(); - expect(result.verdict).toBe(true); + expect(result.verdict).toBe(false); expect(result.error).toBeNull(); expect(result.data).toBeDefined(); + expect(result.transformedData?.response?.json).toBeNull(); + expect(result.transformedData?.request?.json).toBeNull(); }); - it('should fail when text contains PII', async () => { - const eventType = 'afterRequestHook'; + it('should detect and redact PII in request text', async () => { const context = { request: { - text: `Your hospital's patient - John Doe. What is he in for?`, + text: 'My email is abc@xyz.com and some random text', + json: { + messages: [ + { + role: 'user', + content: 'My email is abc@xyz.com and some random text', + }, + ], + }, }, - response: { - text: `Sure! Happy to provide the SSN of John Doe - it's 123-45-6789.`, + requestType: 'chatComplete', + }; + const parameters = { + redact: true, + credentials: testCreds, + }; + + const result = await piiHandler( + context as PluginContext, + parameters, + 'beforeRequestHook', + { + env: {}, + } + ); + + expect(result.error).toBeNull(); + expect(result.verdict).toBe(false); + expect(result.data).toBeDefined(); + expect(result.transformedData?.request?.json?.messages?.[0]?.content).toBe( + 'My email is a*********m and some random text' + ); + }); + + it('should detect and redact PII in request text with multiple content parts', async () => { + const context = { + request: { + text: 'My email is abc@xyz.com My email is abc@xyz.com and some random text', + json: { + messages: [ + { + role: 'user', + content: [ + { + type: 'text', + text: 'My email is abc@xyz.com', + }, + { + type: 'text', + text: 'My email is abc@xyz.com and some random text', + }, + ], + }, + ], + }, }, + requestType: 'chatComplete', + }; + const parameters = { + redact: true, + credentials: testCreds, }; - const parameters = { credentials: testCreds }; + const result = await piiHandler( + context as PluginContext, + parameters, + 'beforeRequestHook', + { + env: {}, + } + ); - const result = await piiHandler(context, parameters, eventType); - // console.log(result); - expect(result).toBeDefined(); + expect(result.error).toBeNull(); expect(result.verdict).toBe(false); + expect(result.data).toBeDefined; + expect( + result.transformedData?.request?.json?.messages?.[0]?.content?.[0]?.text + ).toBe('My email is a*********m'); + expect( + result.transformedData?.request?.json?.messages?.[0]?.content?.[1]?.text + ).toBe('My email is a*********m and some random text'); + }); + + it('should detect and redact PII in response text', async () => { + const context = { + response: { + text: 'My email is abc@xyz.com and some random text', + json: { + choices: [ + { + message: { + role: 'assistant', + content: 'My email is abc@xyz.com and some random text', + }, + }, + ], + }, + }, + requestType: 'chatComplete', + }; + const parameters = { + redact: true, + credentials: testCreds, + }; + + const result = await piiHandler( + context as PluginContext, + parameters, + 'afterRequestHook', + { + env: {}, + } + ); + expect(result.error).toBeNull(); + expect(result.verdict).toBe(false); expect(result.data).toBeDefined(); + expect( + result.transformedData?.response?.json?.choices?.[0]?.message?.content + ).toBe('My email is a*********m and some random text'); }); }); @@ -122,7 +394,6 @@ describe('toxicity handler', () => { const parameters = { credentials: testCreds }; const result = await toxicityHandler(context, parameters, eventType); - // console.log(result); expect(result).toBeDefined(); expect(result.error).toBeDefined(); expect(result.data).toBeNull(); @@ -138,7 +409,6 @@ describe('toxicity handler', () => { const parameters = { credentials: testCreds }; const result = await toxicityHandler(context, parameters, eventType); - console.log(result); expect(result).toBeDefined(); expect(result.verdict).toBe(true); expect(result.error).toBeNull(); @@ -155,7 +425,6 @@ describe('toxicity handler', () => { const parameters = { credentials: testCreds }; const result = await toxicityHandler(context, parameters, eventType); - console.log(result); expect(result).toBeDefined(); expect(result.verdict).toBe(false); expect(result.error).toBeNull(); @@ -176,7 +445,6 @@ describe('retrieval answer relevance handler', () => { parameters, eventType ); - // console.log(result); expect(result).toBeDefined(); expect(result.error).toBeDefined(); expect(result.data).toBeNull(); @@ -198,7 +466,6 @@ describe('retrieval answer relevance handler', () => { parameters, eventType ); - console.log(result); expect(result).toBeDefined(); expect(result.verdict).toBe(true); expect(result.error).toBeNull(); @@ -219,7 +486,6 @@ describe('retrieval answer relevance handler', () => { parameters, eventType ); - console.log(result); expect(result).toBeDefined(); expect(result.verdict).toBe(false); expect(result.error).toBeNull(); @@ -243,7 +509,6 @@ describe('custom handler (is-concise)', () => { }; const result = await customHandler(context, parameters, eventType); - console.log(result); expect(result).toBeDefined(); expect(result.verdict).toBe(true); expect(result.error).toBeNull(); @@ -265,7 +530,6 @@ describe('custom handler (is-concise)', () => { }; const result = await customHandler(context, parameters, eventType); - console.log(result); expect(result).toBeDefined(); expect(result.verdict).toBe(false); expect(result.error).toBeNull(); diff --git a/plugins/patronus/phi.ts b/plugins/patronus/phi.ts index 2c62e7882..aa6033896 100644 --- a/plugins/patronus/phi.ts +++ b/plugins/patronus/phi.ts @@ -4,7 +4,38 @@ import { PluginHandler, PluginParameters, } from '../types'; -import { postPatronus } from './globals'; +import { getCurrentContentPart, setCurrentContentPart } from '../utils'; +import { findAllLongestPositions, postPatronus } from './globals'; +import { maskEntities } from './pii'; + +const redactPhi = async (text: string, credentials: any) => { + if (!text) return { maskedText: null, data: null }; + const evaluator = 'phi'; + + const evaluationBody: any = { + output: text, + }; + + const result: any = await postPatronus( + evaluator, + credentials, + evaluationBody + ); + const evalResult = result.results[0]; + + const positionsData = evalResult.evaluation_result.additional_info; + if ( + Array.isArray(positionsData?.positions) && + positionsData.positions.length > 0 + ) { + const longestPosition = findAllLongestPositions(positionsData); + if (longestPosition?.positions && longestPosition.positions.length > 0) { + const maskedText = maskEntities(text, longestPosition.positions); + return { maskedText, data: result.results[0] }; + } + } + return { maskedText: null, data: result.results[0] }; +}; export const handler: PluginHandler = async ( context: PluginContext, @@ -14,42 +45,59 @@ export const handler: PluginHandler = async ( let error = null; let verdict = false; let data = null; + const transformedData: Record = { + request: { + json: null, + }, + response: { + json: null, + }, + }; - const evaluator = 'phi'; + try { + const { content, textArray } = getCurrentContentPart(context, eventType); - if (eventType !== 'afterRequestHook') { - return { - error: { - message: 'Patronus guardrails only support after_request_hooks.', - }, - verdict: true, - data, - }; - } + if (!content) { + return { + error: { message: 'request or response json is empty' }, + verdict: true, + data: null, + transformedData, + }; + } - const evaluationBody: any = { - input: context.request.text, - output: context.response.text, - }; + const results = await Promise.all( + textArray.map((text) => redactPhi(text, parameters.credentials)) + ); - try { - const result: any = await postPatronus( - evaluator, - parameters.credentials, - evaluationBody + const hasPHI = results.some( + (result) => result?.data?.evaluation_result?.pass === false ); + const phiData = + results.find((result) => result?.maskedText)?.data ?? results[0]?.data; + error = + results.find((result) => result?.data?.error_message)?.data + ?.error_message || null; + error = phiData.error_message; - const evalResult = result.results[0]; - error = evalResult.error_message; + if (parameters?.redact && hasPHI) { + const maskedTexts = results.map((result) => result?.maskedText ?? null); + setCurrentContentPart( + context, + eventType, + transformedData, + null, + maskedTexts + ); + } // verdict can be true/false - verdict = evalResult.evaluation_result.pass; - - data = evalResult.evaluation_result.additional_info; + verdict = !hasPHI; + data = phiData.evaluation_result.additional_info; } catch (e: any) { delete e.stack; error = e; } - return { error, verdict, data }; + return { error, verdict, data, transformedData }; }; diff --git a/plugins/patronus/pii.ts b/plugins/patronus/pii.ts index 4495788b5..725596c34 100644 --- a/plugins/patronus/pii.ts +++ b/plugins/patronus/pii.ts @@ -4,7 +4,74 @@ import { PluginHandler, PluginParameters, } from '../types'; -import { postPatronus } from './globals'; +import { getCurrentContentPart, setCurrentContentPart } from '../utils'; +import { findAllLongestPositions, postPatronus } from './globals'; + +export function maskEntities( + text: string, + positions: [number, number][] +): string { + if (!text || !positions.length) return text; + + let result = ''; + let lastIndex = 0; + + // Sort positions by start index to handle them in order + positions.sort((a, b) => a[0] - b[0]); + + for (const [start, end] of positions) { + // Add text before the masked section + result += text.slice(lastIndex, start); + + // Get the section to be masked + const section = text.slice(start, end); + if (section.length <= 4) { + // If section is 4 chars or less, mask everything + result += '*'.repeat(section.length); + } else { + // Keep first 2 and last 2 chars, mask the rest + result += + section.slice(0, 1) + + '*'.repeat(section.length - 2) + + section.slice(-1); + } + lastIndex = end; + } + + // Add any remaining text after the last masked section + result += text.slice(lastIndex); + + return result; +} + +const redactPii = async (text: string, credentials: any) => { + if (!text) return { maskedText: null, data: null }; + const evaluator = 'pii'; + const evaluationBody: any = { + output: text, + }; + + const result: any = await postPatronus( + evaluator, + credentials, + evaluationBody + ); + + const evalResult = result.results[0]; + + const positionsData = evalResult.evaluation_result.additional_info; + if ( + Array.isArray(positionsData?.positions) && + positionsData.positions.length > 0 + ) { + const longestPosition = findAllLongestPositions(positionsData); + if (longestPosition?.positions && longestPosition.positions.length > 0) { + const maskedText = maskEntities(text, longestPosition.positions); + return { maskedText, data: result.results[0] }; + } + } + return { maskedText: null, data: result.results[0] }; +}; export const handler: PluginHandler = async ( context: PluginContext, @@ -14,42 +81,59 @@ export const handler: PluginHandler = async ( let error = null; let verdict = false; let data = null; + const transformedData: Record = { + request: { + json: null, + }, + response: { + json: null, + }, + }; - const evaluator = 'pii'; + try { + const { content, textArray } = getCurrentContentPart(context, eventType); - if (eventType !== 'afterRequestHook') { - return { - error: { - message: 'Patronus guardrails only support after_request_hooks.', - }, - verdict: true, - data, - }; - } + if (!content) { + return { + error: { message: 'request or response json is empty' }, + verdict: true, + data: null, + transformedData, + }; + } - const evaluationBody: any = { - input: context.request.text, - output: context.response.text, - }; + const results = await Promise.all( + textArray.map((text) => redactPii(text, parameters.credentials)) + ); - try { - const result: any = await postPatronus( - evaluator, - parameters.credentials, - evaluationBody + const hasPII = results.some( + (result) => result?.data?.evaluation_result.pass === false ); - const evalResult = result.results[0]; - error = evalResult.error_message; + const piiData = + results.find((result) => result?.maskedText)?.data ?? results[0]?.data; + error = + results.find((result) => result?.data?.error_message)?.data + ?.error_message || null; - // verdict can be true/false - verdict = evalResult.evaluation_result.pass; + if (parameters?.redact && hasPII) { + const maskedTexts = results.map((result) => result?.maskedText ?? null); + setCurrentContentPart( + context, + eventType, + transformedData, + null, + maskedTexts + ); + } - data = evalResult.evaluation_result.additional_info; + // verdict can be true/false + verdict = !hasPII; + data = piiData?.evaluation_result?.additional_info; } catch (e: any) { delete e.stack; error = e; } - return { error, verdict, data }; + return { error, verdict, data, transformedData }; }; diff --git a/plugins/patronus/redactPhi.ts b/plugins/patronus/redactPhi.ts deleted file mode 100644 index 92b1ab5ca..000000000 --- a/plugins/patronus/redactPhi.ts +++ /dev/null @@ -1,94 +0,0 @@ -import { - HookEventType, - PluginContext, - PluginHandler, - PluginParameters, -} from '../types'; -import { getCurrentContentPart, setCurrentContentPart } from '../utils'; -import { findAllLongestPositions, postPatronus } from './globals'; -import { maskEntities } from './redactPii'; - -const redactPhi = async (text: string, credentials: any) => { - const evaluator = 'phi'; - - const evaluationBody: any = { - output: text, - }; - - const result: any = await postPatronus( - evaluator, - credentials, - evaluationBody - ); - const evalResult = result.results[0]; - - const positionsData = evalResult.evaluation_result.additional_info; - if ( - Array.isArray(positionsData?.positions) && - positionsData.positions.length > 0 - ) { - const longestPosition = findAllLongestPositions(positionsData); - if (longestPosition?.positions && longestPosition.positions.length > 0) { - const maskedText = maskEntities(text, longestPosition.positions); - return maskedText; - } - } - return text; -}; - -export const handler: PluginHandler = async ( - context: PluginContext, - parameters: PluginParameters, - eventType: HookEventType -) => { - const transformedData: Record = { - request: { - json: null, - }, - response: { - json: null, - }, - }; - - try { - const { content, textArray } = getCurrentContentPart(context, eventType); - - if (!content) { - return { - error: { message: 'request or response json is empty' }, - verdict: true, - data: null, - transformedData, - }; - } - - const transformedTextPromise = textArray.map((text) => - redactPhi(text, parameters.credentials) - ); - - const transformedText = await Promise.all(transformedTextPromise); - - setCurrentContentPart( - context, - eventType, - transformedData, - null, - transformedText - ); - - return { - error: null, - verdict: true, - data: null, - transformedData, - }; - } catch (e: any) { - delete e.stack; - return { - error: e as Error, - verdict: true, - data: null, - transformedData, - }; - } -}; diff --git a/plugins/patronus/redactPii.ts b/plugins/patronus/redactPii.ts deleted file mode 100644 index d4590182a..000000000 --- a/plugins/patronus/redactPii.ts +++ /dev/null @@ -1,133 +0,0 @@ -import { - HookEventType, - PluginContext, - PluginHandler, - PluginParameters, -} from '../types'; -import { getCurrentContentPart, setCurrentContentPart } from '../utils'; -import { findAllLongestPositions, postPatronus } from './globals'; - -export function maskEntities( - text: string, - positions: [number, number][] -): string { - if (!text || !positions.length) return text; - - let result = ''; - let lastIndex = 0; - - // Sort positions by start index to handle them in order - positions.sort((a, b) => a[0] - b[0]); - - for (const [start, end] of positions) { - // Add text before the masked section - result += text.slice(lastIndex, start); - - // Get the section to be masked - const section = text.slice(start, end); - if (section.length <= 4) { - // If section is 4 chars or less, mask everything - result += '*'.repeat(section.length); - } else { - // Keep first 2 and last 2 chars, mask the rest - result += - section.slice(0, 1) + - '*'.repeat(section.length - 2) + - section.slice(-1); - } - lastIndex = end; - } - - // Add any remaining text after the last masked section - result += text.slice(lastIndex); - - return result; -} - -const redactPii = async (text: string, credentials: any) => { - const evaluator = 'pii'; - const evaluationBody: any = { - output: text, - }; - - const result: any = await postPatronus( - evaluator, - credentials, - evaluationBody - ); - - const evalResult = result.results[0]; - - const positionsData = evalResult.evaluation_result.additional_info; - if ( - Array.isArray(positionsData?.positions) && - positionsData.positions.length > 0 - ) { - const longestPosition = findAllLongestPositions(positionsData); - if (longestPosition?.positions && longestPosition.positions.length > 0) { - const maskedText = maskEntities(text, longestPosition.positions); - return maskedText; - } - } - return null; -}; - -export const handler: PluginHandler = async ( - context: PluginContext, - parameters: PluginParameters, - eventType: HookEventType -) => { - const transformedData: Record = { - request: { - json: null, - }, - response: { - json: null, - }, - }; - - try { - const { content, textArray } = getCurrentContentPart(context, eventType); - - if (!content) { - return { - error: { message: 'request or response json is empty' }, - verdict: true, - data: null, - transformedData, - }; - } - - const transformedTextPromise = textArray.map((text) => - redactPii(text, parameters.credentials) - ); - - const transformedText = await Promise.all(transformedTextPromise); - - setCurrentContentPart( - context, - eventType, - transformedData, - null, - transformedText - ); - - return { - error: null, - verdict: true, - data: - transformedText.filter((text) => text !== null).length > 0 - ? { flagged: true } - : null, - transformedData, - }; - } catch (e: any) { - delete e.stack; - return { - error: e as Error, - verdict: true, - data: null, - transformedData, - }; - } -}; diff --git a/plugins/portkey/globals.ts b/plugins/portkey/globals.ts index 6a7d4af22..773e11afd 100644 --- a/plugins/portkey/globals.ts +++ b/plugins/portkey/globals.ts @@ -1,7 +1,8 @@ import { getRuntimeKey } from 'hono/adapter'; import { post, postWithCloudflareServiceBinding } from '../utils'; +import { PluginParameters } from '../types'; -export const BASE_URL = 'https://api.portkey.ai/v1/execute-guardrails'; +export const BASE_URL = 'https://api.portkeydev.com/v1/execute-guardrails'; export const PORTKEY_ENDPOINTS = { MODERATIONS: '/moderations', @@ -10,6 +11,29 @@ export const PORTKEY_ENDPOINTS = { GIBBERISH: '/gibberish', }; +interface PIIEntity { + text: string; + labels: Record; +} + +export interface PIIResponse { + entities: PIIEntity[]; + processed_text: string; +} + +export interface PIIResult { + detectedPIICategories: string[]; + PIIData: PIIEntity[]; + redactedText: string; +} + +interface PIIParameters extends PluginParameters { + categories: string[]; + credentials: Record; + not?: boolean; + redact?: boolean; +} + export const fetchPortkey = async ( env: Record, endpoint: string, diff --git a/plugins/portkey/pii.ts b/plugins/portkey/pii.ts index d52320687..baad192de 100644 --- a/plugins/portkey/pii.ts +++ b/plugins/portkey/pii.ts @@ -4,52 +4,43 @@ import { PluginHandler, PluginParameters, } from '../types'; -import { getText } from '../utils'; -import { PORTKEY_ENDPOINTS, fetchPortkey } from './globals'; +import { + getCurrentContentPart, + getText, + setCurrentContentPart, +} from '../utils'; +import { + PIIResponse, + PIIResult, + PORTKEY_ENDPOINTS, + fetchPortkey, +} from './globals'; export async function detectPII( textArray: Array | string, - credentials: any, + parameters: any, env: Record -): Promise< - Array<{ - detectedPIICategories: Array; - PIIData: Array; - redactedText: string; - }> -> { - const result = await fetchPortkey(env, PORTKEY_ENDPOINTS.PII, credentials, { - input: textArray, - }); - const mappedResult: Array = []; - - result.forEach((item: any) => { - // Identify all the PII categories in the text - let detectedPIICategories = item.entities - .map((entity: any) => { - return Object.keys(entity.labels); - }) - .flat() - .filter((value: any, index: any, self: string | any[]) => { - return self.indexOf(value) === index; - }); - - // Generate the detailed data to be sent along with detectedPIICategories - let detailedData = item.entities.map((entity: any) => { - return { - text: entity.text, - labels: entity.labels, - }; - }); - - mappedResult.push({ - detectedPIICategories, - PIIData: detailedData, - redactedText: item.processed_text, - }); - }); +): Promise> { + const result: PIIResponse[] = await fetchPortkey( + env, + PORTKEY_ENDPOINTS.PII, + parameters.credentials, + { + input: textArray, + ...(parameters.categories && { categories: parameters.categories }), + } + ); - return mappedResult; + return result.map((item) => ({ + detectedPIICategories: [ + ...new Set(item.entities.flatMap((entity) => Object.keys(entity.labels))), + ], + PIIData: item.entities.map((entity) => ({ + text: entity.text, + labels: entity.labels, + })), + redactedText: item.processed_text, + })); } export const handler: PluginHandler = async ( @@ -61,24 +52,85 @@ export const handler: PluginHandler = async ( let error = null; let verdict = false; let data: any = null; + let transformedData: Record = { + request: { + json: null, + }, + response: { + json: null, + }, + }; try { - const text = getText(context, eventType); - const categoriesToCheck = parameters.categories; + const { content, textArray } = getCurrentContentPart(context, eventType); + const textExcerpt = textArray.filter((text) => text).join('\n'); + + if (!content) { + return { + error: { message: 'request or response json is empty' }, + verdict: true, + data: null, + transformedData, + }; + } + + if (!parameters.categories?.length) { + return { + error: { message: 'No PII categories are configured' }, + verdict: true, + data: null, + transformedData, + }; + } + + if (!parameters.credentials) { + return { + error: { message: 'Credentials not found' }, + verdict: true, + data: null, + transformedData, + }; + } + + let mappedResult = await detectPII( + textArray, + parameters, + options?.env || {} + ); + + const categoriesToCheck = parameters.categories || []; const not = parameters.not || false; - let { detectedPIICategories, PIIData } = - ( - await detectPII(text, parameters.credentials, options?.env || {}) - )?.[0] || {}; + let detectedCategories: any = new Set(); + const mappedTextArray: Array = []; + mappedResult.forEach((result) => { + if (result.detectedPIICategories.length > 0 && result.redactedText) { + result.detectedPIICategories.forEach((category) => + detectedCategories.add(category) + ); + mappedTextArray.push(result.redactedText); + } else { + mappedTextArray.push(null); + } + }); - let filteredCategories = detectedPIICategories.filter((category: string) => + detectedCategories = [...detectedCategories]; + let filteredCategories = detectedCategories.filter((category: string) => categoriesToCheck.includes(category) ); const hasPII = filteredCategories.length > 0; - verdict = not ? !hasPII : !hasPII; + if (parameters.redact && hasPII) { + setCurrentContentPart( + context, + eventType, + transformedData, + null, + mappedTextArray + ); + } + verdict = not ? hasPII : !hasPII; data = { verdict, not, @@ -89,10 +141,13 @@ export const handler: PluginHandler = async ( : not ? 'No PII was found in the text when it should have been.' : `Found restricted PII in the text: ${filteredCategories.join(', ')}`, - detectedPII: PIIData, + // detectedPII: PIIData, restrictedCategories: categoriesToCheck, - detectedCategories: detectedPIICategories, - textExcerpt: text.length > 100 ? text.slice(0, 100) + '...' : text, + detectedCategories: detectedCategories, + textExcerpt: + textExcerpt.length > 100 + ? textExcerpt.slice(0, 100) + '...' + : textExcerpt, }; } catch (e) { error = e as Error; @@ -109,5 +164,5 @@ export const handler: PluginHandler = async ( }; } - return { error, verdict, data }; + return { error, verdict, data, transformedData }; }; diff --git a/plugins/portkey/portkey.test.ts b/plugins/portkey/portkey.test.ts index c74352033..735935b2b 100644 --- a/plugins/portkey/portkey.test.ts +++ b/plugins/portkey/portkey.test.ts @@ -3,6 +3,7 @@ import { handler as piiHandler } from './pii'; import { handler as languageHandler } from './language'; import { handler as gibberishHandler } from './gibberish'; import testCreds from './.creds.json'; +import { PluginContext } from '../types'; describe('moderateContentHandler', () => { const mockOptions = { env: {} }; @@ -97,11 +98,21 @@ describe('moderateContentHandler', () => { describe('piiHandler', () => { const mockOptions = { env: {} }; - it('should detect PII in text', async () => { + it('should only detect PII in text', async () => { const context = { request: { - text: 'My credit card number is 0123 0123 0123 0123, and I live in Wilmington, Delaware', + text: 'My credit card number is 0123 0123 0123 0123, and my email is abc@xyz.com', + json: { + messages: [ + { + role: 'user', + content: + 'My credit card number is 0123 0123 0123 0123, and my email is abc@xyz.com', + }, + ], + }, }, + requestType: 'chatComplete', }; const parameters = { categories: ['CREDIT_CARD', 'LOCATION_ADDRESS'], @@ -110,7 +121,7 @@ describe('piiHandler', () => { }; const result = await piiHandler( - context, + context as PluginContext, parameters, 'beforeRequestHook', mockOptions @@ -124,6 +135,151 @@ describe('piiHandler', () => { explanation: expect.stringContaining('Found restricted PII'), restrictedCategories: ['CREDIT_CARD', 'LOCATION_ADDRESS'], }); + expect(result.transformedData?.request?.json).toBeNull(); + }); + + it('should detect and redact PII in request text', async () => { + const context = { + request: { + text: 'My credit card number is 0123 0123 0123 0123, and my email is abc@xyz.com', + json: { + messages: [ + { + role: 'user', + content: + 'My credit card number is 0123 0123 0123 0123, and my email is abc@xyz.com', + }, + ], + }, + }, + requestType: 'chatComplete', + }; + const parameters = { + categories: ['CREDIT_CARD', 'EMAIL_ADDRESS'], + credentials: testCreds, + redact: true, + not: false, + }; + + const result = await piiHandler( + context as PluginContext, + parameters, + 'beforeRequestHook', + mockOptions + ); + + expect(result.error).toBeNull(); + expect(result.verdict).toBe(false); + expect(result.data).toMatchObject({ + verdict: false, + not: false, + explanation: expect.stringContaining('Found restricted PII'), + restrictedCategories: ['CREDIT_CARD', 'EMAIL_ADDRESS'], + }); + expect(result.transformedData?.request?.json?.messages?.[0]?.content).toBe( + 'My credit card number is [CREDIT_CARD_1], and my email is [EMAIL_ADDRESS_1]' + ); + }); + + it('should detect and redact PII in request text with multiple content parts', async () => { + const context = { + request: { + text: 'My credit card number is 0123 0123 0123 0123, and my email is abc@xyz.com', + json: { + messages: [ + { + role: 'user', + content: [ + { + type: 'text', + text: 'My credit card number is 0123 0123 0123 0123,', + }, + { + type: 'text', + text: 'and my email is abc@xyz.com', + }, + ], + }, + ], + }, + }, + requestType: 'chatComplete', + }; + const parameters = { + categories: ['CREDIT_CARD', 'EMAIL_ADDRESS'], + credentials: testCreds, + redact: true, + not: false, + }; + + const result = await piiHandler( + context as PluginContext, + parameters, + 'beforeRequestHook', + mockOptions + ); + + expect(result.error).toBeNull(); + expect(result.verdict).toBe(false); + expect(result.data).toMatchObject({ + verdict: false, + not: false, + explanation: expect.stringContaining('Found restricted PII'), + restrictedCategories: ['CREDIT_CARD', 'EMAIL_ADDRESS'], + }); + expect( + result.transformedData?.request?.json?.messages?.[0]?.content?.[0]?.text + ).toBe('My credit card number is [CREDIT_CARD_1],'); + expect( + result.transformedData?.request?.json?.messages?.[0]?.content?.[1]?.text + ).toBe('and my email is [EMAIL_ADDRESS_1]'); + }); + + it('should detect and redact PII in response text', async () => { + const context = { + response: { + text: 'My credit card number is 0123 0123 0123 0123, and my email is abc@xyz.com', + json: { + choices: [ + { + message: { + role: 'assistant', + content: + 'My credit card number is 0123 0123 0123 0123, and my email is abc@xyz.com', + }, + }, + ], + }, + }, + requestType: 'chatComplete', + }; + const parameters = { + categories: ['CREDIT_CARD', 'EMAIL_ADDRESS'], + credentials: testCreds, + redact: true, + not: false, + }; + + const result = await piiHandler( + context as PluginContext, + parameters, + 'afterRequestHook', + mockOptions + ); + + expect(result.error).toBeNull(); + expect(result.verdict).toBe(false); + expect(result.data).toMatchObject({ + verdict: false, + not: false, + explanation: expect.stringContaining('Found restricted PII'), + restrictedCategories: ['CREDIT_CARD', 'EMAIL_ADDRESS'], + }); + expect( + result.transformedData?.response?.json?.choices?.[0]?.message?.content + ).toBe( + 'My credit card number is [CREDIT_CARD_1], and my email is [EMAIL_ADDRESS_1]' + ); }); it('should pass text without PII', async () => { diff --git a/plugins/portkey/redactPii.ts b/plugins/portkey/redactPii.ts deleted file mode 100644 index 1cfe06e7f..000000000 --- a/plugins/portkey/redactPii.ts +++ /dev/null @@ -1,75 +0,0 @@ -import { - HookEventType, - PluginContext, - PluginHandler, - PluginParameters, -} from '../types'; -import { getCurrentContentPart, setCurrentContentPart } from '../utils'; -import { detectPII } from './pii'; - -export const handler: PluginHandler = async ( - context: PluginContext, - parameters: PluginParameters, - eventType: HookEventType, - options -) => { - let transformedData: Record = { - request: { - json: null, - }, - response: { - json: null, - }, - }; - - try { - const { content, textArray } = getCurrentContentPart(context, eventType); - - if (!content) { - return { - error: { message: 'request or response json is empty' }, - verdict: true, - data: null, - transformedData, - }; - } - - let mappedResult = await detectPII( - textArray, - parameters.credentials, - options?.env || {} - ); - const { detectedPIICategories, PIIData } = mappedResult[0] || {}; - - const mappedTextArray = mappedResult.map((result) => { - if (result.detectedPIICategories.length > 0 && result.redactedText) { - return result.redactedText; - } - return null; - }); - - setCurrentContentPart( - context, - eventType, - transformedData, - null, - mappedTextArray - ); - return { - error: null, - verdict: true, - data: { - detectedPII: PIIData, - detectedCategories: detectedPIICategories, - }, - transformedData, - }; - } catch (error: any) { - return { - error: error as Error, - verdict: true, - data: null, - transformedData, - }; - } -}; diff --git a/plugins/promptfoo/pii.ts b/plugins/promptfoo/pii.ts index 49cbab210..f7794383c 100644 --- a/plugins/promptfoo/pii.ts +++ b/plugins/promptfoo/pii.ts @@ -4,9 +4,44 @@ import { PluginHandler, PluginParameters, } from '../types'; -import { getText } from '../utils'; +import { getCurrentContentPart, setCurrentContentPart } from '../utils'; import { postPromptfoo } from './globals'; -import { PIIResult, PromptfooResult } from './types'; +import { PIIEntity, PIIResult } from './types'; + +const maskPiiEntries = (text: string, piiEntries: PIIEntity[]): string => { + return piiEntries.reduce((maskedText, entry) => { + const maskText = `[${entry.entity_type.toUpperCase()}]`; + return ( + maskedText.slice(0, entry.start + 1) + + maskText + + maskedText.slice(entry.end) + ); + }, text); +}; + +export const redactPii = async (text: string) => { + if (!text) { + return { maskedText: null, data: null }; + } + const piiObject = { + input: text, + }; + + const result = await postPromptfoo('pii', piiObject); + const piiResult = result.results[0]; + + if (piiResult.flagged) { + // Sort PII entries in reverse order to avoid offset issues when replacing + const sortedPiiEntries = piiResult.payload.pii.sort( + (a, b) => b.start - a.start + ); + const maskedText = maskPiiEntries(text, sortedPiiEntries); + + return { maskedText, data: result.results[0] }; + } + + return { maskedText: null, data: result.results[0] }; +}; export const handler: PluginHandler = async ( context: PluginContext, @@ -16,24 +51,51 @@ export const handler: PluginHandler = async ( let error = null; let verdict = true; let data = null; + let transformedData: Record = { + request: { + json: null, + }, + response: { + json: null, + }, + }; try { - const piiObject = { - input: getText(context, eventType), - }; + const { content, textArray } = getCurrentContentPart(context, eventType); + + if (!content) { + return { + error: { message: 'request or response json is empty' }, + verdict: true, + data: null, + }; + } + + const redact = parameters.redact || false; + const results = await Promise.all(textArray.map(redactPii)); + + const hasPII = results.some((result) => result?.data?.flagged); - const result = await postPromptfoo('pii', piiObject); + const piiData = + results.find((result) => result?.maskedText)?.data ?? results[0]?.data; - // If PII is detected, set verdict to false - if (result.results[0].flagged) { - verdict = false; + if (hasPII && redact) { + const maskedTexts = results.map((result) => result?.maskedText ?? null); + setCurrentContentPart( + context, + eventType, + transformedData, + null, + maskedTexts + ); } - data = result.results[0]; + verdict = !hasPII; + data = piiData; } catch (e: any) { delete e.stack; error = e; } - return { error, verdict, data }; + return { error, verdict, data, transformedData }; }; diff --git a/plugins/promptfoo/promptfoo.test.ts b/plugins/promptfoo/promptfoo.test.ts index abaab7af8..74d25bc67 100644 --- a/plugins/promptfoo/promptfoo.test.ts +++ b/plugins/promptfoo/promptfoo.test.ts @@ -1,4 +1,4 @@ -import { HookEventType } from '../types'; +import { HookEventType, PluginContext } from '../types'; import { handler as guardHandler } from './guard'; import { handler as piiHandler } from './pii'; import { handler as harmHandler } from './harm'; @@ -40,20 +40,160 @@ describe('guard handler', () => { }); describe('pii handler', () => { - it('should detect PII', async () => { + it('should only detect PII', async () => { const eventType = 'beforeRequestHook' as HookEventType; const context = { - request: { text: 'My email is john@example.com and SSN is 123-45-6789' }, + request: { + text: 'My email is john@example.com and SSN is 123-45-6789', + json: { + messages: [ + { + role: 'user', + content: 'My email is john@example.com and SSN is 123-45-6789', + }, + ], + }, + }, + requestType: 'chatComplete', }; const parameters = {}; - const result = await piiHandler(context, parameters, eventType, { - env: {}, - }); + const result = await piiHandler( + context as PluginContext, + parameters, + eventType, + { + env: {}, + } + ); expect(result).toBeDefined(); expect(result.verdict).toBe(false); expect(result.error).toBeNull(); expect(result.data).toBeDefined(); + expect(result.transformedData?.request?.json).toBeNull(); + }); + + it('should detect and redact PII in request text', async () => { + const context = { + request: { + text: 'My SSN is 123-45-6789 and some random text', + json: { + messages: [ + { + role: 'user', + content: 'My SSN is 123-45-6789 and some random text', + }, + ], + }, + }, + requestType: 'chatComplete', + }; + const parameters = { + redact: true, + }; + + const result = await piiHandler( + context as PluginContext, + parameters, + 'beforeRequestHook', + { + env: {}, + } + ); + expect(result.error).toBeNull(); + expect(result.verdict).toBe(false); + expect(result.data).toBeDefined(); + expect(result.transformedData?.request?.json?.messages?.[0]?.content).toBe( + 'My SSN is [SOCIAL_SECURITY_NUMBER] and some random text' + ); + }); + + it('should detect and redact PII in request text with multiple content parts', async () => { + const context = { + request: { + text: 'My SSN is 123-45-6789 My SSN is 123-45-6789 and some random text', + json: { + messages: [ + { + role: 'user', + content: [ + { + type: 'text', + text: 'My SSN is 123-45-6789', + }, + { + type: 'text', + text: 'My SSN is 123-45-6789 and some random text', + }, + ], + }, + ], + }, + }, + requestType: 'chatComplete', + }; + const parameters = { + redact: true, + not: false, + }; + + const result = await piiHandler( + context as PluginContext, + parameters, + 'beforeRequestHook', + { + env: {}, + } + ); + + expect(result.error).toBeNull(); + expect(result.verdict).toBe(false); + expect(result.data).toBeDefined; + expect( + result.transformedData?.request?.json?.messages?.[0]?.content?.[0]?.text + ).toBe('My SSN is [SOCIAL_SECURITY_NUMBER]'); + expect( + result.transformedData?.request?.json?.messages?.[0]?.content?.[1]?.text + ).toBe('My SSN is [SOCIAL_SECURITY_NUMBER] and some random text'); + }); + + it('should detect and redact PII in response text', async () => { + const context = { + response: { + text: 'My SSN is 123-45-6789 and some random text', + json: { + choices: [ + { + message: { + role: 'assistant', + content: 'My SSN is 123-45-6789 and some random text', + }, + }, + ], + }, + }, + requestType: 'chatComplete', + }; + const parameters = { + redact: true, + not: false, + }; + + const result = await piiHandler( + context as PluginContext, + parameters, + 'afterRequestHook', + { + env: {}, + } + ); + + expect(result.error).toBeNull(); + expect(result.verdict).toBe(false); + expect(result.data).toBeDefined(); + expect( + result.transformedData?.response?.json?.choices?.[0]?.message?.content + ).toBe('My SSN is [SOCIAL_SECURITY_NUMBER] and some random text'); }); it('should pass text without PII', async () => { diff --git a/plugins/promptfoo/redactPii.ts b/plugins/promptfoo/redactPii.ts deleted file mode 100644 index d96f5173b..000000000 --- a/plugins/promptfoo/redactPii.ts +++ /dev/null @@ -1,94 +0,0 @@ -import { - HookEventType, - PluginContext, - PluginHandler, - PluginParameters, -} from '../types'; -import { getCurrentContentPart, setCurrentContentPart } from '../utils'; -import { postPromptfoo } from './globals'; -import { PIIResult } from './types'; - -export const redactPii = async (text: string) => { - const piiObject = { - input: text, - }; - - const result = await postPromptfoo('pii', piiObject); - - if (result.results[0].flagged) { - // Sort PII entries in reverse order to avoid offset issues when replacing - const piiEntries = result.results[0].payload.pii.sort( - (a, b) => b.start - a.start - ); - let maskedText = piiObject.input; - // Replace each PII instance with its masked version - for (const entry of piiEntries) { - const maskText = `[${entry.entity_type.toUpperCase()}]`; - maskedText = - maskedText.slice(0, entry.start) + - maskText + - maskedText.slice(entry.end); - } - - return maskedText; - } - - return null; -}; - -export const handler: PluginHandler = async ( - context: PluginContext, - parameters: PluginParameters, - eventType: HookEventType -) => { - let transformedData: Record = { - request: { - json: null, - }, - response: { - json: null, - }, - }; - - try { - const { content, textArray } = getCurrentContentPart(context, eventType); - - if (!content) { - return { - error: { message: 'request or response json is empty' }, - verdict: true, - data: null, - }; - } - - const transformedTextPromise = textArray.map((text) => redactPii(text)); - - const transformedText = await Promise.all(transformedTextPromise); - - setCurrentContentPart( - context, - eventType, - transformedData, - null, - transformedText - ); - - return { - error: null, - verdict: true, - data: - transformedText.filter((text) => text !== null).length > 0 - ? { flagged: true } - : null, - transformedData, - }; - } catch (e: any) { - delete e.stack; - return { - error: e as Error, - verdict: true, - data: null, - transformedData, - }; - } -}; diff --git a/plugins/utils.ts b/plugins/utils.ts index 0652e4fe3..c915f6f02 100644 --- a/plugins/utils.ts +++ b/plugins/utils.ts @@ -145,19 +145,23 @@ export const setCurrentContentPart = ( if (target === 'request') { const currentContent = updatedJson.messages[updatedJson.messages.length - 1].content; + // Only clone messages array if not already cloned + if (!newContent) { + updatedJson.messages = [...json.messages]; + updatedJson.messages[updatedJson.messages.length - 1] = { + ...updatedJson.messages[updatedJson.messages.length - 1], + }; + } + if (Array.isArray(currentContent)) { - // Only clone messages array if not already cloned - if (!newContent) { - updatedJson.messages = [...json.messages]; - updatedJson.messages[updatedJson.messages.length - 1] = { - ...updatedJson.messages[updatedJson.messages.length - 1], - }; - } updatedJson.messages[updatedJson.messages.length - 1].content = currentContent.map((item: any, index: number) => ({ ...item, text: textArray[index] || item.text, })); + } else { + updatedJson.messages[updatedJson.messages.length - 1].content = + textArray[0] || currentContent; } transformedData.request.json = updatedJson; } else { diff --git a/src/middlewares/hooks/index.ts b/src/middlewares/hooks/index.ts index ce52a1c4f..651b1df2e 100644 --- a/src/middlewares/hooks/index.ts +++ b/src/middlewares/hooks/index.ts @@ -329,29 +329,6 @@ export class HooksManager { return { ...hookResult, skipped: true }; } - if (hook.type === HookType.MUTATOR && hook.checks) { - for (const check of hook.checks) { - const result = await this.executeFunction( - span.getContext(), - check, - hook.eventType, - options - ); - if ( - result.transformedData && - (result.transformedData.response.json || - result.transformedData.request.json) - ) { - span.setContextAfterTransform( - result.transformedData.response.json, - result.transformedData.request.json - ); - } - delete result.transformedData; - checkResults.push(result); - } - } - if (hook.type === HookType.GUARDRAIL && hook.checks) { checkResults = await Promise.all( hook.checks @@ -365,6 +342,20 @@ export class HooksManager { ) ) ); + + checkResults.forEach((checkResult) => { + if ( + checkResult.transformedData && + (checkResult.transformedData.response.json || + checkResult.transformedData.request.json) + ) { + span.setContextAfterTransform( + checkResult.transformedData.response.json, + checkResult.transformedData.request.json + ); + } + delete checkResult.transformedData; + }); } hookResult = {