Skip to content

Commit

Permalink
[dotnet] Add mixin for easier state save/load apis (#5438)
Browse files Browse the repository at this point in the history
Co-authored-by: Ryan Sweet <[email protected]>
  • Loading branch information
jackgerrits and rysweet authored Feb 24, 2025
1 parent 213da85 commit 181925c
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 19 deletions.
2 changes: 1 addition & 1 deletion dotnet/src/Microsoft.AutoGen/Contracts/IAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace Microsoft.AutoGen.Contracts;
/// <summary>
/// Represents an agent within the runtime that can process messages, maintain state, and be closed when no longer needed.
/// </summary>
public interface IAgent : ISaveState<IAgent>
public interface IAgent : ISaveState
{
/// <summary>
/// Gets the unique identifier of the agent.
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace Microsoft.AutoGen.Contracts;
/// <summary>
/// Defines the runtime environment for agents, managing message sending, subscriptions, agent resolution, and state persistence.
/// </summary>
public interface IAgentRuntime : ISaveState<IAgentRuntime>
public interface IAgentRuntime : ISaveState
{
/// <summary>
/// Sends a message to an agent and gets a response.
Expand Down
13 changes: 9 additions & 4 deletions dotnet/src/Microsoft.AutoGen/Contracts/ISaveState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ namespace Microsoft.AutoGen.Contracts;
/// Defines a contract for saving and loading the state of an object.
/// The state must be JSON serializable.
/// </summary>
/// <typeparam name="T">The type of the object implementing this interface.</typeparam>
public interface ISaveState<T>
public interface ISaveState
{
/// <summary>
/// Saves the current state of the object.
Expand All @@ -20,7 +19,10 @@ public interface ISaveState<T>
/// containing the saved state. The structure of the state is implementation-defined
/// but must be JSON serializable.
/// </returns>
public ValueTask<JsonElement> SaveStateAsync();
public virtual ValueTask<JsonElement> SaveStateAsync()
{
return new ValueTask<JsonElement>(JsonDocument.Parse("{}").RootElement);
}

/// <summary>
/// Loads a previously saved state into the object.
Expand All @@ -30,6 +32,9 @@ public interface ISaveState<T>
/// is implementation-defined but must be JSON serializable.
/// </param>
/// <returns>A task representing the asynchronous operation.</returns>
public ValueTask LoadStateAsync(JsonElement state);
public virtual ValueTask LoadStateAsync(JsonElement state)
{
return ValueTask.CompletedTask;
}
}

48 changes: 48 additions & 0 deletions dotnet/src/Microsoft.AutoGen/Contracts/ISaveStateMixin.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ISaveStateMixin.cs

using System.Text.Json;

namespace Microsoft.AutoGen.Contracts;

/// <summary>
/// Defines a contract for saving and loading the state of an object.
/// The state must be JSON serializable.
/// </summary>
/// <typeparam name="T">The type of the object implementing this interface.</typeparam>
///
public interface ISaveStateMixin<T> : ISaveState
{
/// <summary>
/// Saves the current state of the object.
/// </summary>
/// <returns>
/// A task representing the asynchronous operation, returning a dictionary
/// containing the saved state. The structure of the state is implementation-defined
/// but must be JSON serializable.
/// </returns>
async ValueTask<JsonElement> ISaveState.SaveStateAsync()
{
var state = await SaveStateImpl();
return JsonSerializer.SerializeToElement(state);
}

/// <summary>
/// Loads a previously saved state into the object.
/// </summary>
/// <param name="state">
/// A dictionary representing the saved state. The structure of the state
/// is implementation-defined but must be JSON serializable.
/// </param>
/// <returns>A task representing the asynchronous operation.</returns>
ValueTask ISaveState.LoadStateAsync(JsonElement state)
{
// Throw if failed to deserialize
var stateObject = JsonSerializer.Deserialize<T>(state) ?? throw new InvalidDataException();
return LoadStateImpl(stateObject);
}

protected ValueTask<T> SaveStateImpl();

protected ValueTask LoadStateImpl(T state);
}
10 changes: 0 additions & 10 deletions dotnet/src/Microsoft.AutoGen/Core/BaseAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

using System.Diagnostics;
using System.Reflection;
using System.Text.Json;
using Microsoft.AutoGen.Contracts;
using Microsoft.Extensions.Logging;

Expand Down Expand Up @@ -93,15 +92,6 @@ private Dictionary<Type, HandlerInvoker> ReflectInvokers()
return null;
}

public virtual ValueTask<JsonElement> SaveStateAsync()
{
return ValueTask.FromResult(JsonDocument.Parse("{}").RootElement);
}
public virtual ValueTask LoadStateAsync(JsonElement state)
{
return ValueTask.CompletedTask;
}

public ValueTask<object?> SendMessageAsync(object message, AgentId recepient, string? messageId = null, CancellationToken cancellationToken = default)
{
return this.Runtime.SendMessageAsync(message, recepient, sender: this.Id, messageId: messageId, cancellationToken: cancellationToken);
Expand Down
58 changes: 58 additions & 0 deletions dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentTests.cs
using System.Text.Json;
using FluentAssertions;
using Microsoft.AutoGen.Contracts;
using Microsoft.Extensions.Logging;
Expand Down Expand Up @@ -146,4 +147,61 @@ await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) =>

Assert.True(agent.ReceivedItems.Count == 1);
}

public class AgentState
{
public required string Name { get; set; }
public required int Value { get; set; }
}

public class StateAgent(AgentId id,
IAgentRuntime runtime,
AgentState state,
Logger<BaseAgent>? logger = null) : BaseAgent(id, runtime, "Test Agent", logger),
ISaveStateMixin<AgentState>

{
ValueTask<AgentState> ISaveStateMixin<AgentState>.SaveStateImpl()
{
return ValueTask.FromResult(_state);
}

ValueTask ISaveStateMixin<AgentState>.LoadStateImpl(AgentState state)
{
_state = state;
return ValueTask.CompletedTask;
}

private AgentState _state = state;
}

[Fact]
public async Task StateMixinTest()
{
var runtime = new InProcessRuntime();
await runtime.StartAsync();
await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) =>
{
return ValueTask.FromResult(new StateAgent(id, runtime, new AgentState { Name = "TestAgent", Value = 5 }));
});

var agentId = new AgentId("MyAgent", "default");

// Get the state
var state1 = await runtime.SaveAgentStateAsync(agentId);

Assert.Equal("TestAgent", state1.GetProperty("Name").GetString());
Assert.Equal(5, state1.GetProperty("Value").GetInt32());

// Change the state
var newState = new AgentState { Name = "TestAgent", Value = 100 };
var jsonState = JsonSerializer.SerializeToElement(newState);
await runtime.LoadAgentStateAsync(agentId, jsonState);

// Get the state
var state2 = await runtime.SaveAgentStateAsync(agentId);

Assert.Equal("TestAgent", state2.GetProperty("Name").GetString());
Assert.Equal(100, state2.GetProperty("Value").GetInt32());
}
}
6 changes: 3 additions & 3 deletions dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,21 @@ public SubscribedAgent(AgentId id,
}

[TypeSubscription("TestTopic")]
public class SubscribedSaveLoadAgent : TestAgent
public class SubscribedSaveLoadAgent : TestAgent, ISaveState
{
public SubscribedSaveLoadAgent(AgentId id,
IAgentRuntime runtime,
Logger<BaseAgent>? logger = null) : base(id, runtime, logger)
{
}

public override ValueTask<JsonElement> SaveStateAsync()
ValueTask<JsonElement> ISaveState.SaveStateAsync()
{
var jsonDoc = JsonSerializer.SerializeToElement(_receivedMessages);
return ValueTask.FromResult(jsonDoc);
}

public override ValueTask LoadStateAsync(JsonElement state)
ValueTask ISaveState.LoadStateAsync(JsonElement state)
{
_receivedMessages = JsonSerializer.Deserialize<Dictionary<string, object>>(state) ?? throw new InvalidOperationException("Failed to deserialize state");
return ValueTask.CompletedTask;
Expand Down

0 comments on commit 181925c

Please sign in to comment.