Skip to content

Commit

Permalink
Add nullables to candidate_subclass() (fix HazyResearch#496)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hiromu Hota committed Sep 1, 2020
1 parent afe12b0 commit 2ce00a9
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 8 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ Added
* `@wajdikhattel`_: Add multinary candidates.
(`#455 <https://github.com/HazyResearch/fonduer/issues/455>`_)
(`#456 <https://github.com/HazyResearch/fonduer/pull/456>`_)
* `@HiromuHota`_: Add ``nullables`` to :func:`candidate_subclass()` to allow NULL mention in a candidate.
(`#496 <https://github.com/HazyResearch/fonduer/issues/496>`_)
(`#497 <https://github.com/HazyResearch/fonduer/pull/497>`_)

Changed
^^^^^^^
Expand Down
15 changes: 10 additions & 5 deletions src/fonduer/candidates/candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,11 @@ def apply( # type: ignore
enumerate(
# a list of mentions for each mention subclass within a doc
getattr(doc, mention.__tablename__ + "s")
+ ([None] if nullable else [])
)
for mention, nullable in zip(
candidate_class.mentions, candidate_class.nullables
)
for mention in candidate_class.mentions
]
)
# Get a set of stable_ids of candidates.
Expand All @@ -286,15 +289,16 @@ def apply( # type: ignore

# TODO: Make this work for higher-order relations
if self.arities[i] == 2:
ai, a = (cand[0][0], cand[0][1].context)
bi, b = (cand[1][0], cand[1][1].context)
ai, a = (cand[0][0], cand[0][1].context if cand[0][1] else None)
bi, b = (cand[1][0], cand[1][1].context if cand[1][1] else None)

# Check for self-joins, "nested" joins (joins from context to
# its subcontext), and flipped duplicate "symmetric" relations
if not self.self_relations and a == b:
logger.debug(f"Skipping self-joined candidate {cand}")
continue
if not self.nested_relations and (a in b or b in a):
# Skip the check if either is None as None is not iterable.
if not self.nested_relations and (a and b) and (a in b or b in a):
logger.debug(f"Skipping nested candidate {cand}")
continue
if not self.symmetric_relations and ai > bi:
Expand All @@ -306,7 +310,8 @@ def apply( # type: ignore
candidate_args[arg_name] = cand[j][1]

stable_ids = tuple(
cand[j][1].context.get_stable_id() for j in range(self.arities[i])
cand[j][1].context.get_stable_id() if cand[j][1] else None
for j in range(self.arities[i])
)
# Skip if this (temporary) candidate is used by this candidate class.
if (
Expand Down
18 changes: 16 additions & 2 deletions src/fonduer/candidates/models/candidate.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def candidate_subclass(
table_name: Optional[str] = None,
cardinality: Optional[int] = None,
values: Optional[List[Any]] = None,
nullables: Optional[List[bool]] = None,
) -> Type[Candidate]:
"""Create new relation.
Expand All @@ -95,6 +96,10 @@ def candidate_subclass(
:param cardinality: The cardinality of the variable corresponding to the
Candidate. By default is 2 i.e. is a binary value, e.g. is or is not
a true mention.
:param values: A list of values a candidate can take as their label.
:param nullables: The number of nullables must match that of args.
If nullables[i]==True, a mention for ith mention subclass can be NULL.
If nullables=``None`` (by default), no mention can be NULL.
"""
if table_name is None:
table_name = camel_to_under(class_name)
Expand Down Expand Up @@ -124,6 +129,12 @@ def candidate_subclass(
elif cardinality is not None:
values = list(range(cardinality))

if nullables:
if len(nullables) != len(args):
raise ValueError("The number of nullables must match that of args.")
else:
nullables = [False] * len(args)

class_spec = (args, table_name, cardinality, values)
if class_name in candidate_subclasses:
if class_spec == candidate_subclasses[class_name][1]:
Expand Down Expand Up @@ -153,6 +164,7 @@ def candidate_subclass(
# Helper method to get argument names
"__argnames__": [_.__tablename__ for _ in args],
"mentions": args,
"nullables": nullables,
}
class_attribs["document_id"] = Column(
Integer, ForeignKey("document.id", ondelete="CASCADE")
Expand All @@ -166,10 +178,12 @@ def candidate_subclass(
# Create named arguments, i.e. the entity mentions comprising the
# relation mention.
unique_args = []
for arg in args:
for arg, nullable in zip(args, nullables):
# Primary arguments are constituent Contexts, and their ids
class_attribs[arg.__tablename__ + "_id"] = Column(
Integer, ForeignKey(arg.__tablename__ + ".id", ondelete="CASCADE")
Integer,
ForeignKey(arg.__tablename__ + ".id", ondelete="CASCADE"),
nullable=nullable,
)
class_attribs[arg.__tablename__] = relationship(
arg.__name__,
Expand Down
2 changes: 1 addition & 1 deletion src/fonduer/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_set_of_stable_ids(
set_of_stable_ids.update(
set(
[
tuple(m.context.get_stable_id() for m in c)
tuple(m.context.get_stable_id() for m in c) if c else None
for c in getattr(doc, candidate_class.__tablename__ + "s")
]
)
Expand Down
31 changes: 31 additions & 0 deletions tests/candidates/test_candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,3 +541,34 @@ def test_pickle_subclasses():
pickle.loads(pickle.dumps(part))
pickle.loads(pickle.dumps(temp))
pickle.loads(pickle.dumps(parttemp))


def test_candidate_with_nullable_mentions():
"""Test if mentions can be NULL."""
docs_path = "tests/data/html/112823.html"
pdf_path = "tests/data/pdf/112823.pdf"
doc = parse_doc(docs_path, "112823", pdf_path)

# Mention Extraction
MentionTemp = mention_subclass("MentionTemp")
temp_ngrams = MentionNgramsTemp(n_max=2)
mention_extractor_udf = MentionExtractorUDF(
[MentionTemp],
[temp_ngrams],
[temp_matcher],
)
doc = mention_extractor_udf.apply(doc)

assert len(doc.mention_temps) == 23

# Candidate Extraction
CandidateTemp = candidate_subclass("CandidateTemp", [MentionTemp], nullables=[True])
candidate_extractor_udf = CandidateExtractorUDF(
[CandidateTemp], [None], False, False, True
)

doc = candidate_extractor_udf.apply(doc, split=0)
# The number of extracted candidates should be that of mentions + 1 (NULL)
assert len(doc.candidate_temps) == len(doc.mention_temps) + 1
# Extracted candidates should include one with NULL mention.
assert None in [c[0] for c in doc.candidate_temps]

0 comments on commit 2ce00a9

Please sign in to comment.