From 136ffaa9bdc35edadc9000b81fd438c15d749324 Mon Sep 17 00:00:00 2001 From: Kurt Shuster Date: Mon, 13 Sep 2021 13:34:01 -0400 Subject: [PATCH] add msc mutator (#3966) --- parlai/tasks/msc/agents.py | 1 + parlai/tasks/msc/mutators.py | 30 ++++++++++++++++++++++++++++++ tests/test_mutators.py | 13 +++++++++++++ 3 files changed, 44 insertions(+) create mode 100644 parlai/tasks/msc/mutators.py diff --git a/parlai/tasks/msc/agents.py b/parlai/tasks/msc/agents.py index 60ca0b77b56..a0978d9f488 100644 --- a/parlai/tasks/msc/agents.py +++ b/parlai/tasks/msc/agents.py @@ -31,6 +31,7 @@ UI_OPT, COMMON_CONFIG, ) +import parlai.tasks.msc.mutators # type: ignore NOPERSONA = '__NO__PERSONA__BEAM__MIN__LEN__20__' diff --git a/parlai/tasks/msc/mutators.py b/parlai/tasks/msc/mutators.py new file mode 100644 index 00000000000..158cf018476 --- /dev/null +++ b/parlai/tasks/msc/mutators.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import random +from typing import List +from parlai.core.message import Message +from parlai.core.mutators import register_mutator, ManyEpisodeMutator + + +@register_mutator("ltm_mutator") +class LongTermMemoryMutator(ManyEpisodeMutator): + """ + Replaces the episode labels in messages with "personal_knwowledge". + + Episodes are flattened to ensure dialogue history is maintained appropriately. + """ + + def many_episode_mutation(self, episode: List[Message]) -> List[List[Message]]: + history = [] + for message in episode: + if message['text'] == '__SILENCE__': + continue + history.append(message.pop('text')) + message['text'] = '\n'.join(history) + labels = message.pop('labels') + message['labels'] = ["personal_knowledge"] + yield [message] + history.append(random.choice(labels)) diff --git a/tests/test_mutators.py b/tests/test_mutators.py index c26a2712987..1e2a49e3a5e 100644 --- a/tests/test_mutators.py +++ b/tests/test_mutators.py @@ -227,6 +227,19 @@ def test_word_shuffle(self): assert set(ex3['text'].split()) == set(EXAMPLE3['text'].split()) assert set(ex4['text'].split()) == set(EXAMPLE4['text'].split()) + def test_msc_ltm_mutator(self): + from parlai.tasks.msc.mutators import LongTermMemoryMutator + + ex1, ex2, ex3, ex4 = self._apply_mutator(LongTermMemoryMutator) + + assert ( + ex1['labels'] + == ex2['labels'] + == ex3['labels'] + == ex4['labels'] + == ['personal_knowledge'] + ) + class TestMutatorStickiness(unittest.TestCase): """