Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Pydantic to v2 and redis-om to v0.3. #138

Merged
merged 11 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3,276 changes: 1,775 additions & 1,501 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ openai = "^1.11.0"
langchain = "~0.2.5"
rich = "^13.6.0"
PettingZoo = "1.24.3"
redis-om = "^0.2.1"
redis-om = "^0.3.1"
gin-config = "^0.5.0"
absl-py = "^2.0.0"
together = "^0.2.4"
pydantic = "1.10.17"
pydantic = "^2.8.2"
beartype = "^0.14.0"
langchain-openai = "~0.1.8"
litellm = ">=1.23.12,<1.41.0"
Expand Down
4 changes: 2 additions & 2 deletions sotopia/agents/redis_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ async def aact(
return AgentAction(action_type="leave", argument="")
action_string = sorted_message_list[-1][2]
try:
action = AgentAction.parse_raw(action_string)
action = AgentAction.model_validate_json(action_string)
return action
except pydantic.error_wrappers.ValidationError:
except pydantic.ValidationError:
logging.warn(
"Failed to parse action string {}. Fall back to speak".format(
action_string
Expand Down
6 changes: 3 additions & 3 deletions sotopia/database/annotators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pydantic import Field
from redis_om.model.model import Field
from redis_om import JsonModel


class Annotator(JsonModel):
name: str = Field(index=True, required=True)
email: str = Field(index=True, required=True)
name: str = Field(index=True)
email: str = Field(index=True)
33 changes: 15 additions & 18 deletions sotopia/database/logs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import Any
import sys

from pydantic import root_validator
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

from pydantic import model_validator
from redis_om import JsonModel
from redis_om.model.model import Field

Expand All @@ -14,29 +19,21 @@ class EpisodeLog(JsonModel):

environment: str = Field(index=True)
agents: list[str] = Field(index=True)
tag: str | None = Field(index=True)
models: list[str] | None = Field(index=True)
tag: str | None = Field(index=True, default="")
models: list[str] | None = Field(index=True, default=[])
messages: list[list[tuple[str, str, str]]] # Messages arranged by turn
ProKil marked this conversation as resolved.
Show resolved Hide resolved
reasoning: str
rewards: list[tuple[float, dict[str, float]] | float] # Rewards arranged by turn
rewards_prompt: str

@root_validator(skip_on_failure=True)
def agent_number_message_number_reward_number_turn_number_match(
cls, values: Any
) -> Any:
agents, _, _reasoning, rewards = (
values.get("agents"),
values.get("messages"),
values.get("reasoning"),
values.get("rewards"),
)
agent_number = len(agents)
@model_validator(mode="after")
def agent_number_message_number_reward_number_turn_number_match(self) -> Self:
agent_number = len(self.agents)

assert (
len(rewards) == agent_number
), f"Number of agents in rewards {len(rewards)} and agents {agent_number} do not match"
return values
len(self.rewards) == agent_number
), f"Number of agents in rewards {len(self.rewards)} and agents {agent_number} do not match"
return self

def render_for_humans(self) -> tuple[list[AgentProfile], list[str]]:
"""Generate a human readable version of the episode log.
Expand Down
21 changes: 13 additions & 8 deletions sotopia/database/persistent_profile.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from enum import IntEnum
from typing import Any
import sys

from pydantic import root_validator
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

from pydantic import model_validator
from redis_om import JsonModel
from redis_om.model.model import Field

Expand Down Expand Up @@ -87,15 +92,15 @@ class EnvironmentList(JsonModel):
agent_index: list[str] | None = Field(default_factory=lambda: None)

# validate the length of agent_index should be same as environments
@root_validator
def the_length_agent_index_matches_environments(cls, values: Any) -> Any:
@model_validator(mode="after")
def the_length_agent_index_matches_environments(self) -> Self:
environments, agent_index = (
values.get("environments"),
values.get("agent_index"),
self.environments,
self.agent_index,
)
if agent_index is None:
return values
return self
assert (
len(environments) == len(agent_index)
), f"Number of environments {len(environments)} and agent_index {len(agent_index)} do not match"
return values
return self
99 changes: 50 additions & 49 deletions sotopia/database/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import json
from typing import Any

from pydantic import BaseModel, Field
from pydantic import BaseModel
from redis_om.model.model import Field

from .env_agent_combo_storage import EnvAgentComboStorage
from .logs import EpisodeLog
Expand All @@ -14,66 +15,66 @@


class TwoAgentEpisodeWithScenarioBackgroundGoals(BaseModel):
episode_id: str = Field(required=True)
environment_id: str = Field(required=True)
agent_ids: list[str] = Field(required=True)
experiment_tag: str = Field(required=True)
experiment_model_name_pairs: list[str] = Field(required=True)
raw_messages: list[list[tuple[str, str, str]]] = Field(required=True)
raw_rewards: list[tuple[float, dict[str, float]] | float] = Field(required=True)
raw_rewards_prompt: str = Field(required=True)
scenario: str = Field(required=True)
codename: str = Field(required=True)
agents_background: dict[str, str] = Field(required=True)
social_goals: dict[str, str] = Field(required=True)
social_interactions: str = Field(required=True)
reasoning: str = Field(required=False)
rewards: list[tuple[float, dict[str, float]]] = Field(required=False)
episode_id: str = Field()
ProKil marked this conversation as resolved.
Show resolved Hide resolved
environment_id: str = Field()
agent_ids: list[str] = Field()
experiment_tag: str = Field()
experiment_model_name_pairs: list[str] = Field()
raw_messages: list[list[tuple[str, str, str]]] = Field()
raw_rewards: list[tuple[float, dict[str, float]] | float] = Field()
raw_rewards_prompt: str = Field()
scenario: str = Field()
codename: str = Field()
agents_background: dict[str, str] = Field()
social_goals: dict[str, str] = Field()
social_interactions: str = Field()
reasoning: str = Field()
rewards: list[tuple[float, dict[str, float]]] = Field()


class AgentProfileWithPersonalInformation(BaseModel):
agent_id: str = Field(required=True)
first_name: str = Field(required=True)
last_name: str = Field(required=True)
age: int = Field(required=True)
occupation: str = Field(required=True)
gender: str = Field(required=True)
gender_pronoun: str = Field(required=True)
public_info: str = Field(required=True)
big_five: str = Field(required=True)
moral_values: list[str] = Field(required=True)
schwartz_personal_values: list[str] = Field(required=True)
personality_and_values: str = Field(required=True)
decision_making_style: str = Field(required=True)
secret: str = Field(required=True)
mbti: str = Field(required=True)
model_id: str = Field(required=True)
agent_id: str = Field()
first_name: str = Field()
last_name: str = Field()
age: int = Field()
occupation: str = Field()
gender: str = Field()
gender_pronoun: str = Field()
public_info: str = Field()
big_five: str = Field()
moral_values: list[str] = Field()
schwartz_personal_values: list[str] = Field()
personality_and_values: str = Field()
decision_making_style: str = Field()
secret: str = Field()
mbti: str = Field()
model_id: str = Field()


class EnvironmentProfileWithTwoAgentRequirements(BaseModel):
env_id: str = Field(required=True)
codename: str = Field(required=True)
source: str = Field(required=True)
scenario: str = Field(required=True)
agent_goals: list[str] = Field(required=True)
relationship: str = Field(required=True)
age_constraint: str = Field(required=True)
occupation_constraint: str = Field(required=True)
agent_constraint: str | None = Field(required=False)
env_id: str = Field()
codename: str = Field()
source: str = Field()
scenario: str = Field()
agent_goals: list[str] = Field()
relationship: str = Field()
age_constraint: str = Field()
occupation_constraint: str = Field()
agent_constraint: str | None = Field()


class RelationshipProfileBetweenTwoAgents(BaseModel):
relationship_id: str = Field(required=True)
agent1_id: str = Field(required=True)
agent2_id: str = Field(required=True)
relationship: str = Field(required=True)
background_story: str = Field(required=True)
relationship_id: str = Field()
agent1_id: str = Field()
agent2_id: str = Field()
relationship: str = Field()
background_story: str = Field()


class EnvAgentComboStorageWithID(BaseModel):
combo_id: str = Field(default_factory=lambda: "", index=True)
env_id: str = Field(default_factory=lambda: "", index=True)
agent_ids: list[str] = Field(default_factory=lambda: [], index=True)
combo_id: str = Field(default="", index=True)
env_id: str = Field(default="", index=True)
agent_ids: list[str] = Field(default=[], index=True)


def _map_gender_to_adj(gender: str) -> str:
Expand Down
4 changes: 2 additions & 2 deletions sotopia/database/session_transaction.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import validator
from pydantic import field_validator
from redis_om import EmbeddedJsonModel, JsonModel
from redis_om.model.model import Field

Expand Down Expand Up @@ -30,7 +30,7 @@ class SessionTransaction(AutoExpireMixin, JsonModel):
"""
)

@validator("message_list")
@field_validator("message_list")
def validate_message_list(
cls, v: list[MessageTransaction]
) -> list[MessageTransaction]:
Expand Down
Loading
Loading