diff --git a/python/packages/autogen-ext/src/autogen_ext/agents/openai/_openai_assistant_agent.py b/python/packages/autogen-ext/src/autogen_ext/agents/openai/_openai_assistant_agent.py index 0419d2954889..6da6755efc00 100644 --- a/python/packages/autogen-ext/src/autogen_ext/agents/openai/_openai_assistant_agent.py +++ b/python/packages/autogen-ext/src/autogen_ext/agents/openai/_openai_assistant_agent.py @@ -11,6 +11,7 @@ Iterable, List, Literal, + Mapping, Optional, Sequence, Set, @@ -36,6 +37,7 @@ from autogen_core.models._model_client import ChatCompletionClient from autogen_core.models._types import FunctionExecutionResult from autogen_core.tools import FunctionTool, Tool +from pydantic import BaseModel, Field from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI, NotGiven from openai.pagination import AsyncCursorPage @@ -77,6 +79,15 @@ def _convert_tool_to_function_param(tool: Tool) -> "FunctionToolParam": return FunctionToolParam(type="function", function=function_def) +class OpenAIAssistantAgentState(BaseModel): + type: str = Field(default="OpenAIAssistantAgentState") + assistant_id: Optional[str] = None + thread_id: Optional[str] = None + initial_message_ids: List[str] = Field(default_factory=list) + vector_store_id: Optional[str] = None + uploaded_file_ids: List[str] = Field(default_factory=list) + + class OpenAIAssistantAgent(BaseChatAgent): """An agent implementation that uses the Assistant API to generate responses. @@ -666,3 +677,21 @@ async def delete_vector_store(self, cancellation_token: CancellationToken) -> No self._vector_store_id = None except Exception as e: event_logger.error(f"Failed to delete vector store: {str(e)}") + + async def save_state(self) -> Mapping[str, Any]: + state = OpenAIAssistantAgentState( + assistant_id=self._assistant.id if self._assistant else self._assistant_id, + thread_id=self._thread.id if self._thread else self._init_thread_id, + initial_message_ids=list(self._initial_message_ids), + vector_store_id=self._vector_store_id, + uploaded_file_ids=self._uploaded_file_ids, + ) + return state.model_dump() + + async def load_state(self, state: Mapping[str, Any]) -> None: + agent_state = OpenAIAssistantAgentState.model_validate(state) + self._assistant_id = agent_state.assistant_id + self._init_thread_id = agent_state.thread_id + self._initial_message_ids = set(agent_state.initial_message_ids) + self._vector_store_id = agent_state.vector_store_id + self._uploaded_file_ids = agent_state.uploaded_file_ids diff --git a/python/packages/autogen-ext/tests/test_openai_assistant_agent.py b/python/packages/autogen-ext/tests/test_openai_assistant_agent.py index da55d860c674..e6512c9fb881 100644 --- a/python/packages/autogen-ext/tests/test_openai_assistant_agent.py +++ b/python/packages/autogen-ext/tests/test_openai_assistant_agent.py @@ -1,9 +1,14 @@ +import io import os +from contextlib import asynccontextmanager from enum import Enum -from typing import List, Literal, Optional, Union +from pathlib import Path +from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Union +from unittest.mock import AsyncMock, MagicMock +import aiofiles import pytest -from autogen_agentchat.messages import TextMessage +from autogen_agentchat.messages import ChatMessage, TextMessage from autogen_core import CancellationToken from autogen_core.tools._base import BaseTool, Tool from autogen_ext.agents.openai import OpenAIAssistantAgent @@ -57,14 +62,104 @@ async def run(self, args: DisplayQuizArgs, cancellation_token: CancellationToken return QuizResponses(responses=responses) +class FakeText: + def __init__(self, value: str): + self.value = value + + +class FakeTextContent: + def __init__(self, text: str): + self.type = "text" + self.text = FakeText(text) + + +class FakeMessage: + def __init__(self, id: str, text: str): + self.id = id + # The agent expects content to be a list of objects with a "type" attribute. + self.content = [FakeTextContent(text)] + + +class FakeCursorPage: + def __init__(self, data: List[ChatMessage | FakeMessage]) -> None: + self.data = data + + def has_next_page(self) -> bool: + return False + + +def create_mock_openai_client() -> AsyncAzureOpenAI: + # Create the base client as an AsyncMock. + client = AsyncMock(spec=AsyncAzureOpenAI) + + # Create a "beta" attribute with the required nested structure. + beta = MagicMock() + client.beta = beta + + # Setup beta.assistants with dummy create/retrieve/update/delete. + beta.assistants = MagicMock() + beta.assistants.create = AsyncMock(return_value=MagicMock(id="assistant-mock")) + beta.assistants.retrieve = AsyncMock(return_value=MagicMock(id="assistant-mock")) + beta.assistants.update = AsyncMock(return_value=MagicMock(id="assistant-mock")) + beta.assistants.delete = AsyncMock(return_value=None) + + # Setup beta.threads with create and retrieve. + beta.threads = MagicMock() + beta.threads.create = AsyncMock(return_value=MagicMock(id="thread-mock", tool_resources=None)) + beta.threads.retrieve = AsyncMock(return_value=MagicMock(id="thread-mock", tool_resources=None)) + + # Setup beta.threads.messages with create, list, and delete. + beta.threads.messages = MagicMock() + beta.threads.messages.create = AsyncMock(return_value=MagicMock(id="msg-mock", content="mock content")) + + # Default fake messages – these may be overridden in individual tests. + name_message = FakeMessage("msg-mock", "Your name is John, you are a software engineer.") + + def mock_list(thread_id: str, **kwargs: Dict[str, Any]) -> FakeCursorPage: + # Default behavior returns the "name" message. + if thread_id == "thread-mock": + return FakeCursorPage([name_message]) + return FakeCursorPage([FakeMessage("msg-mock", "Default response")]) + + beta.threads.messages.list = AsyncMock(side_effect=mock_list) + beta.threads.messages.delete = AsyncMock(return_value=MagicMock(deleted=True)) + + # Setup beta.threads.runs with create, retrieve, and submit_tool_outputs. + beta.threads.runs = MagicMock() + beta.threads.runs.create = AsyncMock(return_value=MagicMock(id="run-mock", status="completed")) + beta.threads.runs.retrieve = AsyncMock(return_value=MagicMock(id="run-mock", status="completed")) + beta.threads.runs.submit_tool_outputs = AsyncMock(return_value=MagicMock(id="run-mock", status="completed")) + + # Setup beta.vector_stores with create, delete, and file_batches. + beta.vector_stores = MagicMock() + beta.vector_stores.create = AsyncMock(return_value=MagicMock(id="vector-mock")) + beta.vector_stores.delete = AsyncMock(return_value=None) + beta.vector_stores.file_batches = MagicMock() + beta.vector_stores.file_batches.create_and_poll = AsyncMock(return_value=None) + + # Setup client.files with create and delete. + client.files = MagicMock() + client.files.create = AsyncMock(return_value=MagicMock(id="file-mock")) + client.files.delete = AsyncMock(return_value=None) + + return client + + +# Fixture for the mock client. +@pytest.fixture +def mock_openai_client() -> AsyncAzureOpenAI: + return create_mock_openai_client() + + @pytest.fixture def client() -> AsyncAzureOpenAI: azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview") api_key = os.getenv("AZURE_OPENAI_API_KEY") - if not azure_endpoint: - pytest.skip("Azure OpenAI endpoint not found in environment variables") + # Return mock client if credentials not available + if not azure_endpoint or not api_key: + return create_mock_openai_client() # Try Azure CLI credentials if API key not provided if not api_key: @@ -76,7 +171,7 @@ def client() -> AsyncAzureOpenAI: azure_endpoint=azure_endpoint, api_version=api_version, azure_ad_token_provider=token_provider ) except Exception: - pytest.skip("Failed to get Azure CLI credentials and no API key provided") + return create_mock_openai_client() # Fall back to API key auth if provided return AsyncAzureOpenAI(azure_endpoint=azure_endpoint, api_version=api_version, api_key=api_key) @@ -105,10 +200,38 @@ def cancellation_token() -> CancellationToken: return CancellationToken() +# A fake aiofiles.open to bypass filesystem access. +@asynccontextmanager +async def fake_aiofiles_open(*args: Any, **kwargs: Dict[str, Any]) -> AsyncGenerator[io.BytesIO, None]: + yield io.BytesIO(b"dummy file content") + + @pytest.mark.asyncio -async def test_file_retrieval(agent: OpenAIAssistantAgent, cancellation_token: CancellationToken) -> None: - file_path = r"C:\Users\lpinheiro\Github\autogen-test\data\SampleBooks\jungle_book.txt" - await agent.on_upload_for_file_search(file_path, cancellation_token) +async def test_file_retrieval( + agent: OpenAIAssistantAgent, cancellation_token: CancellationToken, monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + # Arrange: Define a fake async file opener that returns a file-like object with an async read() method. + class FakeAiofilesFile: + async def read(self) -> bytes: + return b"dummy file content" + + @asynccontextmanager + async def fake_async_aiofiles_open(*args: Any, **kwargs: Dict[str, Any]) -> AsyncGenerator[FakeAiofilesFile, None]: + yield FakeAiofilesFile() + + monkeypatch.setattr(aiofiles, "open", fake_async_aiofiles_open) + + # We also override the messages.list to return a fake file search result. + fake_file_message = FakeMessage( + "msg-mock", "The first sentence of the jungle book is 'Mowgli was raised by wolves.'" + ) + agent._client.beta.threads.messages.list = AsyncMock(return_value=FakeCursorPage([fake_file_message])) # type: ignore + + # Create a temporary file. + file_path = tmp_path / "jungle_book.txt" + file_path.write_text("dummy content") + + await agent.on_upload_for_file_search(str(file_path), cancellation_token) message = TextMessage(source="user", content="What is the first sentence of the jungle scout book?") response = await agent.on_messages([message], cancellation_token) @@ -123,7 +246,14 @@ async def test_file_retrieval(agent: OpenAIAssistantAgent, cancellation_token: C @pytest.mark.asyncio -async def test_code_interpreter(agent: OpenAIAssistantAgent, cancellation_token: CancellationToken) -> None: +async def test_code_interpreter( + agent: OpenAIAssistantAgent, cancellation_token: CancellationToken, monkeypatch: pytest.MonkeyPatch +) -> None: + # Arrange: For code interpreter, have the messages.list return a result with "x = 1". + agent._client.beta.threads.messages.list = AsyncMock( # type: ignore + return_value=FakeCursorPage([FakeMessage("msg-mock", "x = 1")]) + ) + message = TextMessage(source="user", content="I need to solve the equation `3x + 11 = 14`. Can you help me?") response = await agent.on_messages([message], cancellation_token) @@ -136,25 +266,64 @@ async def test_code_interpreter(agent: OpenAIAssistantAgent, cancellation_token: @pytest.mark.asyncio -async def test_quiz_creation(agent: OpenAIAssistantAgent, cancellation_token: CancellationToken) -> None: +async def test_quiz_creation( + agent: OpenAIAssistantAgent, cancellation_token: CancellationToken, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setattr(DisplayQuizTool, "run_json", DisplayQuizTool.run) + + # Create a fake tool call for display_quiz. + fake_tool_call = MagicMock() + fake_tool_call.type = "function" + fake_tool_call.id = "tool-call-1" + fake_tool_call.function = MagicMock() + fake_tool_call.function.name = "display_quiz" + fake_tool_call.function.arguments = ( + '{"title": "Quiz Title", "questions": [{"question_text": "What is 2+2?", ' + '"question_type": "MULTIPLE_CHOICE", "choices": ["3", "4", "5"]}]}' + ) + + # Create a run that requires action (tool call). + run_requires_action = MagicMock() + run_requires_action.id = "run-mock" + run_requires_action.status = "requires_action" + run_requires_action.required_action = MagicMock() + run_requires_action.required_action.submit_tool_outputs = MagicMock() + run_requires_action.required_action.submit_tool_outputs.tool_calls = [fake_tool_call] + + # Create a completed run for the subsequent retrieval. + run_completed = MagicMock() + run_completed.id = "run-mock" + run_completed.status = "completed" + run_completed.required_action = None + + # Set up the beta.threads.runs.retrieve mock to return these in sequence. + agent._client.beta.threads.runs.retrieve.side_effect = [run_requires_action, run_completed] # type: ignore + + # Also, set the messages.list call (after run completion) to return a quiz message. + quiz_tool_message = FakeMessage("msg-mock", "Quiz created: Q1) 2+2=? Answer: b) 4; Q2) Free: Sample free response") + agent._client.beta.threads.messages.list = AsyncMock(return_value=FakeCursorPage([quiz_tool_message])) # type: ignore + + # Create a user message to trigger the tool invocation. message = TextMessage( source="user", content="Create a short quiz about basic math with one multiple choice question and one free response question.", ) response = await agent.on_messages([message], cancellation_token) + # Check that the final response has non-empty inner messages (i.e. tool call events). assert response.chat_message.content is not None assert isinstance(response.chat_message.content, str) assert len(response.chat_message.content) > 0 assert isinstance(response.inner_messages, list) - assert any(tool_msg.content for tool_msg in response.inner_messages if hasattr(tool_msg, "content")) + # Ensure that at least one inner message has non-empty content. + assert any(hasattr(tool_msg, "content") and tool_msg.content for tool_msg in response.inner_messages) await agent.delete_assistant(cancellation_token) @pytest.mark.asyncio async def test_on_reset_behavior(client: AsyncAzureOpenAI, cancellation_token: CancellationToken) -> None: - # Create thread with initial message + # Arrange: Use the default behavior for reset. thread = await client.beta.threads.create() await client.beta.threads.messages.create( thread_id=thread.id, @@ -162,7 +331,6 @@ async def test_on_reset_behavior(client: AsyncAzureOpenAI, cancellation_token: C role="user", ) - # Create agent with existing thread agent = OpenAIAssistantAgent( name="assistant", instructions="Help the user with their task.", @@ -172,19 +340,51 @@ async def test_on_reset_behavior(client: AsyncAzureOpenAI, cancellation_token: C thread_id=thread.id, ) - # Test before reset message1 = TextMessage(source="user", content="What is my name?") response1 = await agent.on_messages([message1], cancellation_token) assert isinstance(response1.chat_message.content, str) assert "john" in response1.chat_message.content.lower() - # Reset agent state await agent.on_reset(cancellation_token) - # Test after reset message2 = TextMessage(source="user", content="What is my name?") response2 = await agent.on_messages([message2], cancellation_token) assert isinstance(response2.chat_message.content, str) assert "john" in response2.chat_message.content.lower() await agent.delete_assistant(cancellation_token) + + +@pytest.mark.asyncio +async def test_save_and_load_state(mock_openai_client: AsyncAzureOpenAI) -> None: + agent = OpenAIAssistantAgent( + name="assistant", + description="Dummy assistant for state testing", + client=mock_openai_client, + model="dummy-model", + instructions="dummy instructions", + tools=[], + ) + agent._assistant_id = "assistant-123" # type: ignore + agent._init_thread_id = "thread-456" # type: ignore + agent._initial_message_ids = {"msg1", "msg2"} # type: ignore + agent._vector_store_id = "vector-789" # type: ignore + agent._uploaded_file_ids = ["file-abc", "file-def"] # type: ignore + + saved_state = await agent.save_state() + + new_agent = OpenAIAssistantAgent( + name="assistant", + description="Dummy assistant for state testing", + client=mock_openai_client, + model="dummy-model", + instructions="dummy instructions", + tools=[], + ) + await new_agent.load_state(saved_state) + + assert new_agent._assistant_id == "assistant-123" # type: ignore + assert new_agent._init_thread_id == "thread-456" # type: ignore + assert new_agent._initial_message_ids == {"msg1", "msg2"} # type: ignore + assert new_agent._vector_store_id == "vector-789" # type: ignore + assert new_agent._uploaded_file_ids == ["file-abc", "file-def"] # type: ignore