Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Allow more subclassing of self-chat world #3955

Merged
merged 6 commits into from
Sep 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions parlai/core/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,21 @@ def reset(self):
"""
Reset all agents in the world, and world statistics.
"""
for a in self.agents:
a.reset()
self.reset_agents()
self.max_exs = None
self.total_exs = 0
self.total_epochs = 0
self.total_parleys = 0
self.time.reset()

def reset_agents(self):
"""
Reset all agents in the world.
"""
agents = self.get_agents()
for a in agents:
a.reset()

def reset_metrics(self):
"""
Reset metrics for all agents.
Expand Down
43 changes: 27 additions & 16 deletions parlai/tasks/self_chat/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,32 +122,33 @@ def episode_done(self):
def _get_seed_utt_acts(
self, episode_num: int, agents: List[Agent]
) -> List[Dict[str, Any]]:
"""
Return acts of any utterances to "seed" the conversation with.
"""

def make_agent_action(utterance: str, agent: Agent) -> Dict[str, Any]:
return {'text': utterance, 'episode_done': False, 'id': agent.id}

openers = self.get_openers(episode_num)
if not openers:
return []
return list(map(make_agent_action, openers, agents))
if self.turn_cnt == 0:
# Create the seed utterances from any openers
openers = self.get_openers(episode_num)
if not openers:
return []
return list(map(make_agent_action, openers, agents))
else:
# Just return the existing seed utterances, if any exist
return self.seed_utterances

def parley(self):
if self.episode_done():
self.turn_cnt = 0
self.episode_cnt += 1
self.contexts = None
self.seed_utterances = None
agents = self.get_agents()
for a in agents:
a.reset()
self._end_episode()

if self.turn_cnt == 0:
self.acts = [None, None]
# get the beginning of the conversation, which can include contexts
# and/or any number of starting messages
# get any context for the beginning of the conversation
self.contexts = self.get_contexts()
self.seed_utterances = self._get_seed_utt_acts(
self.episode_cnt, self.agents
)

self.seed_utterances = self._get_seed_utt_acts(self.episode_cnt, self.agents)

if self.contexts:
assert len(self.contexts) == 2
Expand Down Expand Up @@ -186,3 +187,13 @@ def parley(self):

self.update_counters()
self.turn_cnt += 1

def _end_episode(self):
"""
Apply logic to end the episode.
"""
self.turn_cnt = 0
self.episode_cnt += 1
self.contexts = None
self.seed_utterances = None
self.reset_agents()