Skip to content

Commit

Permalink
Merge pull request #5 from allenai/kylel/test-taggers
Browse files Browse the repository at this point in the history
Kylel/test taggers
  • Loading branch information
soldni authored Jul 3, 2023
2 parents 79ad7c0 + 702e5e2 commit 36eceb0
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 3 deletions.
86 changes: 86 additions & 0 deletions tests/python/test_taggers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
Unit tests for taggers/*.py
@kylel
"""

from unittest import TestCase

from dolma.core.data_types import DocResult, Document, Span
from dolma.taggers import GopherTagger


class TestGopherTagger(TestCase):
def test_predict_short(self):
tagger = GopherTagger()
doc = Document(source="", version="", id="", text="This is a test.")
doc_result = tagger.predict(doc=doc)
d = doc_result.to_json()
self.assertEqual(len(d["spans"]), 13)
self.assertEqual(d["spans"][0], {"start": 0, "end": 15, "type": "fraction_of_characters_in_most_common_2grams", "score": 0.5, "mention": "This is a test."})
self.assertEqual(d["spans"][1], {"start": 0, "end": 15, "type": "fraction_of_characters_in_most_common_3grams", "score": 0.5833333333333334, "mention": "This is a test."})
self.assertEqual(d["spans"][2], {"start": 0, "end": 15, "type": "fraction_of_characters_in_most_common_4grams", "score": 1.0, "mention": "This is a test."})
self.assertEqual(d["spans"][3], {"start": 0, "end": 15, "type": "character_count", "score": 15.0, "mention": "This is a test."})
self.assertEqual(d["spans"][4], {"start": 0, "end": 15, "type": "word_count", "score": 4.0, "mention": "This is a test."})
self.assertEqual(d["spans"][5], {"start": 0, "end": 15, "type": "median_word_length", "score": 3.0, "mention": "This is a test."})
self.assertEqual(d["spans"][6], {"start": 0, "end": 15, "type": "symbol_to_word_ratio", "score": 0.0, "mention": "This is a test."})
self.assertEqual(d["spans"][7], {"start": 0, "end": 15, "type": "fraction_of_words_with_alpha_character", "score": 1.0, "mention": "This is a test."})
self.assertEqual(d["spans"][8], {"start": 0, "end": 15, "type": "required_word_count", "score": 0.0, "mention": "This is a test."})
self.assertEqual(d["spans"][9], {"start": 0, "end": 15, "type": "fraction_of_lines_starting_with_bullet_point", "score": 0.0, "mention": "This is a test."})
self.assertEqual(d["spans"][10], {"start": 0, "end": 15, "type": "fraction_of_lines_ending_with_ellipsis", "score": 0.0, "mention": "This is a test."})
self.assertEqual(d["spans"][11], {"start": 0, "end": 15, "type": "fraction_of_duplicate_lines", "score": 0.0, "mention": "This is a test."})
self.assertEqual(d["spans"][12], {"start": 0, "end": 15, "type": "fraction_of_characters_in_duplicate_lines", "score": 0.0, "mention": "This is a test."})

def test_predict_multiline(self):
tagger = GopherTagger()
text = "This is a sentence. \n \n This is another sentence.\n\n This is a third sentence."
doc = Document(source="", version="", id="", text=text)
doc_result = tagger.predict(doc=doc)
d = doc_result.to_json()
self.assertEqual(len(d['spans']), 19)
self.assertEqual(d['spans'][0], {'start': 0, 'end': 79, 'type': 'fraction_of_characters_in_most_common_2grams', 'score': 0.3050847457627119, 'mention': text})
self.assertEqual(d['spans'][1], {'start': 0, 'end': 79, 'type': 'fraction_of_characters_in_most_common_3grams', 'score': 0.23728813559322035, 'mention': text})
self.assertEqual(d['spans'][2], {'start': 0, 'end': 79, 'type': 'fraction_of_characters_in_most_common_4grams', 'score': 0.2711864406779661, 'mention': text})
self.assertEqual(d['spans'][3], {'start': 0, 'end': 79, 'type': 'fraction_of_characters_in_duplicate_5grams', 'score': 0.0, 'mention': text})
self.assertEqual(d['spans'][4], {'start': 0, 'end': 79, 'type': 'fraction_of_characters_in_duplicate_6grams', 'score': 0.0, 'mention': text})
self.assertEqual(d['spans'][5], {'start': 0, 'end': 79, 'type': 'fraction_of_characters_in_duplicate_7grams', 'score': 0.0, 'mention': text})
self.assertEqual(d['spans'][6], {'start': 0, 'end': 79, 'type': 'fraction_of_characters_in_duplicate_8grams', 'score': 0.0, 'mention': text})
self.assertEqual(d['spans'][7], {'start': 0, 'end': 79, 'type': 'fraction_of_characters_in_duplicate_9grams', 'score': 0.0, 'mention': text})
self.assertEqual(d['spans'][8], {'start': 0, 'end': 79, 'type': 'fraction_of_characters_in_duplicate_10grams', 'score': 0.0, 'mention': text})
self.assertEqual(d["spans"][9], {"start": 0, "end": 15, "type": "character_count", "score": 79.0, "mention": text})
self.assertEqual(d["spans"][10], {"start": 0, "end": 15, "type": "word_count", "score": 13.0, "mention": text})
self.assertEqual(d["spans"][11], {"start": 0, "end": 15, "type": "median_word_length", "score": 4.0, "mention": text})
self.assertEqual(d["spans"][12], {"start": 0, "end": 15, "type": "symbol_to_word_ratio", "score": 0.0, "mention": text})
self.assertEqual(d["spans"][13], {"start": 0, "end": 15, "type": "fraction_of_words_with_alpha_character", "score": 1.0, "mention": text})
self.assertEqual(d["spans"][14], {"start": 0, "end": 15, "type": "required_word_count", "score": 0.0, "mention": text})
self.assertEqual(d["spans"][15], {"start": 0, "end": 15, "type": "fraction_of_lines_starting_with_bullet_point", "score": 0.0, "mention": text})
self.assertEqual(d["spans"][16], {"start": 0, "end": 15, "type": "fraction_of_lines_ending_with_ellipsis", "score": 0.0, "mention": text})
self.assertEqual(d["spans"][17], {"start": 0, "end": 15, "type": "fraction_of_duplicate_lines", "score": 0.0, "mention": text})
self.assertEqual(d["spans"][18], {"start": 0, "end": 15, "type": "fraction_of_characters_in_duplicate_lines", "score": 0.0, "mention": text})

def test_word_count_is_whitespace_sep(self):
tagger = GopherTagger()
text = "T h i s \n \n\n\n isoneword !!!"
doc = Document(source="", version="", id="", text=text)
doc_result = tagger.predict(doc=doc)
d = doc_result.to_json()
self.assertEqual(d['spans'][6]['type'], 'word_count')
self.assertEqual(d['spans'][6]['score'], 6.0)

def test_required_word_count(self):
tagger = GopherTagger()
text = "The.and.that"
doc = Document(source="", version="", id="", text=text)
doc_result = tagger.predict(doc=doc)
d = doc_result.to_json()
self.assertEqual(d['spans'][5]['type'], 'required_word_count')
self.assertEqual(d['spans'][5]['score'], 0.0)

text = "The and that"
doc = Document(source="", version="", id="", text=text)
doc_result = tagger.predict(doc=doc)
d = doc_result.to_json()
self.assertEqual(d['spans'][7]['type'], 'required_word_count')
self.assertEqual(d['spans'][7]['score'], 2.0)
3 changes: 0 additions & 3 deletions tests/python/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,6 @@ def test_split_sentences(self):
self.assertEqual(text[sentences[0].start : sentences[0].end], sentences[0].text)
self.assertEqual(sentences[1].text, "This is another sentence.")
self.assertEqual(text[sentences[1].start : sentences[1].end], sentences[1].text)

sentences2 = split_sentences(text=text, remove_empty=False)
self.assertListEqual([s.text for s in sentences], [s.text for s in sentences2])

def test_split_sentences_empty(self):
text = ""
Expand Down

0 comments on commit 36eceb0

Please sign in to comment.