Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[BB2] LTM Mutator #3966

Merged
merged 1 commit into from
Sep 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions parlai/tasks/msc/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
UI_OPT,
COMMON_CONFIG,
)
import parlai.tasks.msc.mutators # type: ignore


NOPERSONA = '__NO__PERSONA__BEAM__MIN__LEN__20__'
Expand Down
30 changes: 30 additions & 0 deletions parlai/tasks/msc/mutators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import random
from typing import List
from parlai.core.message import Message
from parlai.core.mutators import register_mutator, ManyEpisodeMutator


@register_mutator("ltm_mutator")
class LongTermMemoryMutator(ManyEpisodeMutator):
"""
Replaces the episode labels in messages with "personal_knwowledge".

Episodes are flattened to ensure dialogue history is maintained appropriately.
"""

def many_episode_mutation(self, episode: List[Message]) -> List[List[Message]]:
history = []
for message in episode:
if message['text'] == '__SILENCE__':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one nit: replace this '__SILENCE__' with DUMMY_TEXT

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

happy to do this, i think i hesitated at first because it would require a nested import (to avoid circular dependencies), but should be fine for me!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah didn't realize it would require a nested import. I'm fine either way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, if you're ok with it i'll avoid the nested import for now, then

continue
history.append(message.pop('text'))
message['text'] = '\n'.join(history)
labels = message.pop('labels')
message['labels'] = ["personal_knowledge"]
yield [message]
history.append(random.choice(labels))
13 changes: 13 additions & 0 deletions tests/test_mutators.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,19 @@ def test_word_shuffle(self):
assert set(ex3['text'].split()) == set(EXAMPLE3['text'].split())
assert set(ex4['text'].split()) == set(EXAMPLE4['text'].split())

def test_msc_ltm_mutator(self):
from parlai.tasks.msc.mutators import LongTermMemoryMutator

ex1, ex2, ex3, ex4 = self._apply_mutator(LongTermMemoryMutator)

assert (
ex1['labels']
== ex2['labels']
== ex3['labels']
== ex4['labels']
== ['personal_knowledge']
)


class TestMutatorStickiness(unittest.TestCase):
"""
Expand Down