Skip to content

Commit

Permalink
Add distillation tests with max cut size
Browse files Browse the repository at this point in the history
And fix endless loop when the max cut size is 0 or 1.
  • Loading branch information
danieldk committed Dec 8, 2023
1 parent e2591cd commit 42fe4ed
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
2 changes: 1 addition & 1 deletion spacy/pipeline/transition_parser.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ cdef class Parser(TrainablePipe):
# batch uniform length. Since we do not have a gold standard
# sequence, we use the teacher's predictions as the gold
# standard.
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2))
states = self._init_batch(teacher_step_model, student_docs, max_moves)
else:
states = self.moves.init_batch(student_docs)
Expand Down
5 changes: 4 additions & 1 deletion spacy/tests/parser/test_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,9 @@ def test_is_distillable():
assert ner.is_distillable


def test_distill():
@pytest.mark.slow
@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
def test_distill(max_moves):
teacher = English()
teacher_ner = teacher.add_pipe("ner")
train_examples = []
Expand All @@ -642,6 +644,7 @@ def test_distill():

student = English()
student_ner = student.add_pipe("ner")
student_ner.cfg["update_with_oracle_cut_size"] = max_moves
student_ner.initialize(
get_examples=lambda: train_examples, labels=teacher_ner.label_data
)
Expand Down
5 changes: 4 additions & 1 deletion spacy/tests/parser/test_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,9 @@ def test_is_distillable():
assert parser.is_distillable


def test_distill():
@pytest.mark.slow
@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
def test_distill(max_moves):
teacher = English()
teacher_parser = teacher.add_pipe("parser")
train_examples = []
Expand All @@ -420,6 +422,7 @@ def test_distill():

student = English()
student_parser = student.add_pipe("parser")
student_parser.cfg["update_with_oracle_cut_size"] = max_moves
student_parser.initialize(
get_examples=lambda: train_examples, labels=teacher_parser.label_data
)
Expand Down

0 comments on commit 42fe4ed

Please sign in to comment.