Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[BlenderBot2] clean up reply during interactive mode #4379

Merged
merged 7 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion projects/blenderbot2/agents/blenderbot2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
The Memory Decoder examines the context and generates memories to write to
the long-term memory module.
"""
from abc import abstractmethod
import re
import copy
import torch
import torch.nn
Expand All @@ -32,7 +34,7 @@
from parlai.core.metrics import AverageMetric
from parlai.core.opt import Opt
from parlai.core.params import ParlaiParser
from parlai.core.torch_agent import Batch
from parlai.core.torch_agent import Batch, History
from parlai.tasks.wizard_of_internet.constants import (
SELECTED_DOCS,
SELECTED_DOCS_TITLES,
Expand All @@ -57,6 +59,54 @@
ZOO_MEMORY_DECODER = 'zoo:blenderbot2/memory_decoder/model'


class HistoryCleanReply(History):
def __init__(
self,
opt,
field='text',
maxlen=None,
size=-1,
p1_token='__p1__',
p2_token='__p2__',
dict_agent=None,
):
super().__init__(
opt,
field=field,
maxlen=maxlen,
size=size,
p1_token=p1_token,
p2_token=p2_token,
dict_agent=dict_agent,
)
self.add_cleaned_reply_to_history = opt.get(
'add_cleaned_reply_to_history', False
)

@abstractmethod
def _clean_text(self, txt):
"""
Clean text to be override with custom logic.
"""

def add_reply(self, text):
clean_text = text
if self.add_cleaned_reply_to_history:
clean_text = self._clean_text(text)
super().add_reply(clean_text)


class HistoryCleanUnsafeToken(HistoryCleanReply):
"""
Override the history _clean_text to filter out special tokens like
_potentially_unsafe.
"""

def _clean_text(self, txt):
cleaned_txt = re.sub(r'_[\S]*unsafe_*', '', txt, flags=re.IGNORECASE)
return cleaned_txt.strip()


class BlenderBot2ModelTypeMixin(RagModelInterface):
"""
Override Normal RAG Model Types, in case we retrieve from both memory and search.
Expand Down Expand Up @@ -339,6 +389,12 @@ def add_cmdline_args(
hidden=True,
help='model file for memory writer',
)
bb2_group.add_argument(
'--add-cleaned-reply-to-history',
type=bool,
default=False,
help='whether to add the cleaned bb2 generated text without any special tokens to its history',
)
memory_decoder = parser.add_argument_group('BlenderBot2 Memory Decoder Args')
memory_decoder.add_argument(
'--memory-decoder-key',
Expand Down Expand Up @@ -391,6 +447,10 @@ def add_cmdline_args(
)
return parser

@classmethod
def history_class(cls):
return HistoryCleanUnsafeToken

@property
def rag_model_type(self) -> str:
return self._rag_model_type
Expand Down
34 changes: 34 additions & 0 deletions tests/nightly/gpu/test_bb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import copy
import torch.cuda
import unittest
from parlai.core.agents import create_agent

import parlai.utils.testing as testing_utils

Expand Down Expand Up @@ -203,6 +204,39 @@ def test_rag(self):
)


@testing_utils.skipUnlessGPU
@unittest.skipIf(LOCAL, "Skipping Test because its slow and mem intensive")
class TestBB2CleanText(unittest.TestCase):
SPEICIAL_TOKEN = '_POTENTIALLY_UNSAFE__'

def test_bb2_history(self):
"""
Test out-of-the-box BB2 generation.
"""
opt = copy.deepcopy(common_opt)
opt.update(
{
'model_file': ZOO_BB2,
'override': {
'search_server': SEARCH_SERVER,
'add_cleaned_reply_to_history': True,
},
}
)
bb2 = create_agent(opt)

text_with_safety_token = f"Don't have a cow, Man! {self.SPEICIAL_TOKEN}"
obs = {'text': text_with_safety_token}
bb2.observe(obs)
assert self.SPEICIAL_TOKEN in bb2.history.get_history_str()

bb2.history.reset()
obs = {'text': "I am Groot"}
bb2.observe(obs)
bb2.history.add_reply(text_with_safety_token)
assert self.SPEICIAL_TOKEN not in bb2.history.get_history_str()


@testing_utils.skipUnlessGPU
@unittest.skipIf(LOCAL, "Skipping Test because its slow and mem intensive")
class TestBB2AdditionalTruncation(unittest.TestCase):
Expand Down