diff --git a/src/clients/discord/index.ts b/src/clients/discord/index.ts index e28cba03fa..e621f06bc5 100644 --- a/src/clients/discord/index.ts +++ b/src/clients/discord/index.ts @@ -41,6 +41,8 @@ import settings from "../../core/settings.ts"; import { AudioMonitor } from "./audioMonitor.ts"; import { commands } from "./commands.ts"; import { InterestChannels, ResponseType } from "./types.ts"; +import ImageRecognitionService from "../../services/imageRecognition.ts" +import { extractAnswer } from "../../core/util.ts"; import { SpeechSynthesizer } from "../../services/speechSynthesis.ts"; import WavEncoder from "wav-encoder"; @@ -101,6 +103,7 @@ export class DiscordClient extends EventEmitter { private agent: Agent; private bio: string; private transcriber: any; + private imageRecognitionService: ImageRecognitionService; speechSynthesizer: SpeechSynthesizer; constructor(agent: Agent, bio: string) { @@ -124,6 +127,9 @@ export class DiscordClient extends EventEmitter { this.initializeTranscriber(); + this.imageRecognitionService = new ImageRecognitionService(); + this.imageRecognitionService.initialize(); + this.client.once(Events.ClientReady, async (readyClient: { user: { tag: any; id: any } }) => { console.log(`Logged in as ${readyClient.user?.tag}`); console.log("Use this URL to add the bot to your server:"); @@ -210,6 +216,12 @@ export class DiscordClient extends EventEmitter { const user_id = message.author.id as UUID; const userName = message.author.username; const channelId = message.channel.id; + + // Check for image attachments + if (message.attachments.size > 0) { + await this.handleImageRecognition(message); + } + const textContent = message.content; try { @@ -316,6 +328,20 @@ export class DiscordClient extends EventEmitter { } } + private async handleImageRecognition(message: DiscordMessage) { + const attachment = message.attachments.first(); + if (attachment && attachment.contentType?.startsWith('image/')) { + try { + const recognizedText = await this.imageRecognitionService.recognizeImage(attachment.url); + const description = extractAnswer(recognizedText[0]); + // Add the image description to the completion context + message.content += `\nImage description: ${description}`; + } catch (error) { + console.error('Error recognizing image:', error); + await message.reply('Sorry, I encountered an error while processing the image.'); + } + } + } private async ensureUserExists(agentId: UUID, userName: string, botToken: string | null = null) { if (!userName && botToken) { diff --git a/src/clients/twitter/base.ts b/src/clients/twitter/base.ts index 3fcba683b9..85a2551b48 100644 --- a/src/clients/twitter/base.ts +++ b/src/clients/twitter/base.ts @@ -16,6 +16,7 @@ import settings from "../../core/settings.ts"; import { fileURLToPath } from 'url'; import ImageRecognitionService from "../../services/imageRecognition.ts"; +import { extractAnswer } from "../../core/util.ts"; const __filename = fileURLToPath(import.meta.url); const __dirname = path.dirname(__filename); @@ -123,8 +124,10 @@ export class ClientBase extends EventEmitter { async describeImage(imageUrl: string): Promise { try { - const description = await this.imageRecognitionService.recognizeImage(imageUrl); - return description[0] || 'Unable to describe the image.'; + const recognizedText = await this.imageRecognitionService.recognizeImage(imageUrl); + const description = extractAnswer(recognizedText[0]); + + return description || 'Unable to describe the image.'; } catch (error) { console.error('Error describing image:', error); return 'Error occurred while describing the image.'; diff --git a/src/core/util.ts b/src/core/util.ts index 0529e69c4d..22a338ee2d 100644 --- a/src/core/util.ts +++ b/src/core/util.ts @@ -41,4 +41,11 @@ export function prependWavHeader(readable: Readable, audioLength: number, sample passThrough.end(); }); return passThrough; -} \ No newline at end of file +} + + +export function extractAnswer(text: string): string { + const startIndex = text.indexOf('Answer: ') + 8; + const endIndex = text.indexOf('<|endoftext|>', 11); + return text.slice(startIndex, endIndex); +}; \ No newline at end of file diff --git a/src/services/imageRecognition.ts b/src/services/imageRecognition.ts index 5b0c61218d..a746708799 100644 --- a/src/services/imageRecognition.ts +++ b/src/services/imageRecognition.ts @@ -9,7 +9,8 @@ class ImageRecognitionService { constructor() { this.modelId = 'Xenova/moondream2'; - this.device = 'webgpu'; + // this.device = 'webgpu'; + this.device = 'cpu'; this.model = null; this.processor = null; this.tokenizer = null;