Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Speed up TestWizardModel (#3604)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenroller authored Apr 19, 2021
1 parent d83cd22 commit 334faae
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions tests/nightly/gpu/test_wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
'batchsize': 4,
'log_every_n_secs': 30,
'embedding_type': 'random',
'num_examples': 128,
}


Expand All @@ -35,6 +36,7 @@
'delimiter': ' __SOC__ ',
'n_positions': 1000,
'legacy': True,
'num_examples': 128,
}


Expand All @@ -44,27 +46,25 @@ class TestWizardModel(unittest.TestCase):
Checks that pre-trained Wizard models give the correct results.
"""

@classmethod
def setUpClass(cls):
# go ahead and download things here
parser = display_data.setup_args()
parser.set_defaults(**END2END_OPTIONS)
opt = parser.parse_args([])
opt['num_examples'] = 1
opt['verbose'] = True
display_data.display_data(opt)

def test_end2end(self):
valid, _ = testing_utils.eval_model(END2END_OPTIONS, skip_test=True)
self.assertAlmostEqual(valid['ppl'], 61.21, places=2)
self.assertAlmostEqual(valid['f1'], 0.1717, places=4)
self.assertAlmostEqual(valid['know_acc'], 0.2201, places=4)
# For full dataset, remove `num_examples` from the END2END_OPTIONS
# self.assertAlmostEqual(valid['ppl'], 61.21, places=2)
# self.assertAlmostEqual(valid['f1'], 0.1717, places=4)
# self.assertAlmostEqual(valid['know_acc'], 0.2201, places=4)
self.assertAlmostEqual(valid['ppl'], 71.49, places=2)
self.assertAlmostEqual(valid['f1'], 0.1741, places=4)
self.assertAlmostEqual(valid['know_acc'], 0.1797, places=4)

def test_retrieval(self):
_, test = testing_utils.eval_model(RETRIEVAL_OPTIONS, skip_valid=True)
self.assertAlmostEqual(test['accuracy'], 0.8631, places=4)
self.assertAlmostEqual(test['hits@5'], 0.9814, places=4)
self.assertAlmostEqual(test['hits@10'], 0.9917, places=4)
# for full dataset, remove `num_examples` from END2END_OPTIONS
# self.assertAlmostEqual(test['accuracy'], 0.8631, places=4)
# self.assertAlmostEqual(test['hits@5'], 0.9814, places=4)
# self.assertAlmostEqual(test['hits@10'], 0.9917, places=4)
self.assertAlmostEqual(test['accuracy'], 0.9141, places=4)
self.assertAlmostEqual(test['hits@5'], 1.0, places=4)
self.assertAlmostEqual(test['hits@10'], 1.0, places=4)


class TestKnowledgeRetriever(unittest.TestCase):
Expand Down

0 comments on commit 334faae

Please sign in to comment.