Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[Style-Controlled Generation] Open-source a second style classifier #4380

Merged
merged 6 commits into from
Feb 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions parlai/tasks/style_gen/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,37 @@ def _edit_action(self, act: Message) -> Message:
assert 'text' not in act and act['episode_done'] is True
act.force_set('episode_done', True) # Clear the dialogue history
return act


class CurrUttOnlyStyleTeacher(AbstractWrapperTeacher):
"""
Serving examples for use with projects.style_gen.classifier:ClassifierAgent.

This teacher will replace message['text'] with message['labels'][0], and it will
replace message['labels'] with [message['personality']]. This is to allow for
training/evaluation of projects.style_gen.classifier:ClassifierAgent in the case
where we want to classify the style of an utterance given only that utterance as
context.

Because the dialogue history is effectively overwritten by this action, all episodes
will be flattened into one example each.
"""

def _edit_action(self, act: Message) -> Message:
"""
Edit the fields of the action manually.
"""
if 'labels' in act:
labels = act['labels']
if len(labels) != 1:
raise ValueError(
f'{type(self).__name__} can only be used with one label!'
)
assert '\n' not in labels[0]
# Classifier will not expect any newlines in the context
act.force_set('text', labels[0])
act.force_set('labels', [act['personality']])
else:
assert 'text' not in act and act['episode_done'] is True
act.force_set('episode_done', True) # Clear the dialogue history
return act
17 changes: 17 additions & 0 deletions parlai/zoo/model_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,23 @@
accuracy bleu-4 ctpb ctps exps exs f1 gpu_mem loss lr ltpb ltps tpb tps
.9973 .01745 38.08 604.1 15.86 5651 .9973 .1622 2.129 5e-10 5.633 89.36 43.71 693.4""",
},
{
"title": "Style-controlled generation: current-utterance-only classifier",
"id": "style_gen",
"path": "zoo:style_gen/curr_only_classifier/model",
"agent": "projects.style_gen.classifier:ClassifierAgent",
"task": "style_gen:LabeledBlendedSkillTalk",
"project": 'https://github.com/facebookresearch/ParlAI/tree/main/projects/style_gen',
"description": "Classifier trained on Image-Chat turns 2 and 3 to classify the personality of an example given that utterance as the sole context.",
"example": "parlai eval_model --task style_gen:CurrUttOnlyStyle --wrapper-task style_gen:LabeledBlendedSkillTalk --model-file zoo:style_gen/curr_only_classifier/model --model projects.style_gen.classifier:ClassifierAgent --classes-from-file image_chat_personalities_file",
"result": """
16:46:41 | Finished evaluating tasks ['style_gen:CurrUttOnlyStyle'] using datatype valid
accuracy bleu-4 <PER_CLASS_METRICS_SNIPPED> clen ctpb ctps ctrunc ctrunclen exps exs f1 gpu_mem llen loss lr ltpb ltps ltrunc ltrunclen tpb tps \
.4311 .004642 19.18 19.18 425.9 0 0 22.2 5651 .4363 .1586 5.633 2.658 5e-10 5.633 125.1 0 0 24.82 550.9
weighted_f1
.4319
""", # The accuracy is low here because the task was labeled using a different classifier, zoo:style_gen/prev_curr_classifier/model
},
{
"title": "Faster-R-CNN Detectron Features",
"id": "detectron",
Expand Down
26 changes: 26 additions & 0 deletions parlai/zoo/style_gen/curr_only_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Classifier trained on Image-Chat turns 2 and 3 to classify the personality of an example
given only the current utterance as context.
"""

from parlai.core.build_data import download_models


def download(datapath):
model_type = 'curr_only_classifier'
# v1.1 vacuumed the model file to be smaller
version = 'v1.1'
opt = {'datapath': datapath, 'model_type': model_type}
fnames = [f'{version}.tar.gz']
download_models(
opt=opt,
fnames=fnames,
model_folder='style_gen',
version=version,
use_model_type=True,
)
24 changes: 23 additions & 1 deletion tests/nightly/gpu/test_style_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_simple(self):
self.assertEqual(valid['accuracy'], 1.0)
self.assertEqual(test['accuracy'], 1.0)

def test_accuracy(self):
def test_prev_curr_accuracy(self):
"""
Test the accuracy of the classifier trained on previous and current utterances.

Expand All @@ -63,6 +63,28 @@ def test_accuracy(self):
)
self.assertAlmostEqual(test['accuracy'], 1.0, delta=0.0)

def test_curr_only_accuracy(self):
"""
Test the accuracy of the classifier trained on current utterances only.

The accuracy is low here because the task was labeled using a different
classifier, zoo:style_gen/prev_curr_classifier/model.
"""
_, test = testing_utils.eval_model(
opt={
'batchsize': 4,
'fp16': True,
'num_examples': 16,
'model_file': 'zoo:style_gen/curr_only_classifier/model',
'model': 'projects.style_gen.classifier:ClassifierAgent',
'classes_from_file': 'image_chat_personalities_file',
'task': 'style_gen:CurrUttOnlyStyle',
'wrapper_task': 'style_gen:LabeledBlendedSkillTalk',
},
skip_valid=True,
)
self.assertAlmostEqual(test['accuracy'], 0.4375, delta=0.0)


class TestStyleGen(unittest.TestCase):
def test_perplexities(self):
Expand Down