Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph Support #528

Merged
merged 57 commits into from
Jan 15, 2025
Merged

Graph Support #528

merged 57 commits into from
Jan 15, 2025

Conversation

samuelcolvin
Copy link
Member

@samuelcolvin samuelcolvin commented Dec 23, 2024

TODO:

  • nodes via decorator impossible without HKT
  • infer graph name
  • tests
  • docs
  • examples

This is a work in progress, it's the result of a lot of discussion with @dmontagu.

The idea is to provide a graph/state machine library to use with PydanticAI that is as type-safe as possible in python.

NOTE: the vast majority of multi-agent examples I've seen to not need a graph or state machine, and would be more complex to write and understand if written using one. You should only use this functionality if:

  1. you understand how to use Agent as tools
  2. you've tried using standard programming techniques to link agents
  3. after that you're still sure you need a graph library and state machine

In particular this means we define edges (which nodes in a graph can breached from any given node) using type annotations, rather than some separate set_edges mechanism.

To do this we define nodes as types (that must inherit from BaseNode), to route the graph to (say) NodeB, NodeB will return an instance of NodeB which holders the input data to NodeB. Similarly to end a run, nodes should return End.

We inspect the return annotation of the run method on nodes to build the graph.

here's a minimal simple example:

Code
from __future__ import annotations as _annotations

from dataclasses import dataclass

from pydantic_graph import Graph, BaseNode, End


@dataclass
class NodeA(BaseNode[None]):
    apple: int

    async def run(self, ctx) -> NodeB:
        return NodeB(self.apple / 2)


@dataclass
class NodeB(BaseNode[None]):
    banana: float

    async def run(self, ctx) -> NodeC:
        return NodeC((int(self.banana), int(self.banana) + 4))


@dataclass
class NodeC(BaseNode[None, float]):
    pair: tuple[int, int]

    async def run(self, ctx) -> NodeA | End[float]:
        v1, v2 = self.pair
        if v1 + v2 > 10:
            return End((v1 + v2) / 3)
        else:
            return NodeA(v1 + v2)


graph = Graph(nodes=(NodeA, NodeB, NodeC))
print(graph.mermaid_code(start_node=NodeA))


result, history = graph.run_sync(None, NodeA(8))
print(result)

result, history = graph.run_sync(None, NodeA(7))
print(result)

The mermaid chart printed in the example looks like this:

---
title: graph
---
stateDiagram-v2
  [*] --> NodeA
  NodeA --> NodeB
  NodeB --> NodeC
  NodeC --> NodeA
  NodeC --> [*]
Loading

And the rest of the output is:

4.0
4.666666666666667

The graph library is completely independent of LLM use cases, but can relatively easily be used with pydantic-ai's Agent, see the examples/pydantic_ai_examples/email_extract_graph.py example.

Copy link

cloudflare-workers-and-pages bot commented Dec 23, 2024

Deploying pydantic-ai with  Cloudflare Pages  Cloudflare Pages

Latest commit: a2144d6
Status: ✅  Deploy successful!
Preview URL: https://2d84d435.pydantic-ai.pages.dev
Branch Preview URL: https://graph.pydantic-ai.pages.dev

View logs

This comment was marked as off-topic.

@brettkromkamp
Copy link

brettkromkamp commented Jan 7, 2025

In my opinion this feature is critical for adoption of the PydanticAI framework. Any timeframe when this will land in main, @samuelcolvin?

I very much like the approach of using type annotations and returns instead of a separate set_edge mechanism (described above)... really gives me a LlamaIndex Workflows vibe as opposed to the more complex LangGraph approach.

@samuelcolvin
Copy link
Member Author

@brettkromkamp we'll do our best to get something merged and released this week.

@samuelcolvin samuelcolvin force-pushed the graph branch 2 times, most recently from c52885b to e7b3949 Compare January 7, 2025 21:41
@samuelcolvin samuelcolvin mentioned this pull request Jan 8, 2025
@samuelcolvin
Copy link
Member Author

samuelcolvin commented Jan 9, 2025

I've added support for Interrupt to interrupt a run and continue it from the right place.

I've remove Interrupt and instead added next() method which provides the same functionality, but is more flexible and easier to understand.

Here's an example using it:

from __future__ import annotations as _annotations

from dataclasses import dataclass
from typing import Annotated

import logfire

from pydantic_ai import Agent
from pydantic_ai.messages import ModelMessage
from pydantic_ai.format_as_xml import format_as_xml
from pydantic_graph import Graph, BaseNode, End, GraphContext, AbstractState, Edge

logfire.configure()
ask_agent = Agent('openai:gpt-4o', result_type=str)


@dataclass
class QuestionState(AbstractState):
    ask_agent_messages: list[ModelMessage] | None = None

    def serialize(self) -> bytes | None:
        raise NotImplementedError('TODO')


@dataclass
class Ask(BaseNode[QuestionState]):
    """Generate a question to ask the user.

    Uses the GPT-4o model to generate a question.
    """
    async def run(self, ctx: GraphContext[QuestionState]) -> Annotated[Answer, Edge(label='ask the question')]:
        result = await ask_agent.run(
            'Ask a simple question with a single correct answer.', message_history=ctx.state.ask_agent_messages
        )
        if ctx.state.ask_agent_messages is None:
            ctx.state.ask_agent_messages = []
        ctx.state.ask_agent_messages += result.all_messages()
        return Answer(result.data)


@dataclass
class Answer(BaseNode[QuestionState]):
    question: str
    answer: str | None = None

    async def run(self, ctx: GraphContext[QuestionState]) -> Annotated[Evaluate, Edge(label='answer the question')]:
        assert self.answer is not None
        return Evaluate(self.question, self.answer)


@dataclass
class EvaluationResult:
    correct: bool
    comment: str


evaluate_agent = Agent(
    'openai:gpt-4o',
    result_type=EvaluationResult,
    system_prompt='Given a question and answer, evaluate if the answer is correct.',
    result_tool_name='evaluation',
)


@dataclass
class Evaluate(BaseNode[QuestionState]):
    question: str
    answer: str

    async def run(self, ctx: GraphContext[QuestionState]) -> Congratulate | Castigate:
        result = await evaluate_agent.run(format_as_xml({'question': self.question, 'answer': self.answer}))
        if result.data.correct:
            return Congratulate(result.data.comment)
        else:
            return Castigate(result.data.comment)


@dataclass
class Congratulate(BaseNode[QuestionState, None]):
    comment: str

    async def run(self, ctx: GraphContext[QuestionState]) -> End:
        print(f'Correct answer! {self.comment}')
        return End(None)


@dataclass
class Castigate(BaseNode[QuestionState]):
    comment: str

    async def run(self, ctx: GraphContext[QuestionState]) -> Ask:
        print(f'Comment: {self.comment}')
        return Ask()


graph = Graph(nodes=(Ask, Answer, Evaluate, Congratulate, Castigate))


@dataclass
class QuestionState(AbstractState):
    ask_agent_messages: list[ModelMessage] | None = None

    def serialize(self) -> bytes | None:
        raise NotImplementedError('TODO')


print(graph.mermaid_code(start_node=Ask))
graph.mermaid_save('questions_graph.svg', start_node=Ask)


async def main():
    node = Ask()
    state = QuestionState()
    history = []
    with logfire.span('run questions graph'):
        while True:
            node = await graph.next(state, node, history)
            if isinstance(node, End):
                print('\n'.join(e.summary() for e in history))
                break
            elif isinstance(node, Answer):
                node.answer = input(f'{node.question} ')
            # otherwise just continue


if __name__ == '__main__':
    import asyncio

    asyncio.run(main())

Which has the following graph:

stateDiagram-v2
  [*] --> Ask
  Ask --> Answer: ask the question
  note right of Ask
    Generate a question to ask the user.
    Uses the GPT-4o model to generate a question.
  end note
  Answer --> Evaluate: answer the question
  Evaluate --> Congratulate
  Evaluate --> Castigate
  Congratulate --> [*]
  Castigate --> Ask
Loading

You'll see that the Answer node is square, not rounded to identify a point where the graph may restart after interruption. this is no longer included, I don't think that matters.

@brettkromkamp
Copy link

Just wondering if the new interrupt mechanism can be used for HITL-purposes? Or, is it more for retrying steps in case of failures. It could also be a general mechanism for all kinds of purposes. I'll take a closer look... definitely exciting to see how this feature is developing, though.

@dmontagu
Copy link
Contributor

dmontagu commented Jan 9, 2025

Just wondering if the new interrupt mechanism can be used for HITL-purposes? Or, is it more for retrying steps in case of failures. It could also be a general mechanism for all kinds of purposes. I'll take a closer look... definitely exciting to see how this feature is developing, though.

It is definitely explicitly and primarily intended for facilitating HITL; if it's useful for other purposes then of course that's great but most of the discussion we've been having about the feature has been oriented around how to use it for human feedback.

@@ -25,15 +25,19 @@ def serialize(self) -> bytes | None:
"""Serialize the state object."""
raise NotImplementedError
Copy link
Contributor

@dmontagu dmontagu Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can eliminate this AbstractState type by moving the serialization and/or copying logic to be kwargs of the graph, and if not provided, use copy.deepcopy (or noop if None as you've done) for copying, and pydantic_core.to_json for serialization. That would let you use a typical basemodel/dataclass/typeddict as state with minimal boilerplate.

(Because the graph is aware of the state type, we can still use type hints on the kwargs like serializer: Callable[[StateT], bytes] to get the same type safety you'd get from a method.)

@samuelcolvin
Copy link
Member Author

Just wondering if the new interrupt mechanism can be used for HITL-purposes? Or, is it more for retrying steps in case of failures. It could also be a general mechanism for all kinds of purposes. I'll take a closer look... definitely exciting to see how this feature is developing, though.

@brettkromkamp I've removed Interrupt and replaced with what we think is a better API, I've updated my example above.

@ME-Msc
Copy link
Contributor

ME-Msc commented Jan 10, 2025

Hi, team. I have some questions about graph support.

  1. It seems that a node's multiple out-edges refer to different transition conditions. Will it be possible to see the condition on the mermaid graph?
  2. Is there any support for running the subsequent nodes in parallel (mentioned in Functionality to Define Multi-Agent Graphs and Workflows #529 by @izzyacademy ) ?

Copy link
Contributor

@hyperlint-ai hyperlint-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The style guide flagged several spelling errors that seemed like false positives. We skipped posting inline suggestions for the following words:

  • Pydantic

@samuelcolvin
Copy link
Member Author

  1. It seems that a node's multiple out-edges refer to different transition conditions. Will it be possible to see the condition on the mermaid graph?

Hi @ME-Msc, I'm not exactly sure what you mean here?

I'm going to provide a way to label an edge, but you won't be able to "see the logic" that leads to an edge being followed, as that's just procedural python code.

  1. Is there any support for running the subsequent nodes in parallel (mentioned in Functionality to Define Multi-Agent Graphs and Workflows #529 by @izzyacademy ) ?

Not yet, we might add it in future.

@izzyacademy
Copy link
Contributor

@samuelcolvin I think I have an idea of what he is asking. I had similar thoughts earlier.

It seems @ME-Msc is looking for a mechanism to annotate (within the docstring) the pydantic_graph.BaseNode.run() method with a small note/comment to indicate what condition causes this node to route to the next node returned by this node. It looks like we could parse the docstring for a special tag or something from BaseNode.run() to get a list of conditions and then inject this into the mermaid code generated so that it shows up in the graph image generated.

@ME-Msc I think the goal of the project is to avoid fancy syntax that does not give you visibility into how the parallel nodes/tasks are run. What I would recommend is for you to dedicate a node that can aggregate all the parallel tasks and then spin up async tasks in that node using regular python code that you have 100% visibility and control so that you can see the exceptions, cancellations etc without having to get stressed out when things deviate from the happy path. I hope this helps.

I am working on an example for this because I think many users will have the similar questions/needs based on how they are using other frameworks with custom syntax for routing to parallel nodes in graph transitions

@cxbxmxcx
Copy link

I've added support for Interrupt to interrupt a run and continue it from the right place.

I've remove Interrupt and instead added next() method which provides the same functionality, but is more flexible and easier to understand.

Here's an example using it:

from __future__ import annotations as _annotations

from dataclasses import dataclass
from typing import Annotated

import logfire

from pydantic_ai import Agent
from pydantic_ai.messages import ModelMessage
from pydantic_ai.format_as_xml import format_as_xml
from pydantic_graph import Graph, BaseNode, End, GraphContext, AbstractState, Edge

logfire.configure()
ask_agent = Agent('openai:gpt-4o', result_type=str)


@dataclass
class QuestionState(AbstractState):
    ask_agent_messages: list[ModelMessage] | None = None

    def serialize(self) -> bytes | None:
        raise NotImplementedError('TODO')


@dataclass
class Ask(BaseNode[QuestionState]):
    """Generate a question to ask the user.

    Uses the GPT-4o model to generate a question.
    """
    async def run(self, ctx: GraphContext[QuestionState]) -> Annotated[Answer, Edge(label='ask the question')]:
        result = await ask_agent.run(
            'Ask a simple question with a single correct answer.', message_history=ctx.state.ask_agent_messages
        )
        if ctx.state.ask_agent_messages is None:
            ctx.state.ask_agent_messages = []
        ctx.state.ask_agent_messages += result.all_messages()
        return Answer(result.data)


@dataclass
class Answer(BaseNode[QuestionState]):
    question: str
    answer: str | None = None

    async def run(self, ctx: GraphContext[QuestionState]) -> Annotated[Evaluate, Edge(label='answer the question')]:
        assert self.answer is not None
        return Evaluate(self.question, self.answer)


@dataclass
class EvaluationResult:
    correct: bool
    comment: str


evaluate_agent = Agent(
    'openai:gpt-4o',
    result_type=EvaluationResult,
    system_prompt='Given a question and answer, evaluate if the answer is correct.',
    result_tool_name='evaluation',
)


@dataclass
class Evaluate(BaseNode[QuestionState]):
    question: str
    answer: str

    async def run(self, ctx: GraphContext[QuestionState]) -> Congratulate | Castigate:
        result = await evaluate_agent.run(format_as_xml({'question': self.question, 'answer': self.answer}))
        if result.data.correct:
            return Congratulate(result.data.comment)
        else:
            return Castigate(result.data.comment)


@dataclass
class Congratulate(BaseNode[QuestionState, None]):
    comment: str

    async def run(self, ctx: GraphContext[QuestionState]) -> End:
        print(f'Correct answer! {self.comment}')
        return End(None)


@dataclass
class Castigate(BaseNode[QuestionState]):
    comment: str

    async def run(self, ctx: GraphContext[QuestionState]) -> Ask:
        print(f'Comment: {self.comment}')
        return Ask()


graph = Graph(nodes=(Ask, Answer, Evaluate, Congratulate, Castigate))


@dataclass
class QuestionState(AbstractState):
    ask_agent_messages: list[ModelMessage] | None = None

    def serialize(self) -> bytes | None:
        raise NotImplementedError('TODO')


print(graph.mermaid_code(start_node=Ask))
graph.mermaid_save('questions_graph.svg', start_node=Ask)


async def main():
    node = Ask()
    state = QuestionState()
    history = []
    with logfire.span('run questions graph'):
        while True:
            node = await graph.next(state, node, history)
            if isinstance(node, End):
                print('\n'.join(e.summary() for e in history))
                break
            elif isinstance(node, Answer):
                node.answer = input(f'{node.question} ')
            # otherwise just continue


if __name__ == '__main__':
    import asyncio

    asyncio.run(main())

Which has the following graph:

stateDiagram-v2
  [*] --> Ask
  Ask --> Answer: ask the question
  note right of Ask
    Generate a question to ask the user.
    Uses the GPT-4o model to generate a question.
  end note
  Answer --> Evaluate: answer the question
  Evaluate --> Congratulate
  Evaluate --> Castigate
  Congratulate --> [*]
  Castigate --> Ask
Loading

You'll see that the Answer node is square, not rounded to identify a point where the graph may restart after interruption. this is no longer included, I don't think that matters.

First off just want to say this looks and runs great. Thanks for all the effort on this. I wrote a book called AI Agents In Action (Manning) and used Behavior Trees to control agents but I also really like this approach.

Question: I pulled the latest pydantic-graph from PyPi (could not find a repo for this) and the Edge class is missing. I quickly had Claude generate me one so I have things working. Is there a repo for this work?

@izzyacademy
Copy link
Contributor

izzyacademy commented Jan 15, 2025

@cxbxmxcx the work is still in progress.

The latest versions in pydantic-ai-graph and pydantic-graph are out of date by a couple of days.

You can clone the repo and then switch to the graph branch

From there you can install hatch and hatchling and then use hatch to build the .whl file for local installation and tests

git clone [email protected]:pydantic/pydantic-ai.git 
cd pydantic-ai

git checkout graph

cd pydantic_graph
hatch build

I hope this helps.

@cxbxmxcx
Copy link

@cxbxmxcx the work is still in progress. You can clone the repo and then switch to the graph branch

From there you can install hatch and hatchling and then use hatch to build the .whl file for local installation and tests

git clone [email protected]:pydantic/pydantic-ai.git 
cd pydantic-ai

git checkout graph

cd pydantic_graph
hatch build

I hope this helps.

Thanks for help, yes it should.

Co-authored-by: David Montague <[email protected]>
Co-authored-by: Israel Ekpo <[email protected]>
@samuelcolvin samuelcolvin merged commit 53fcb50 into main Jan 15, 2025
17 checks passed
@samuelcolvin samuelcolvin deleted the graph branch January 15, 2025 19:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants