diff --git a/parlai/tasks/style_gen/agents.py b/parlai/tasks/style_gen/agents.py index 9c8eb799768..5e8d10aaa1a 100644 --- a/parlai/tasks/style_gen/agents.py +++ b/parlai/tasks/style_gen/agents.py @@ -124,3 +124,37 @@ def _edit_action(self, act: Message) -> Message: assert 'text' not in act and act['episode_done'] is True act.force_set('episode_done', True) # Clear the dialogue history return act + + +class CurrUttOnlyStyleTeacher(AbstractWrapperTeacher): + """ + Serving examples for use with projects.style_gen.classifier:ClassifierAgent. + + This teacher will replace message['text'] with message['labels'][0], and it will + replace message['labels'] with [message['personality']]. This is to allow for + training/evaluation of projects.style_gen.classifier:ClassifierAgent in the case + where we want to classify the style of an utterance given only that utterance as + context. + + Because the dialogue history is effectively overwritten by this action, all episodes + will be flattened into one example each. + """ + + def _edit_action(self, act: Message) -> Message: + """ + Edit the fields of the action manually. + """ + if 'labels' in act: + labels = act['labels'] + if len(labels) != 1: + raise ValueError( + f'{type(self).__name__} can only be used with one label!' + ) + assert '\n' not in labels[0] + # Classifier will not expect any newlines in the context + act.force_set('text', labels[0]) + act.force_set('labels', [act['personality']]) + else: + assert 'text' not in act and act['episode_done'] is True + act.force_set('episode_done', True) # Clear the dialogue history + return act diff --git a/parlai/zoo/model_list.py b/parlai/zoo/model_list.py index 238c01962af..356f3043a2e 100644 --- a/parlai/zoo/model_list.py +++ b/parlai/zoo/model_list.py @@ -1417,6 +1417,23 @@ accuracy bleu-4 ctpb ctps exps exs f1 gpu_mem loss lr ltpb ltps tpb tps .9973 .01745 38.08 604.1 15.86 5651 .9973 .1622 2.129 5e-10 5.633 89.36 43.71 693.4""", }, + { + "title": "Style-controlled generation: current-utterance-only classifier", + "id": "style_gen", + "path": "zoo:style_gen/curr_only_classifier/model", + "agent": "projects.style_gen.classifier:ClassifierAgent", + "task": "style_gen:LabeledBlendedSkillTalk", + "project": 'https://github.com/facebookresearch/ParlAI/tree/main/projects/style_gen', + "description": "Classifier trained on Image-Chat turns 2 and 3 to classify the personality of an example given that utterance as the sole context.", + "example": "parlai eval_model --task style_gen:CurrUttOnlyStyle --wrapper-task style_gen:LabeledBlendedSkillTalk --model-file zoo:style_gen/curr_only_classifier/model --model projects.style_gen.classifier:ClassifierAgent --classes-from-file image_chat_personalities_file", + "result": """ +16:46:41 | Finished evaluating tasks ['style_gen:CurrUttOnlyStyle'] using datatype valid + accuracy bleu-4 clen ctpb ctps ctrunc ctrunclen exps exs f1 gpu_mem llen loss lr ltpb ltps ltrunc ltrunclen tpb tps \ + .4311 .004642 19.18 19.18 425.9 0 0 22.2 5651 .4363 .1586 5.633 2.658 5e-10 5.633 125.1 0 0 24.82 550.9 + weighted_f1 + .4319 +""", # The accuracy is low here because the task was labeled using a different classifier, zoo:style_gen/prev_curr_classifier/model + }, { "title": "Faster-R-CNN Detectron Features", "id": "detectron", diff --git a/parlai/zoo/style_gen/curr_only_classifier.py b/parlai/zoo/style_gen/curr_only_classifier.py new file mode 100644 index 00000000000..04a0c7ab19e --- /dev/null +++ b/parlai/zoo/style_gen/curr_only_classifier.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Classifier trained on Image-Chat turns 2 and 3 to classify the personality of an example +given only the current utterance as context. +""" + +from parlai.core.build_data import download_models + + +def download(datapath): + model_type = 'curr_only_classifier' + # v1.1 vacuumed the model file to be smaller + version = 'v1.1' + opt = {'datapath': datapath, 'model_type': model_type} + fnames = [f'{version}.tar.gz'] + download_models( + opt=opt, + fnames=fnames, + model_folder='style_gen', + version=version, + use_model_type=True, + ) diff --git a/tests/nightly/gpu/test_style_gen.py b/tests/nightly/gpu/test_style_gen.py index 4f45b544efd..b7dacde90be 100644 --- a/tests/nightly/gpu/test_style_gen.py +++ b/tests/nightly/gpu/test_style_gen.py @@ -41,7 +41,7 @@ def test_simple(self): self.assertEqual(valid['accuracy'], 1.0) self.assertEqual(test['accuracy'], 1.0) - def test_accuracy(self): + def test_prev_curr_accuracy(self): """ Test the accuracy of the classifier trained on previous and current utterances. @@ -63,6 +63,28 @@ def test_accuracy(self): ) self.assertAlmostEqual(test['accuracy'], 1.0, delta=0.0) + def test_curr_only_accuracy(self): + """ + Test the accuracy of the classifier trained on current utterances only. + + The accuracy is low here because the task was labeled using a different + classifier, zoo:style_gen/prev_curr_classifier/model. + """ + _, test = testing_utils.eval_model( + opt={ + 'batchsize': 4, + 'fp16': True, + 'num_examples': 16, + 'model_file': 'zoo:style_gen/curr_only_classifier/model', + 'model': 'projects.style_gen.classifier:ClassifierAgent', + 'classes_from_file': 'image_chat_personalities_file', + 'task': 'style_gen:CurrUttOnlyStyle', + 'wrapper_task': 'style_gen:LabeledBlendedSkillTalk', + }, + skip_valid=True, + ) + self.assertAlmostEqual(test['accuracy'], 0.4375, delta=0.0) + class TestStyleGen(unittest.TestCase): def test_perplexities(self):