Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deepseek model adapter please enhance/extend #586

Closed
georgiedekker opened this issue Jan 2, 2025 · 6 comments
Closed

deepseek model adapter please enhance/extend #586

georgiedekker opened this issue Jan 2, 2025 · 6 comments

Comments

@georgiedekker
Copy link

georgiedekker commented Jan 2, 2025

based on a copy of the openai model in latest pydantic-ai version. imports based on a separate copy of pydantic-ai in my codebase/src folder. Mostly just had chatgpt rewrite the _process_response method to account for empty messages from Deepseek v3.
Works with Deepseek v3 in simple example:

Example pydantic-ai script:

import os
from dotenv import load_dotenv
import json
from dataclasses import dataclass, field
from typing import Optional
from uuid import UUID, uuid4
from pathlib import Path

from pydantic import BaseModel, Field
from pydantic_ai import Agent, RunContext, Tool
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.models.ollama import OllamaModel
from src.deepseek import DeepSeekModel
from src.db.sqlite_db import Database

load_dotenv()

class CalculatorResult(BaseModel):
    """Result type for calculator operations."""
    value: float = Field(description='The calculated result')
    operation: str = Field(description='The operation performed')
    description: str = Field(description='The description of the operation')


@dataclass
class CalculatorDeps:
    """Dependencies for the calculator agent."""
    memory: dict[str, float] = field(default_factory=dict)


# Calculator tools should return strings
async def add(ctx: RunContext[CalculatorDeps], a: float, b: float) -> CalculatorResult:
    """Add two numbers together."""
    result = a + b
    print(f"🔢 ADD TOOL CALLED: {a} + {b} = {result}")
    ctx.deps.memory['last_result'] = result
    print(f"🔢 MEMORY: {ctx.deps.memory}")
    return result


async def multiply(ctx: RunContext[CalculatorDeps], a: float, b: float) -> CalculatorResult:
    """Multiply two numbers together."""
    result = a * b
    print(f"🔢 MULTIPLY TOOL CALLED: {a} × {b} = {result}")
    ctx.deps.memory['last_result'] = result
    return result


async def get_last_result(ctx: RunContext[CalculatorDeps]) -> float:
    """Get the last calculated result from memory."""
    result = ctx.deps.memory.get('last_result', 0.0)
    print(f"🔢 GET_LAST_RESULT TOOL CALLED: {result}")
    return result

model = DeepSeekModel(
    model_name='deepseek-chat',
    base_url='https://api.deepseek.com/v1',
    api_key=os.getenv('DEEPSEEK_API_KEY'))
# model="ollama:llama3.2:3b-instruct-q8_0"
# model = OllamaModel(
#     model_name="llama3.2:3b-instruct-q8_0",
#     base_url='http://localhost:11434/v1',
#     api_key='ollama')

# Create calculator agent with string result type
calculator_agent = Agent(
    model=model,
    deps_type=CalculatorDeps,
    result_type=str,
    tools=[Tool(add), Tool(multiply), Tool(get_last_result)],
    system_prompt=(
        "You are a calculator assistant. When performing calculations, you should:"
        "1. Use the appropriate tool (add, multiply, or get_last_result)"
        "2. Return the tool's JSON response directly without modification"
        "3. Do not add any additional text or formatting"
        "\nExample:"
        "\nUser: What is 5 plus 3?"
        "\nAssistant: {\"value\": 8.0, \"operation\": \"addition\", \"description\": \"5.0 + 3.0 = 8.0\"}"
        "This an example of what I am not looking for: The answer to the question ..."
        "Do not respond with \"The answer to the question ...\" or anything like that."
        "This is an example of what I am looking for: {\"value\": 8.0, \"operation\": \"addition\", \"description\": \"5.0 + 3.0 = 8.0\"}"
        "Respond with a single floating point number for the \"value\" field of the JSON response."
        "Only respond with a float like this: 4.1"
        "Do not respond with any other text or formatting besides the JSON response."
        "Remove any text that is not a float for the \"value\" field of the JSON response."
        "This an example of what I am not looking for: The answer to the question of ..."
        "This is an example of what I am looking for: {\"value\": 8.0, \"operation\": \"addition\", \"description\": \"5.0 + 3.0 = 8.0\"}"
        "You are a calculator assistant. When performing calculations, you should:"
        "1. Use the appropriate tool (add, multiply, or get_last_result)"
        "2. Return the tool's JSON response directly without modification"
        "3. Do not add any additional text or formatting"
        "\nRESPOND LIKE THIS: {\"value\": 8.0, \"operation\": \"addition\", \"description\": \"5.0 + 3.0 = 8.0\"}"
        "\nRESPOND LIKE THIS: {\"value\": 8.0, \"operation\": \"multiply\", \"description\": \"5.0 x 3.0 = 8.0\"}"
        "\nRESPOND LIKE THIS: {\"value\": 8.0, \"operation\": \"get_last_result\", \"description\": \"The last result was 8.0\"}"
    ),
    retries=3
)


class ToolExampleAgent:
    """Example agent implementation with tool support."""

    def __init__(self, database: Database):
        """Initialize the agent with database configuration."""
        self.database = database
        self.agent_id = uuid4()
        self.deps = CalculatorDeps()
        self.calculator = calculator_agent

    async def process_message(self, message: str) -> str:
        """Process message with LLM and store in database."""
        if not message:
            return "Error: Message cannot be empty"
        
        print(f"\n📝 INPUT MESSAGE: {message}")
        
        result = await self.calculator.run(message, deps=self.deps)

        # print(f"🔢 RESULT: {result}")
        
        # Store messages in database - serialize only necessary fields
        messages_to_store = []
        for msg in result.new_messages():
            msg_dict = {
                "kind": msg.kind,
                "parts": [{
                    "part_kind": part.part_kind,
                    "content": part.content if hasattr(part, 'content') else None,
                    "tool_name": part.tool_name if hasattr(part, 'tool_name') else None,
                    "args": part.args.__dict__ if hasattr(part, 'args') and part.args else None
                } for part in msg.parts]
            }
            messages_to_store.append(msg_dict)
        
        # Convert to JSON with custom handling for special types
        json_str = json.dumps(
            messages_to_store,
            default=lambda x: str(x) if not isinstance(x, (dict, list, str, int, float, bool, type(None))) else x
        )
        
        await self.database.add_messages(json_str.encode('utf-8'))
        
        return str(result.data)

    async def get_history(self) -> list[dict]:
        """Retrieve conversation history."""
        print("\n" + "="*50)
        print("📚 FETCHING HISTORY")
        print("="*50)
        
        try:
            messages = await self.database.get_messages()
            print(f"\n📥 Retrieved {len(messages)} messages")
            return messages
        except Exception as e:
            print("\n❌ History Error:")
            print(f"  Type: {type(e).__name__}")
            print(f"  Message: {str(e)}")
            return [{"error": f"Failed to retrieve history: {str(e)}"}]

async def main():
    """Example usage of the ToolExampleAgent."""
    async with Database.connect(Path('.chat_app_messages.sqlite')) as database:
        agent = ToolExampleAgent(database=database)

        # Test basic calculation
        calc_result = await agent.process_message("What is 521312123123.2 plus 321321321.2?")
        print(f"Calc Result: {calc_result}")

        # Test memory
        memory_result = await agent.process_message("What was the last result?")
        print(f"Memory: {memory_result}")

        # Test complex operation
        complex_result = await agent.process_message("Multiply the last result by 2")
        print(f"Complex: {complex_result}")

        test_result = await agent.process_message("What is 123.2 plus 321.2 times 423?")
        print(f"Test: {test_result}")
        # Get history
        history = await agent.get_history()
        # print(f"History: {json.dumps(history, indent=2)}")

if __name__ == "__main__":
    import asyncio
    from pathlib import Path
    asyncio.run(main())

deepseek model based on openai:


from collections.abc import AsyncIterator, Iterable
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from datetime import datetime, timezone
from itertools import chain
from typing import Literal, Union, overload

from httpx import AsyncClient as AsyncHTTPClient
from typing_extensions import assert_never

from pydantic_ai import UnexpectedModelBehavior, _utils, result
from pydantic_ai._utils import guard_tool_call_id as _guard_tool_call_id
from pydantic_ai.messages import (
    ModelMessage,
    ModelRequest,
    ModelResponse,
    ModelResponsePart,
    RetryPromptPart,
    SystemPromptPart,
    TextPart,
    ToolCallPart,
    ToolReturnPart,
    UserPromptPart,
)
from pydantic_ai.result import Usage
from pydantic_ai.settings import ModelSettings
from pydantic_ai.tools import ToolDefinition
from pydantic_ai.models import (
    AgentModel,
    EitherStreamedResponse,
    Model,
    StreamStructuredResponse,
    StreamTextResponse,
    cached_async_http_client,
    check_allow_model_requests,
)

try:
    from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream
    from openai.types import ChatModel, chat
    from openai.types.chat import ChatCompletionChunk
    from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
except ImportError as _import_error:
    raise ImportError(
        'Please install `openai` to use the DeepSeek model, '
        "you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
    ) from _import_error

DeepSeekModelName = Union[ChatModel, str]
"""
Using this more broad type for the model name instead of the ChatModel definition
allows this model to be used more easily with other model types (ie, Ollama)
"""


@dataclass(init=False)
class DeepSeekModel(Model):
    """A model that uses the DeepSeek API.

    Internally, this uses the [DeepSeek Python client](https://github.com/openai/openai-python) to interact with the API.

    Apart from `__init__`, all methods are private or match those of the base class.
    """

    model_name: DeepSeekModelName
    client: AsyncOpenAI = field(repr=False)

    def __init__(
        self,
        model_name: DeepSeekModelName,
        *,
        base_url: str | None = None,
        api_key: str | None = None,
        openai_client: AsyncOpenAI | None = None,
        http_client: AsyncHTTPClient | None = None,
    ):
        """Initialize an DeepSeek model.

        Args:
            model_name: The name of the DeepSeek model to use. List of model names available
                [here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7)
                (Unfortunately, despite being ask to do so, DeepSeek do not provide `.inv` files for their API).
            base_url: The base url for the DeepSeek requests. If not provided, the `OPENAI_BASE_URL` environment variable
                will be used if available. Otherwise, defaults to DeepSeek's base url.
            api_key: The API key to use for authentication, if not provided, the `OPENAI_API_KEY` environment variable
                will be used if available.
            openai_client: An existing
                [`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
                client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
            http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
        """
        self.model_name: DeepSeekModelName = model_name
        if openai_client is not None:
            assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
            assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
            assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
            self.client = openai_client
        elif http_client is not None:
            self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
        else:
            self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())

    async def agent_model(
        self,
        *,
        function_tools: list[ToolDefinition],
        allow_text_result: bool,
        result_tools: list[ToolDefinition],
    ) -> AgentModel:
        check_allow_model_requests()
        tools = [self._map_tool_definition(r) for r in function_tools]
        if result_tools:
            tools += [self._map_tool_definition(r) for r in result_tools]
        return DeepSeekAgentModel(
            self.client,
            self.model_name,
            allow_text_result,
            tools,
        )

    def name(self) -> str:
        return f'openai:{self.model_name}'

    @staticmethod
    def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
        return {
            'type': 'function',
            'function': {
                'name': f.name,
                'description': f.description,
                'parameters': f.parameters_json_schema,
            },
        }


@dataclass
class DeepSeekAgentModel(AgentModel):
    """Implementation of `AgentModel` for DeepSeek models."""

    client: AsyncOpenAI
    model_name: DeepSeekModelName
    allow_text_result: bool
    tools: list[chat.ChatCompletionToolParam]

    async def request(
        self, messages: list[ModelMessage], model_settings: ModelSettings | None
    ) -> tuple[ModelResponse, result.Usage]:
        response = await self._completions_create(messages, False, model_settings)
        return self._process_response(response), _map_usage(response)

    @asynccontextmanager
    async def request_stream(
        self, messages: list[ModelMessage], model_settings: ModelSettings | None
    ) -> AsyncIterator[EitherStreamedResponse]:
        response = await self._completions_create(messages, True, model_settings)
        async with response:
            yield await self._process_streamed_response(response)

    @overload
    async def _completions_create(
        self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
    ) -> AsyncStream[ChatCompletionChunk]:
        pass

    @overload
    async def _completions_create(
        self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
    ) -> chat.ChatCompletion:
        pass

    async def _completions_create(
        self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
    ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
        # standalone function to make it easier to override
        if not self.tools:
            tool_choice: Literal['none', 'required', 'auto'] | None = None
        elif not self.allow_text_result:
            tool_choice = 'required'
        else:
            tool_choice = 'auto'

        deepseek_messages = list(chain(*(self._map_message(m) for m in messages)))

        model_settings = model_settings or {}

        return await self.client.chat.completions.create(
            model=self.model_name,
            messages=deepseek_messages,
            n=1,
            parallel_tool_calls=True if self.tools else NOT_GIVEN,
            tools=self.tools or NOT_GIVEN,
            tool_choice=tool_choice or NOT_GIVEN,
            stream=stream,
            stream_options={'include_usage': True} if stream else NOT_GIVEN,
            max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
            temperature=model_settings.get('temperature', NOT_GIVEN),
            top_p=model_settings.get('top_p', NOT_GIVEN),
            timeout=model_settings.get('timeout', NOT_GIVEN),
        )

    # @staticmethod
    # def _process_response(response: chat.ChatCompletion) -> ModelResponse:
    #     """Process a non-streamed response, and prepare a message to return."""
    #     timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
    #     choice = response.choices[0]
    #     items: list[ModelResponsePart] = []
    #     # if choice.message.content is not None:
    #     #     items.append(TextPart(choice.message.content))
    #     if 'choices' not in response or not response['choices']:
    #       print(f"🔢 RESPONSE: {response}")
    #       raise UnexpectedModelBehavior(f'Received empty or invalid model response: {response}')
    #     choice = response['choices'][0]
    #     items: list[ModelResponsePart] = []
    #     if 'content' in choice['message']:
    #         items.append(TextPart(choice['message']['content']))
    #     if choice.message.tool_calls is not None:
    #         for c in choice.message.tool_calls:
    #             items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
    #     return ModelResponse(items, timestamp=timestamp)

    @staticmethod
    def _process_response(response: chat.ChatCompletion) -> ModelResponse:
        """Process a non-streamed response and prepare a message to return."""
        # Ensure the response contains choices
        if not response.choices:
            raise UnexpectedModelBehavior(f'Received empty or invalid model response: {response}')

        # Extract the first choice
        choice = response.choices[0]
        timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
        items: list[ModelResponsePart] = []

        # Process tool calls if they exist
        if choice.message.tool_calls:
            for tool_call in choice.message.tool_calls:
                items.append(ToolCallPart.from_raw_args(
                    tool_call.function.name,
                    tool_call.function.arguments,
                    tool_call.id
                ))

        # If there's no content or tool calls, handle it gracefully
        if not items:
            if choice.finish_reason == "stop":
                # Add a placeholder message or handle gracefully
                # print(f"⚠️ No content or tool calls in response, adding default fallback: {response}")
                items.append(TextPart("Operation completed successfully, but no further output was provided."))
            else:
                raise UnexpectedModelBehavior(
                    f"Unexpected finish_reason with no content or tool calls: {response}"
                )

        return ModelResponse(items, timestamp=timestamp)

    @staticmethod
    async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
        """Process a streamed response, and prepare a streaming response to return."""
        timestamp: datetime | None = None
        start_usage = Usage()
        # the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
        while True:
            try:
                chunk = await response.__anext__()
            except StopAsyncIteration as e:
                raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e

            timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
            start_usage += _map_usage(chunk)

            if chunk.choices:
                delta = chunk.choices[0].delta

                if delta.content is not None:
                    return DeepSeekStreamTextResponse(delta.content, response, timestamp, start_usage)
                elif delta.tool_calls is not None:
                    return DeepSeekStreamStructuredResponse(
                        response,
                        {c.index: c for c in delta.tool_calls},
                        timestamp,
                        start_usage,
                    )
                # else continue until we get either delta.content or delta.tool_calls

    @classmethod
    def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
        """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
        if isinstance(message, ModelRequest):
            yield from cls._map_user_message(message)
        elif isinstance(message, ModelResponse):
            texts: list[str] = []
            tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
            for item in message.parts:
                if isinstance(item, TextPart):
                    texts.append(item.content)
                elif isinstance(item, ToolCallPart):
                    tool_calls.append(_map_tool_call(item))
                else:
                    assert_never(item)
            message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
            if texts:
                # Note: model responses from this model should only have one text item, so the following
                # shouldn't merge multiple texts into one unless you switch models between runs:
                message_param['content'] = '\n\n'.join(texts)
            if tool_calls:
                message_param['tool_calls'] = tool_calls
            yield message_param
        else:
            assert_never(message)

    @classmethod
    def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
        for part in message.parts:
            if isinstance(part, SystemPromptPart):
                yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
            elif isinstance(part, UserPromptPart):
                yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
            elif isinstance(part, ToolReturnPart):
                yield chat.ChatCompletionToolMessageParam(
                    role='tool',
                    tool_call_id=_guard_tool_call_id(t=part, model_source='DeepSeek'),
                    content=part.model_response_str(),
                )
            elif isinstance(part, RetryPromptPart):
                if part.tool_name is None:
                    yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
                else:
                    yield chat.ChatCompletionToolMessageParam(
                        role='tool',
                        tool_call_id=_guard_tool_call_id(t=part, model_source='DeepSeek'),
                        content=part.model_response(),
                    )
            else:
                assert_never(part)


@dataclass
class DeepSeekStreamTextResponse(StreamTextResponse):
    """Implementation of `StreamTextResponse` for DeepSeek models."""

    _first: str | None
    _response: AsyncStream[ChatCompletionChunk]
    _timestamp: datetime
    _usage: result.Usage
    _buffer: list[str] = field(default_factory=list, init=False)

    async def __anext__(self) -> None:
        if self._first is not None:
            self._buffer.append(self._first)
            self._first = None
            return None

        chunk = await self._response.__anext__()
        self._usage += _map_usage(chunk)
        try:
            choice = chunk.choices[0]
        except IndexError:
            raise StopAsyncIteration()

        # we don't raise StopAsyncIteration on the last chunk because usage comes after this
        if choice.finish_reason is None:
            assert choice.delta.content is not None, f'Expected delta with content, invalid chunk: {chunk!r}'
        if choice.delta.content is not None:
            self._buffer.append(choice.delta.content)

    def get(self, *, final: bool = False) -> Iterable[str]:
        yield from self._buffer
        self._buffer.clear()

    def usage(self) -> Usage:
        return self._usage

    def timestamp(self) -> datetime:
        return self._timestamp


@dataclass
class DeepSeekStreamStructuredResponse(StreamStructuredResponse):
    """Implementation of `StreamStructuredResponse` for DeepSeek models."""

    _response: AsyncStream[ChatCompletionChunk]
    _delta_tool_calls: dict[int, ChoiceDeltaToolCall]
    _timestamp: datetime
    _usage: result.Usage

    async def __anext__(self) -> None:
        chunk = await self._response.__anext__()
        self._usage += _map_usage(chunk)
        try:
            choice = chunk.choices[0]
        except IndexError:
            raise StopAsyncIteration()

        if choice.finish_reason is not None:
            raise StopAsyncIteration()

        assert choice.delta.content is None, f'Expected tool calls, got content instead, invalid chunk: {chunk!r}'

        for new in choice.delta.tool_calls or []:
            if current := self._delta_tool_calls.get(new.index):
                if current.function is None:
                    current.function = new.function
                elif new.function is not None:
                    current.function.name = _utils.add_optional(current.function.name, new.function.name)
                    current.function.arguments = _utils.add_optional(current.function.arguments, new.function.arguments)
            else:
                self._delta_tool_calls[new.index] = new

    def get(self, *, final: bool = False) -> ModelResponse:
        items: list[ModelResponsePart] = []
        for c in self._delta_tool_calls.values():
            if f := c.function:
                if f.name is not None and f.arguments is not None:
                    items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id))

        return ModelResponse(items, timestamp=self._timestamp)

    def usage(self) -> Usage:
        return self._usage

    def timestamp(self) -> datetime:
        return self._timestamp


def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
    return chat.ChatCompletionMessageToolCallParam(
        id=_guard_tool_call_id(t=t, model_source='DeepSeek'),
        type='function',
        function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
    )


def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> result.Usage:
    usage = response.usage
    if usage is None:
        return result.Usage()
    else:
        details: dict[str, int] = {}
        if usage.completion_tokens_details is not None:
            details.update(usage.completion_tokens_details.model_dump(exclude_none=True))
        if usage.prompt_tokens_details is not None:
            details.update(usage.prompt_tokens_details.model_dump(exclude_none=True))
        return result.Usage(
            request_tokens=usage.prompt_tokens,
            response_tokens=usage.completion_tokens,
            total_tokens=usage.total_tokens,
            details=details,
        )

result:

🔢 ADD TOOL CALLED: 521312123123.2 + 321321321.2 = 521633444444.4
🔢 MEMORY: {'last_result': 521633444444.4}
Calc Result: Operation completed successfully, but no further output was provided.

📝 INPUT MESSAGE: What was the last result?
🔢 GET_LAST_RESULT TOOL CALLED: 521633444444.4
Memory: Operation completed successfully, but no further output was provided.

📝 INPUT MESSAGE: Multiply the last result by 2
🔢 GET_LAST_RESULT TOOL CALLED: 521633444444.4
Complex: Operation completed successfully, but no further output was provided.

📝 INPUT MESSAGE: What is 123.2 plus 321.2 times 423?
🔢 ADD TOOL CALLED: 123.2 + 321.2 = 444.4
🔢 MEMORY: {'last_result': 444.4}
🔢 MULTIPLY TOOL CALLED: 321.2 × 423.0 = 135867.6
🔢 ADD TOOL CALLED: 444.4 + 135867.6 = 136312.0
🔢 MEMORY: {'last_result': 136312.0}
Test: Operation completed successfully, but no further output was provided.

==================================================
📚 FETCHING HISTORY
==================================================

📥 Retrieved 108 messages
@sydney-runkle
Copy link
Member

Hi @georgiedekker,

Could you please provide a bit more context here? What is your goal / request here?

@georgiedekker
Copy link
Author

Hi @sydney-runkle,
I was testing out pydantic-ai with Deepseek v3 and it didn't quite work with the openai model that is supplied by pydantic-ai. So I copied it and made some changes to the _process_response method which made it work for the simple example included. The request is to provide a model for deepseek like is done for a bunch of other models. And in the meantime it is to share a working example as a temporary workaround.

@dmontagu
Copy link
Contributor

dmontagu commented Jan 3, 2025

I think we are open to changes/refactors to the OpenAIModel if they don't modify the behavior with OpenAI but make it more compatible with other models or make it easier to override in a useful way. Especially if the change you need is small or can be done in a way that is API-compatible with the actual OpenAI APIs, feel free to open a PR.

If it's too big of a change, you can also create a python package containing the DeepSeekModel and distribute it that way; if you do, we can reference it in our docs.

@georgiedekker
Copy link
Author

@dmontagu Thank you David. Honestly I have no idea how to create a python package. as you can see in the above code, this is the only change I had to make to make it work for my testing purposes. I just wanted to see how well deepseek v3 compared to other models like llama3.2. I haven't played around with the OpenAI model itself, or its API, so this might be completely in line with that API, or it might be a misinterpretation by the deepseek team.
Possibly I've implemented the tools wrong. Just wanted to provide something back as I've been really enjoying working with pydantic-ai after trying many other frameworks.

tldr, I noticed an error: pydantic_ai.exceptions.UnexpectedModelBehavior: Received empty model response

Asked chatgpt:
The error indicates that the system is still treating the response as invalid because pydantic_ai’s _handle_model_response function does not know how to process an empty content combined with finish_reason="stop". This issue arises because the tool result (calculation) is being completed, but no meaningful response is provided back to the assistant.

Analysis
1. DeepSeek Behavior:
• DeepSeek is processing the input, performing the tool operation (add), and returning a valid tool call, but not providing content or further instructions in the response.
• The finish_reason="stop" signifies that the model has concluded processing.
2. pydantic_ai Agent Expectation:
• The agent.run() method expects a valid ModelResponse with content or tool calls to proceed further. An empty response triggers the UnexpectedModelBehavior exception.

Tool Call Handling:
• Tool calls are explicitly processed into ToolCallPart.
2. Fallback for finish_reason="stop":
• A placeholder message is added if no content or tool calls exist, ensuring the response is still valid.
3. Strict Handling of Other Cases:
• An exception is raised for unexpected finish_reason values to catch unexpected behavior.

Implemented this and it worked for me.

Current openai model in latest pydantic-ai

@staticmethod
    def _process_response(response: chat.ChatCompletion) -> ModelResponse:
        """Process a non-streamed response, and prepare a message to return."""
        timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
        choice = response.choices[0]
        items: list[ModelResponsePart] = []
        if choice.message.content is not None:
            items.append(TextPart(choice.message.content))
        if choice.message.tool_calls is not None:
            for c in choice.message.tool_calls:
                items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
        return ModelResponse(items, timestamp=timestamp)

adjusted version to account for the empty message when it finishes with "stop"

    @staticmethod
    def _process_response(response: chat.ChatCompletion) -> ModelResponse:
        """Process a non-streamed response and prepare a message to return."""
        # Ensure the response contains choices
        if not response.choices:
            raise UnexpectedModelBehavior(f'Received empty or invalid model response: {response}')

        # Extract the first choice
        choice = response.choices[0]
        timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
        items: list[ModelResponsePart] = []

        # Process tool calls if they exist
        if choice.message.tool_calls:
            for tool_call in choice.message.tool_calls:
                items.append(ToolCallPart.from_raw_args(
                    tool_call.function.name,
                    tool_call.function.arguments,
                    tool_call.id
                ))

        # If there's no content or tool calls, handle it gracefully
        if not items:
            if choice.finish_reason == "stop":
                items.append(TextPart("Operation completed successfully, but no further output was provided."))
            else:
                raise UnexpectedModelBehavior(
                    f"Unexpected finish_reason with no content or tool calls: {response}"
                )

        return ModelResponse(items, timestamp=timestamp)

@izzyacademy
Copy link

@georgiedekker based on the exchanges, it appears you are looking for the project to add support for a new model

There are project guidelines on what the threshold is for new models to be added [1]

To add a new model with an extra dependency, that dependency needs > 500k monthly downloads from PyPI consistently over 3 months or more and to add a new model which uses another models logic internally and has no extra dependencies, that model's GitHub org needs > 20k stars in total

The new model [2] does not currently meet the threshold for the model to be included in the main package

It appears, the remaining option is for you to release your own Python package pydantic-ai-xxx, which depends on pydantic-ai-slim and implements a model that inherits from the base model.

You can take a look at [4] and [5] for how to create a python package for pydantic-ai-deepseek

I hope this helps

References

@samuelcolvin
Copy link
Member

Closed by #613 I think.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants