Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Allow more subclassing of self-chat world (#3955)
Browse files Browse the repository at this point in the history
* Generalize seed utterances logic

* Update seed utterance logic

* Modularize logic

* Reset agents method

* Streamline seed utterance logic
  • Loading branch information
EricMichaelSmith authored Sep 8, 2021
1 parent 7506a84 commit 376888e
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 18 deletions.
11 changes: 9 additions & 2 deletions parlai/core/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,21 @@ def reset(self):
"""
Reset all agents in the world, and world statistics.
"""
for a in self.agents:
a.reset()
self.reset_agents()
self.max_exs = None
self.total_exs = 0
self.total_epochs = 0
self.total_parleys = 0
self.time.reset()

def reset_agents(self):
"""
Reset all agents in the world.
"""
agents = self.get_agents()
for a in agents:
a.reset()

def reset_metrics(self):
"""
Reset metrics for all agents.
Expand Down
43 changes: 27 additions & 16 deletions parlai/tasks/self_chat/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,32 +122,33 @@ def episode_done(self):
def _get_seed_utt_acts(
self, episode_num: int, agents: List[Agent]
) -> List[Dict[str, Any]]:
"""
Return acts of any utterances to "seed" the conversation with.
"""

def make_agent_action(utterance: str, agent: Agent) -> Dict[str, Any]:
return {'text': utterance, 'episode_done': False, 'id': agent.id}

openers = self.get_openers(episode_num)
if not openers:
return []
return list(map(make_agent_action, openers, agents))
if self.turn_cnt == 0:
# Create the seed utterances from any openers
openers = self.get_openers(episode_num)
if not openers:
return []
return list(map(make_agent_action, openers, agents))
else:
# Just return the existing seed utterances, if any exist
return self.seed_utterances

def parley(self):
if self.episode_done():
self.turn_cnt = 0
self.episode_cnt += 1
self.contexts = None
self.seed_utterances = None
agents = self.get_agents()
for a in agents:
a.reset()
self._end_episode()

if self.turn_cnt == 0:
self.acts = [None, None]
# get the beginning of the conversation, which can include contexts
# and/or any number of starting messages
# get any context for the beginning of the conversation
self.contexts = self.get_contexts()
self.seed_utterances = self._get_seed_utt_acts(
self.episode_cnt, self.agents
)

self.seed_utterances = self._get_seed_utt_acts(self.episode_cnt, self.agents)

if self.contexts:
assert len(self.contexts) == 2
Expand Down Expand Up @@ -186,3 +187,13 @@ def parley(self):

self.update_counters()
self.turn_cnt += 1

def _end_episode(self):
"""
Apply logic to end the episode.
"""
self.turn_cnt = 0
self.episode_cnt += 1
self.contexts = None
self.seed_utterances = None
self.reset_agents()

0 comments on commit 376888e

Please sign in to comment.