Skip to content

Commit

Permalink
Adapt test_incremental.py to spaCy v2.2
Browse files Browse the repository at this point in the history
  • Loading branch information
Hiromu Hota committed Jun 2, 2020
1 parent 4df0b5b commit 24bba78
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions tests/e2e/test_incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,17 @@ def test_incremental():
featurizer = Featurizer(session, [PartTemp])

featurizer.apply(split=0, train=True, parallelism=PARALLEL)
assert session.query(Feature).count() == 70
assert session.query(FeatureKey).count() == 512
assert session.query(Feature).count() == len(train_cands[0])
assert session.query(FeatureKey).count() == 526
num_feature_keys = session.query(FeatureKey).count()

F_train = featurizer.get_feature_matrices(train_cands)
assert F_train[0].shape == (70, 512)
assert len(featurizer.get_keys()) == 512
assert F_train[0].shape == (len(train_cands[0]), num_feature_keys)
assert len(featurizer.get_keys()) == num_feature_keys

# Test Dropping FeatureKey
featurizer.drop_keys(["CORE_e1_LENGTH_1"])
assert session.query(FeatureKey).count() == 512
assert session.query(FeatureKey).count() == num_feature_keys

stg_temp_lfs = [
LF_storage_row,
Expand All @@ -142,12 +143,12 @@ def test_incremental():
labeler = Labeler(session, [PartTemp])

labeler.apply(split=0, lfs=[stg_temp_lfs], train=True, parallelism=PARALLEL)
assert session.query(Label).count() == 70
assert session.query(Label).count() == len(train_cands[0])

# Only 5 because LF_operating_row doesn't apply to the first test doc
assert session.query(LabelKey).count() == 5
L_train = labeler.get_label_matrices(train_cands)
assert L_train[0].shape == (70, 5)
assert L_train[0].shape == (len(train_cands[0]), 5)
assert len(labeler.get_keys()) == 5

docs_path = "tests/data/html/112823.html"
Expand Down Expand Up @@ -180,21 +181,22 @@ def test_incremental():
# Grab candidate lists
train_cands = candidate_extractor.get_candidates(split=0)
assert len(train_cands) == 1
assert len(train_cands[0]) == 1502
assert len(train_cands[0]) == 1501

# Test if existing candidates are skipped.
candidate_extractor.apply(new_docs, split=0, parallelism=PARALLEL, clear=False)
train_cands = candidate_extractor.get_candidates(split=0)
assert len(train_cands) == 1
assert len(train_cands[0]) == 1502
assert len(train_cands[0]) == 1501

# Update features
featurizer.update(new_docs, parallelism=PARALLEL)
assert session.query(Feature).count() == 1502
assert session.query(FeatureKey).count() == 2573
assert session.query(Feature).count() == len(train_cands[0])
assert session.query(FeatureKey).count() == 2526
num_feature_keys = session.query(FeatureKey).count()
F_train = featurizer.get_feature_matrices(train_cands)
assert F_train[0].shape == (1502, 2573)
assert len(featurizer.get_keys()) == 2573
assert F_train[0].shape == (len(train_cands[0]), num_feature_keys)
assert len(featurizer.get_keys()) == num_feature_keys

# Update LF_storage_row. Now it always returns ABSTAIN.
@labeling_function(name="LF_storage_row")
Expand All @@ -213,11 +215,12 @@ def LF_storage_row_updated(c):
# Update Labels
labeler.update(docs, lfs=[stg_temp_lfs], parallelism=PARALLEL)
labeler.update(new_docs, lfs=[stg_temp_lfs], parallelism=PARALLEL)
assert session.query(Label).count() == 1502
assert session.query(Label).count() == len(train_cands[0])
# Only 5 because LF_storage_row doesn't apply to any doc (always ABSTAIN)
assert session.query(LabelKey).count() == 5
num_label_keys = session.query(LabelKey).count()
L_train = labeler.get_label_matrices(train_cands)
assert L_train[0].shape == (1502, 5)
assert L_train[0].shape == (len(train_cands[0]), num_label_keys)

# Test clear
featurizer.clear(train=True)
Expand Down

0 comments on commit 24bba78

Please sign in to comment.