Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Expanded Vertex AI support #55

Merged
merged 2 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 88 additions & 28 deletions drivers/src/vertexai/index.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import { GenerateContentRequest, VertexAI } from "@google-cloud/vertexai";
import { AIModel, AbstractDriver, Completion, CompletionChunkObject, DriverOptions, EmbeddingsResult, ExecutionOptions, ModelSearchPayload, PromptOptions, PromptSegment } from "@llumiverse/core";
import { AIModel, AbstractDriver, Completion, CompletionChunkObject, DriverOptions, EmbeddingsResult, ExecutionOptions, ImageGenExecutionOptions, ImageGeneration, Modalities, ModelSearchPayload, PromptOptions, PromptSegment } from "@llumiverse/core";
import { FetchClient } from "api-fetch-client";
import { GoogleAuth, GoogleAuthOptions } from "google-auth-library";
import { JSONClient } from "google-auth-library/build/src/auth/googleauth.js";
import { TextEmbeddingsOptions, getEmbeddingsForText } from "./embeddings/embeddings-text.js";
import { BuiltinModels, getModelDefinition } from "./models.js";
import { getModelDefinition } from "./models.js";
import { EmbeddingsOptions } from "@llumiverse/core";
import { getEmbeddingsForImages } from "./embeddings/embeddings-image.js";
import { v1beta1 } from '@google-cloud/aiplatform';
import { AnthropicVertex } from '@anthropic-ai/vertex-sdk';
import { ImagenModelDefinition } from "./models/imagen.js";


export interface VertexAIDriverOptions extends DriverOptions {
Expand All @@ -15,18 +18,28 @@ export interface VertexAIDriverOptions extends DriverOptions {
googleAuthOptions?: GoogleAuthOptions;
}

export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, GenerateContentRequest> {
//General Prompt type for VertexAI
export type VertexAIPrompt = GenerateContentRequest;

export function trimModelName(model: string) {
const i = model.lastIndexOf('@');
return i > -1 ? model.substring(0, i) : model;
}

export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, VertexAIPrompt> {
static PROVIDER = "vertexai";
provider = VertexAIDriver.PROVIDER;

//aiplatform: v1.ModelServiceClient;
aiplatform: v1beta1.ModelServiceClient;
vertexai: VertexAI;
fetchClient: FetchClient;
authClient: JSONClient | GoogleAuth<JSONClient>;

anthropicClient: AnthropicVertex | undefined;

constructor( options: VertexAIDriverOptions) {
super(options);
//this.aiplatform = new v1.ModelServiceClient();

this.anthropicClient = undefined;

this.authClient = options.googleAuthOptions?.authClient ?? new GoogleAuth(options.googleAuthOptions);

Expand All @@ -43,43 +56,90 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Genera
const token = await this.authClient.getAccessToken();
return `Bearer ${token}`;
});
// this.aiplatform = new v1.ModelServiceClient({
// projectId: this.options.project,
// apiEndpoint: `${this.options.region}-${API_BASE_PATH}`,
// });
this.aiplatform = new v1beta1.ModelServiceClient({
projectId: this.options.project,
apiEndpoint: `${this.options.region}-${API_BASE_PATH}`,
});
}

public getAnthropicClient() : AnthropicVertex {
//Lazy initialisation
if (!this.anthropicClient) {
this.anthropicClient = new AnthropicVertex({region: "us-east5", projectId: process.env.GOOGLE_PROJECT_ID});
}
return this.anthropicClient;
}

protected canStream(options: ExecutionOptions): Promise<boolean> {
if (options.output_modality == Modalities.image) {
return Promise.resolve(false);
}
return Promise.resolve(getModelDefinition(options.model).model.can_stream === true);
}

public createPrompt(segments: PromptSegment[], options: PromptOptions): Promise<GenerateContentRequest> {
public createPrompt(segments: PromptSegment[], options: PromptOptions): Promise<VertexAIPrompt> {
return getModelDefinition(options.model).createPrompt(this, segments, options);
}

async requestCompletion(prompt: GenerateContentRequest, options: ExecutionOptions): Promise<Completion<any>> {
async requestCompletion(prompt: VertexAIPrompt, options: ExecutionOptions): Promise<Completion<any>> {
return getModelDefinition(options.model).requestCompletion(this, prompt, options);
}
async requestCompletionStream(prompt: GenerateContentRequest, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> {
async requestCompletionStream(prompt: VertexAIPrompt, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> {
return getModelDefinition(options.model).requestCompletionStream(this, prompt, options);
}

async requestImageGeneration(_prompt: GenerateContentRequest, _options: ImageGenExecutionOptions): Promise <Completion<ImageGeneration>> {
const splits = _options.model.split("/");
const modelName = trimModelName(splits[splits.length - 1]);
return new ImagenModelDefinition(modelName).requestImageGeneration(this, _prompt, _options);
}

async listModels(_params?: ModelSearchPayload): Promise<AIModel<string>[]> {
return BuiltinModels;
// try {
// const response = await this.fetchClient.get('/publishers/google/models/gemini-pro');
// console.log('>>>>>>>>', response);
// } catch (err: any) {
// console.error('+++++VETREX ERROR++++++', err);
// throw err;
// }

// TODO uncomment this to use apiplatform instead of the fetch client
// const response = await this.aiplatform.listModels({
// parent: `projects/${this.options.project}/locations/${this.options.region}`,
// });

return []; //TODO
let models: AIModel<string>[] = [];
const modelGarden = new v1beta1.ModelGardenServiceClient({
projectId: this.options.project,
apiEndpoint: `${this.options.region}-${API_BASE_PATH}`,
});

//Project specific deployed models
const [response] = await this.aiplatform.listModels({
parent: `projects/${this.options.project}/locations/${this.options.region}`,
});
models = models.concat(response.map(model => ({
id: model.name?.split('/').pop() ?? '',
name: model.displayName ?? '',
provider: 'vertexai',
})));

//Model Garden Publisher models - Pretrained models
const publishers = ['google', 'anthropic']
const supportedModels = {google: ['gemini','imagen'], anthropic: ['claude']}

for (const publisher of publishers) {
const [response] = await modelGarden.listPublisherModels({
parent: `publishers/${publisher}`,
orderBy: 'name',
//filter: `name eq name`,
//list_all_versions: 'true',
//As of 27/12/24 list_all_versions is not supported yet, see if https://github.com/googleapis/google-cloud-node/pull/5836 is merged
});

models = models.concat(response.map(model => ({
id: model.name ?? '',
name: model.name?.split('/').pop() ?? '',
provider: 'vertexai',
owner: publisher,
})).filter(model => {
const modelFamily = supportedModels[publisher as keyof typeof supportedModels];
for (const family of modelFamily) {
if (model.name.includes(family)) {
return true;
}
}
}));
}

return models;
}

validateConnection(): Promise<boolean> {
Expand Down
108 changes: 13 additions & 95 deletions drivers/src/vertexai/models.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import { AIModel, Completion, CompletionChunkObject, ExecutionOptions, ModelType, PromptOptions, PromptSegment } from "@llumiverse/core";
import { VertexAIDriver } from "./index.js";
import { AIModel, Completion, CompletionChunkObject, ExecutionOptions, PromptOptions, PromptSegment } from "@llumiverse/core";
import { VertexAIDriver , trimModelName} from "./index.js";
import { GeminiModelDefinition } from "./models/gemini.js";



import { ClaudeModelDefinition } from "./models/claude.js";

export interface ModelDefinition<PromptT = any> {
model: AIModel;
Expand All @@ -13,97 +11,17 @@ export interface ModelDefinition<PromptT = any> {
requestCompletionStream: (driver: VertexAIDriver, promp: PromptT, options: ExecutionOptions) => Promise<AsyncIterable<CompletionChunkObject>>;
}

export function getModelName(model: string) {
const i = model.lastIndexOf('@');
return i > -1 ? model.substring(0, i) : model;
}

export function getModelDefinition(model: string): ModelDefinition {
const modelName = getModelName(model);
const def = Models[modelName];
if (!def) {
throw new Error(`Unknown model ${model}`);
const splits = model.split("/");
const publisher = splits[1];
const modelName = trimModelName(splits[splits.length - 1]);

if (publisher?.includes("anthropic")) {
return new ClaudeModelDefinition(modelName);
} else if (publisher?.includes("google")) {
return new GeminiModelDefinition(modelName);
}
return def;
}

export function getAIModels() {
return Object.values(Models).map(m => m.model);
}

// Builtin models. VertexAI doesn't provide an API to list models. so we have to hardcode them here.
export const BuiltinModels: AIModel<string>[] = [
{
id: "gemini-1.5-flash-001",
name: "Gemini Pro 1.5 Flash 001",
provider: "vertexai",
owner: "google",
type: ModelType.MultiModal,
can_stream: true,
is_multimodal: true
},
{
id: "gemini-1.5-flash-002",
name: "Gemini Pro 1.5 Flash 002",
provider: "vertexai",
owner: "google",
type: ModelType.MultiModal,
can_stream: true,
is_multimodal: true
},
{
id: "gemini-1.5-flash",
name: "Gemini Pro 1.5 Flash",
provider: "vertexai",
owner: "google",
type: ModelType.MultiModal,
can_stream: true,
is_multimodal: true
},
{
id: "gemini-1.5-pro-001",
name: "Gemini Pro 1.5 Pro 001",
provider: "vertexai",
owner: "google",
type: ModelType.MultiModal,
can_stream: true,
is_multimodal: true
},
{
id: "gemini-1.5-pro-002",
name: "Gemini Pro 1.5 Pro 002",
provider: "vertexai",
owner: "google",
type: ModelType.MultiModal,
can_stream: true,
is_multimodal: true
},
{
id: "gemini-1.5-pro",
name: "Gemini Pro 1.5 Pro",
provider: "vertexai",
owner: "google",
type: ModelType.MultiModal,
can_stream: true,
is_multimodal: true
},
{
id: "gemini-1.0-pro",
name: "Gemini Pro 1.0",
provider: "vertexai",
owner: "google",
type: ModelType.Text,
can_stream: true,
},
]

// Must be updated when adding new models
const Models: Record<string, ModelDefinition> = {
"gemini-1.5-flash-002": new GeminiModelDefinition("gemini-1.5-flash-002"),
"gemini-1.5-flash-001": new GeminiModelDefinition("gemini-1.5-flash-001"),
"gemini-1.5-flash": new GeminiModelDefinition("gemini-1.5-flash"),
"gemini-1.5-pro-002": new GeminiModelDefinition("gemini-1.5-pro-002"),
"gemini-1.5-pro-001": new GeminiModelDefinition("gemini-1.5-pro-001"),
"gemini-1.5-pro": new GeminiModelDefinition("gemini-1.5-pro"),
"gemini-1.0-pro": new GeminiModelDefinition(),
//Fallback, assume it is Gemini.
return new GeminiModelDefinition(modelName);
}
Loading
Loading