Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ollama LLM provider tools support #14623

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 134 additions & 36 deletions packages/ai-ollama/src/node/ollama-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ import {
LanguageModelRequest,
LanguageModelRequestMessage,
LanguageModelResponse,
LanguageModelStreamResponse,
LanguageModelStreamResponsePart,
ToolCall,
ToolRequest
} from '@theia/ai-core';
import { CancellationToken } from '@theia/core';
Expand All @@ -31,7 +33,9 @@ export const OllamaModelIdentifier = Symbol('OllamaModelIdentifier');
export class OllamaModel implements LanguageModel {

protected readonly DEFAULT_REQUEST_SETTINGS: Partial<Omit<ChatRequest, 'stream' | 'model'>> = {
keep_alive: '15m'
keep_alive: '15m',
// options see: https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
options: {}
};

readonly providerId = 'ollama';
Expand All @@ -50,62 +54,125 @@ export class OllamaModel implements LanguageModel {
public defaultRequestSettings?: { [key: string]: unknown }
) { }

async request(request: LanguageModelRequest, cancellationToken?: CancellationToken): Promise<LanguageModelResponse> {
const settings = this.getSettings(request);
const ollama = this.initializeOllama();

const ollamaRequest: ExtendedChatRequest = {
model: this.model,
...this.DEFAULT_REQUEST_SETTINGS,
...settings,
messages: request.messages.map(this.toOllamaMessage),
tools: request.tools?.map(this.toOllamaTool)
};
const structured = request.response_format?.type === 'json_schema';
return this.dispatchRequest(ollama, ollamaRequest, structured, cancellationToken);
}

/**
* Retrieves the settings for the chat request, merging the request-specific settings with the default settings.
* @param request The language model request containing specific settings.
* @returns A partial ChatRequest object containing the merged settings.
*/
protected getSettings(request: LanguageModelRequest): Partial<ChatRequest> {
const settings = request.settings ?? this.defaultRequestSettings ?? {};
return {
options: settings as Partial<Options>
};
}

async request(request: LanguageModelRequest, cancellationToken?: CancellationToken): Promise<LanguageModelResponse> {
const settings = this.getSettings(request);
const ollama = this.initializeOllama();
protected async dispatchRequest(ollama: Ollama, ollamaRequest: ExtendedChatRequest, structured: boolean, cancellation?: CancellationToken): Promise<LanguageModelResponse> {

// Handle structured output request
if (structured) {
return this.handleStructuredOutputRequest(ollama, ollamaRequest);
}

if (request.response_format?.type === 'json_schema') {
return this.handleStructuredOutputRequest(ollama, request);
// Handle tool request - response may call tools
if (ollamaRequest.tools && ollamaRequest.tools?.length > 0) {
return this.handleToolsRequest(ollama, ollamaRequest);
}

// Handle standard chat request
const response = await ollama.chat({
model: this.model,
...this.DEFAULT_REQUEST_SETTINGS,
...settings,
messages: request.messages.map(this.toOllamaMessage),
stream: true,
tools: request.tools?.map(this.toOllamaTool),
...ollamaRequest,
stream: true
});
return this.handleCancellationAndWrapIterator(response, cancellation);
}

cancellationToken?.onCancellationRequested(() => {
response.abort();
protected async handleToolsRequest(ollama: Ollama, chatRequest: ExtendedChatRequest, prevResponse?: ChatResponse): Promise<LanguageModelResponse> {
const response = prevResponse || await ollama.chat({
...chatRequest,
stream: false
});

async function* wrapAsyncIterator<T>(inputIterable: AsyncIterable<ChatResponse>): AsyncIterable<LanguageModelStreamResponsePart> {
for await (const item of inputIterable) {
// TODO handle tool calls
yield { content: item.message.content };
if (response.message.tool_calls) {
const tools: ToolWithHandler[] = chatRequest.tools ?? [];
// Add response message to chat history
chatRequest.messages.push(response.message);
const tool_calls: ToolCall[] = [];
for (const [idx, toolCall] of response.message.tool_calls.entries()) {
const functionToCall = tools.find(tool => tool.function.name === toolCall.function.name);
if (functionToCall) {
const args = JSON.stringify(toolCall.function?.arguments);
const funcResult = await functionToCall.handler(args);
chatRequest.messages.push({
role: 'tool',
content: `Tool call ${functionToCall.function.name} returned: ${String(funcResult)}`,
});
let resultString = String(funcResult);
if (resultString.length > 1000) {
// truncate result string if it is too long
resultString = resultString.substring(0, 1000) + '...';
}
tool_calls.push({
id: `ollama_${response.created_at}_${idx}`,
function: {
name: functionToCall.function.name,
arguments: Object.values(toolCall.function?.arguments ?? {}).join(', ')
},
result: resultString,
finished: true
});
}
}
// Get final response from model with function outputs
const finalResponse = await ollama.chat({ ...chatRequest, stream: false });
if (finalResponse.message.tool_calls) {
// If the final response also calls tools, recursively handle them
return this.handleToolsRequest(ollama, chatRequest, finalResponse);
}
return { stream: this.createAsyncIterable([{ tool_calls }, { content: finalResponse.message.content }]) };
}
return { stream: wrapAsyncIterator(response) };
return { text: response.message.content };
}

protected async handleStructuredOutputRequest(ollama: Ollama, request: LanguageModelRequest): Promise<LanguageModelParsedResponse> {
const settings = this.getSettings(request);
const result = await ollama.chat({
...settings,
...this.DEFAULT_REQUEST_SETTINGS,
model: this.model,
messages: request.messages.map(this.toOllamaMessage),
protected createAsyncIterable<T>(items: T[]): AsyncIterable<T> {
return {
[Symbol.asyncIterator]: async function* (): AsyncIterableIterator<T> {
for (const item of items) {
yield item;
}
}
};
}

protected async handleStructuredOutputRequest(ollama: Ollama, chatRequest: ChatRequest): Promise<LanguageModelParsedResponse> {
const response = await ollama.chat({
...chatRequest,
format: 'json',
stream: false,
});
try {
return {
content: result.message.content,
parsed: JSON.parse(result.message.content)
content: response.message.content,
parsed: JSON.parse(response.message.content)
};
} catch (error) {
// TODO use ILogger
console.log('Failed to parse structured response from the language model.', error);
return {
content: result.message.content,
content: response.message.content,
parsed: {}
};
}
Expand All @@ -119,11 +186,21 @@ export class OllamaModel implements LanguageModel {
return new Ollama({ host: host });
}

protected toOllamaTool(tool: ToolRequest): Tool {
const transform = (props: Record<string, {
[key: string]: unknown;
type: string;
}> | undefined) => {
protected handleCancellationAndWrapIterator(response: AbortableAsyncIterable<ChatResponse>, token?: CancellationToken): LanguageModelStreamResponse {
token?.onCancellationRequested(() => {
// maybe it is better to use ollama.abort() as we are using one client per request
response.abort();
});
async function* wrapAsyncIterator<T>(inputIterable: AsyncIterable<ChatResponse>): AsyncIterable<LanguageModelStreamResponsePart> {
for await (const item of inputIterable) {
yield { content: item.message.content };
}
}
return { stream: wrapAsyncIterator(response) };
}

protected toOllamaTool(tool: ToolRequest): ToolWithHandler {
const transform = (props: Record<string, { [key: string]: unknown; type: string; }> | undefined) => {
if (!props) {
return undefined;
}
Expand All @@ -148,7 +225,8 @@ export class OllamaModel implements LanguageModel {
required: Object.keys(tool.parameters?.properties ?? {}),
properties: transform(tool.parameters?.properties) ?? {}
},
}
},
handler: tool.handler
};
}

Expand All @@ -165,3 +243,23 @@ export class OllamaModel implements LanguageModel {
return { role: 'system', content: '' };
}
}

/**
* Extended Tool containing a handler
* @see Tool
*/
type ToolWithHandler = Tool & { handler: (arg_string: string) => Promise<unknown> };

/**
* Extended chat request with mandatory messages and ToolWithHandler tools
*
* @see ChatRequest
* @see ToolWithHandler
*/
type ExtendedChatRequest = ChatRequest & {
messages: Message[]
tools?: ToolWithHandler[]
};

// Ollama doesn't export this type, so we have to define it here
type AbortableAsyncIterable<T> = AsyncIterable<T> & { abort: () => void };
Loading