Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
add msc mutator (#3966)
Browse files Browse the repository at this point in the history
  • Loading branch information
klshuster authored Sep 13, 2021
1 parent 60d1499 commit 136ffaa
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 0 deletions.
1 change: 1 addition & 0 deletions parlai/tasks/msc/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
UI_OPT,
COMMON_CONFIG,
)
import parlai.tasks.msc.mutators # type: ignore


NOPERSONA = '__NO__PERSONA__BEAM__MIN__LEN__20__'
Expand Down
30 changes: 30 additions & 0 deletions parlai/tasks/msc/mutators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import random
from typing import List
from parlai.core.message import Message
from parlai.core.mutators import register_mutator, ManyEpisodeMutator


@register_mutator("ltm_mutator")
class LongTermMemoryMutator(ManyEpisodeMutator):
"""
Replaces the episode labels in messages with "personal_knwowledge".
Episodes are flattened to ensure dialogue history is maintained appropriately.
"""

def many_episode_mutation(self, episode: List[Message]) -> List[List[Message]]:
history = []
for message in episode:
if message['text'] == '__SILENCE__':
continue
history.append(message.pop('text'))
message['text'] = '\n'.join(history)
labels = message.pop('labels')
message['labels'] = ["personal_knowledge"]
yield [message]
history.append(random.choice(labels))
13 changes: 13 additions & 0 deletions tests/test_mutators.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,19 @@ def test_word_shuffle(self):
assert set(ex3['text'].split()) == set(EXAMPLE3['text'].split())
assert set(ex4['text'].split()) == set(EXAMPLE4['text'].split())

def test_msc_ltm_mutator(self):
from parlai.tasks.msc.mutators import LongTermMemoryMutator

ex1, ex2, ex3, ex4 = self._apply_mutator(LongTermMemoryMutator)

assert (
ex1['labels']
== ex2['labels']
== ex3['labels']
== ex4['labels']
== ['personal_knowledge']
)


class TestMutatorStickiness(unittest.TestCase):
"""
Expand Down

0 comments on commit 136ffaa

Please sign in to comment.