Skip to content

Commit

Permalink
Merge pull request #246 from Oneirocom/google-and-system-prompt
Browse files Browse the repository at this point in the history
Support google models in generation
  • Loading branch information
lalalune authored Nov 10, 2024
2 parents ce4d327 + c7e9bf0 commit 9a04908
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 40 deletions.
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ OPENAI_API_KEY=sk-* # OpenAI API key, starting with sk-
REDPILL_API_KEY= # REDPILL API Key
GROQ_API_KEY=gsk_*
OPENROUTER_API_KEY=
GOOGLE_GENERATIVE_AI_API_KEY= # Gemini API key

ELEVENLABS_XI_API_KEY= # API key from elevenlabs

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ DISCORD_APPLICATION_ID=
DISCORD_API_TOKEN= # Bot token
OPENAI_API_KEY=sk-* # OpenAI API key, starting with sk-
ELEVENLABS_XI_API_KEY= # API key from elevenlabs
GOOGLE_GENERATIVE_AI_API_KEY= # Gemini API key
# ELEVENLABS SETTINGS
ELEVENLABS_MODEL_ID=eleven_multilingual_v2
Expand Down
29 changes: 25 additions & 4 deletions packages/core/src/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { default as tiktoken, TiktokenModel } from "tiktoken";
import Together from "together-ai";
import { elizaLogger } from "./index.ts";
import models from "./models.ts";
import { createGoogleGenerativeAI } from "@ai-sdk/google";
import {
parseBooleanFromText,
parseJsonArrayFromText,
Expand Down Expand Up @@ -104,6 +105,25 @@ export async function generateText({
break;
}

case ModelProviderName.GOOGLE:
const google = createGoogleGenerativeAI();

const { text: anthropicResponse } = await aiGenerateText({
model: google(model),
prompt: context,
system:
runtime.character.system ??
settings.SYSTEM_PROMPT ??
undefined,
temperature: temperature,
maxTokens: max_response_length,
frequencyPenalty: frequency_penalty,
presencePenalty: presence_penalty,
});

response = anthropicResponse;
break;

case ModelProviderName.ANTHROPIC: {
elizaLogger.log("Initializing Anthropic model.");

Expand Down Expand Up @@ -214,7 +234,6 @@ export async function generateText({
break;
}


case ModelProviderName.OPENROUTER: {
elizaLogger.log("Initializing OpenRouter model.");
const serverUrl = models[provider].endpoint;
Expand All @@ -238,7 +257,6 @@ export async function generateText({
break;
}


case ModelProviderName.OLLAMA:
{
console.log("Initializing Ollama model.");
Expand Down Expand Up @@ -425,10 +443,13 @@ export async function generateTrueOrFalse({
modelClass: string;
}): Promise<boolean> {
let retryDelay = 1000;
console.log("modelClass", modelClass)
console.log("modelClass", modelClass);

const stop = Array.from(
new Set([...(models[runtime.modelProvider].settings.stop || []), ["\n"]])
new Set([
...(models[runtime.modelProvider].settings.stop || []),
["\n"],
])
) as string[];

while (true) {
Expand Down
9 changes: 4 additions & 5 deletions packages/core/src/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ const models: Models = {
temperature: 0.3,
},
model: {
[ModelClass.SMALL]: "gemini-1.5-flash",
[ModelClass.MEDIUM]: "gemini-1.5-flash",
[ModelClass.LARGE]: "gemini-1.5-pro",
[ModelClass.SMALL]: "gemini-1.5-flash-latest",
[ModelClass.MEDIUM]: "gemini-1.5-flash-latest",
[ModelClass.LARGE]: "gemini-1.5-pro-latest",
[ModelClass.EMBEDDING]: "text-embedding-004",
},
},
Expand Down Expand Up @@ -187,8 +187,7 @@ const models: Models = {
settings.LARGE_OPENROUTER_MODEL ||
settings.OPENROUTER_MODEL ||
"nousresearch/hermes-3-llama-3.1-405b",
[ModelClass.EMBEDDING]:
"text-embedding-3-small",
[ModelClass.EMBEDDING]: "text-embedding-3-small",
},
},
[ModelProviderName.OLLAMA]: {
Expand Down
6 changes: 3 additions & 3 deletions packages/core/src/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -498,14 +498,14 @@ export class AgentRuntime implements IAgentRuntime {
* @returns The results of the evaluation.
*/
async evaluate(message: Memory, state?: State, didRespond?: boolean) {
console.log("Evaluate: ", didRespond)
console.log("Evaluate: ", didRespond);
const evaluatorPromises = this.evaluators.map(
async (evaluator: Evaluator) => {
console.log("Evaluating", evaluator.name)
console.log("Evaluating", evaluator.name);
if (!evaluator.handler) {
return null;
}
if(!didRespond && !evaluator.alwaysRun) {
if (!didRespond && !evaluator.alwaysRun) {
return null;
}
const result = await evaluator.validate(this, message, state);
Expand Down
6 changes: 5 additions & 1 deletion packages/core/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,11 @@ export interface IAgentRuntime {
state?: State,
callback?: HandlerCallback
): Promise<void>;
evaluate(message: Memory, state?: State, didRespond?: boolean): Promise<string[]>;
evaluate(
message: Memory,
state?: State,
didRespond?: boolean
): Promise<string[]>;
ensureParticipantExists(userId: UUID, roomId: UUID): Promise<void>;
ensureUserExists(
userId: UUID,
Expand Down
27 changes: 0 additions & 27 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 9a04908

Please sign in to comment.