Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: video generation plugin #394

Merged
merged 2 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions packages/core/src/tests/videoGeneration.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import { IAgentRuntime, Memory, State } from "@ai16z/eliza";
import { videoGenerationPlugin } from "../index";

// Mock the fetch function
global.fetch = jest.fn();

// Mock the fs module
jest.mock('fs', () => ({
writeFileSync: jest.fn(),
existsSync: jest.fn(),
mkdirSync: jest.fn(),
}));

describe('Video Generation Plugin', () => {
let mockRuntime: IAgentRuntime;
let mockCallback: jest.Mock;

beforeEach(() => {
// Reset mocks
jest.clearAllMocks();

// Setup mock runtime
mockRuntime = {
getSetting: jest.fn().mockReturnValue('mock-api-key'),
agentId: 'mock-agent-id',
composeState: jest.fn().mockResolvedValue({}),
} as unknown as IAgentRuntime;

mockCallback = jest.fn();

// Setup fetch mock for successful response
(global.fetch as jest.Mock).mockImplementation(() =>
Promise.resolve({
ok: true,
json: () => Promise.resolve({
id: 'mock-generation-id',
status: 'completed',
assets: {
video: 'https://example.com/video.mp4'
}
}),
text: () => Promise.resolve(''),
})
);
});

it('should validate when API key is present', async () => {
const mockMessage = {} as Memory;
const result = await videoGenerationPlugin.actions[0].validate(mockRuntime, mockMessage);
expect(result).toBe(true);
expect(mockRuntime.getSetting).toHaveBeenCalledWith('LUMA_API_KEY');
});

it('should handle video generation request', async () => {
const mockMessage = {
content: {
text: 'Generate a video of a sunset'
}
} as Memory;
const mockState = {} as State;

await videoGenerationPlugin.actions[0].handler(
mockRuntime,
mockMessage,
mockState,
{},
mockCallback
);

// Check initial callback
expect(mockCallback).toHaveBeenCalledWith(
expect.objectContaining({
text: expect.stringContaining('I\'ll generate a video based on your prompt')
})
);

// Check final callback with video
expect(mockCallback).toHaveBeenCalledWith(
expect.objectContaining({
text: 'Here\'s your generated video!',
attachments: expect.arrayContaining([
expect.objectContaining({
source: 'videoGeneration'
})
])
}),
expect.arrayContaining([expect.stringMatching(/generated_video_.*\.mp4/)])
);
});

it('should handle API errors gracefully', async () => {
// Mock API error
(global.fetch as jest.Mock).mockImplementationOnce(() =>
Promise.resolve({
ok: false,
status: 500,
statusText: 'Internal Server Error',
text: () => Promise.resolve('API Error'),
})
);

const mockMessage = {
content: {
text: 'Generate a video of a sunset'
}
} as Memory;
const mockState = {} as State;

await videoGenerationPlugin.actions[0].handler(
mockRuntime,
mockMessage,
mockState,
{},
mockCallback
);

// Check error callback
expect(mockCallback).toHaveBeenCalledWith(
expect.objectContaining({
text: expect.stringContaining('Video generation failed'),
error: true
})
);
});
});
18 changes: 18 additions & 0 deletions packages/plugin-video-generation/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"name": "@ai16z/plugin-video-generation",
"version": "0.0.1",
"main": "dist/index.js",
"type": "module",
"types": "dist/index.d.ts",
"dependencies": {
"@ai16z/eliza": "workspace:*",
"tsup": "^8.3.5"
},
"scripts": {
"build": "tsup --format esm --dts",
"dev": "tsup --watch"
},
"peerDependencies": {
"whatwg-url": "7.1.0"
}
}
4 changes: 4 additions & 0 deletions packages/plugin-video-generation/src/constants.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
export const LUMA_CONSTANTS = {
API_URL: 'https://api.lumalabs.ai/dream-machine/v1/generations',
API_KEY_SETTING: "LUMA_API_KEY" // The setting name to fetch from runtime
};
221 changes: 221 additions & 0 deletions packages/plugin-video-generation/src/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import { elizaLogger } from "@ai16z/eliza/src/logger.ts";
import {
Action,
HandlerCallback,
IAgentRuntime,
Memory,
Plugin,
State,
} from "@ai16z/eliza/src/types.ts";
import fs from "fs";
import { LUMA_CONSTANTS } from './constants';

const generateVideo = async (prompt: string, runtime: IAgentRuntime) => {
const API_KEY = runtime.getSetting(LUMA_CONSTANTS.API_KEY_SETTING);

try {
elizaLogger.log("Starting video generation with prompt:", prompt);

const response = await fetch(LUMA_CONSTANTS.API_URL, {
method: 'POST',
headers: {
'Authorization': `Bearer ${API_KEY}`,
'accept': 'application/json',
'Content-Type': 'application/json'
},
body: JSON.stringify({ prompt })
});

if (!response.ok) {
const errorText = await response.text();
elizaLogger.error("Luma API error:", {
status: response.status,
statusText: response.statusText,
error: errorText
});
throw new Error(`Luma API error: ${response.statusText} - ${errorText}`);
}

const data = await response.json();
elizaLogger.log("Generation request successful, received response:", data);

// Poll for completion
let status = data.status;
let videoUrl = null;
const generationId = data.id;

while (status !== 'completed' && status !== 'failed') {
await new Promise(resolve => setTimeout(resolve, 5000)); // Wait 5 seconds

const statusResponse = await fetch(`${LUMA_CONSTANTS.API_URL}/${generationId}`, {
method: 'GET',
headers: {
'Authorization': `Bearer ${API_KEY}`,
'accept': 'application/json'
}
});

if (!statusResponse.ok) {
const errorText = await statusResponse.text();
elizaLogger.error("Status check error:", {
status: statusResponse.status,
statusText: statusResponse.statusText,
error: errorText
});
throw new Error('Failed to check generation status: ' + errorText);
}

const statusData = await statusResponse.json();
elizaLogger.log("Status check response:", statusData);

status = statusData.state;
if (status === 'completed') {
videoUrl = statusData.assets?.video;
}
}

if (status === 'failed') {
throw new Error('Video generation failed');
}

if (!videoUrl) {
throw new Error('No video URL in completed response');
}

return {
success: true,
data: videoUrl
};
} catch (error) {
elizaLogger.error("Video generation error:", error);
return {
success: false,
error: error.message || 'Unknown error occurred'
};
}
}

const videoGeneration: Action = {
name: "GENERATE_VIDEO",
similes: [
"VIDEO_GENERATION",
"VIDEO_GEN",
"CREATE_VIDEO",
"MAKE_VIDEO",
"RENDER_VIDEO",
"ANIMATE",
"CREATE_ANIMATION",
"VIDEO_CREATE",
"VIDEO_MAKE"
],
description: "Generate a video based on a text prompt",
validate: async (runtime: IAgentRuntime, message: Memory) => {
elizaLogger.log("Validating video generation action");
const lumaApiKey = runtime.getSetting("LUMA_API_KEY");
elizaLogger.log("LUMA_API_KEY present:", !!lumaApiKey);
return !!lumaApiKey;
},
handler: async (
runtime: IAgentRuntime,
message: Memory,
state: State,
options: any,
callback: HandlerCallback
) => {
elizaLogger.log("Video generation request:", message);

// Clean up the prompt by removing mentions and commands
let videoPrompt = message.content.text
.replace(/<@\d+>/g, '') // Remove mentions
.replace(/generate video|create video|make video|render video/gi, '') // Remove commands
.trim();

if (!videoPrompt || videoPrompt.length < 5) {
callback({
text: "Could you please provide more details about what kind of video you'd like me to generate? For example: 'Generate a video of a sunset on a beach' or 'Create a video of a futuristic city'",
});
return;
}

elizaLogger.log("Video prompt:", videoPrompt);

callback({
text: `I'll generate a video based on your prompt: "${videoPrompt}". This might take a few minutes...`,
});

try {
const result = await generateVideo(videoPrompt, runtime);

if (result.success && result.data) {
// Download the video file
const response = await fetch(result.data);
const arrayBuffer = await response.arrayBuffer();
const videoFileName = `content_cache/generated_video_${Date.now()}.mp4`;

// Save video file
fs.writeFileSync(videoFileName, Buffer.from(arrayBuffer));

callback({
text: "Here's your generated video!",
attachments: [
{
id: crypto.randomUUID(),
url: result.data,
title: "Generated Video",
source: "videoGeneration",
description: videoPrompt,
text: videoPrompt,
},
],
}, [videoFileName]); // Add the video file to the attachments
} else {
callback({
text: `Video generation failed: ${result.error}`,
error: true
});
}
} catch (error) {
elizaLogger.error(`Failed to generate video. Error: ${error}`);
callback({
text: `Video generation failed: ${error.message}`,
error: true
});
}
},
examples: [
[
{
user: "{{user1}}",
content: { text: "Generate a video of a cat playing piano" },
},
{
user: "{{agentName}}",
content: {
text: "I'll create a video of a cat playing piano for you",
action: "GENERATE_VIDEO"
},
}
],
[
{
user: "{{user1}}",
content: { text: "Can you make a video of a sunset at the beach?" },
},
{
user: "{{agentName}}",
content: {
text: "I'll generate a beautiful beach sunset video for you",
action: "GENERATE_VIDEO"
},
}
]
]
} as Action;

export const videoGenerationPlugin: Plugin = {
name: "videoGeneration",
description: "Generate videos using Luma AI",
actions: [videoGeneration],
evaluators: [],
providers: [],
};
15 changes: 15 additions & 0 deletions packages/plugin-video-generation/tsconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"extends": "../../tsconfig.json",
"compilerOptions": {
"outDir": "dist",
"rootDir": ".",
"module": "ESNext",
"moduleResolution": "Bundler",
"types": [
"node"
]
},
"include": [
"src"
]
}
Loading
Loading