diff --git a/pyproject.toml b/pyproject.toml index ccc5c65b..f97ec59e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dolma" -version = "1.0.5" +version = "1.0.6" description = "Data filters" license = { text = "Apache-2.0" } readme = "README.md" diff --git a/python/dolma/taggers/gopher.py b/python/dolma/taggers/gopher.py index 3417f625..6a5a9c84 100644 --- a/python/dolma/taggers/gopher.py +++ b/python/dolma/taggers/gopher.py @@ -1,4 +1,5 @@ import logging +import re from collections import Counter from dataclasses import dataclass from statistics import median @@ -135,7 +136,7 @@ def as_spans(self) -> List[Span]: return spans -def get_attributes(text: str) -> GopherAttributes: +def get_attributes(text: str, ignore_empty_lines: bool = False) -> GopherAttributes: attrs = GopherAttributes([], []) attrs.character_count = len(text) if attrs.character_count == 0: @@ -173,7 +174,11 @@ def get_attributes(text: str) -> GopherAttributes: ) / max(ng_char_count, 1) attrs.fraction_of_characters_in_duplicate_ngrams.append((n, value)) - lines = text.split("\n") + if ignore_empty_lines: + lines = re.split(r"\n+", text) + else: + lines = text.split("\n") + line_count = len(lines) for line in lines: if any(line.startswith(s) for s in BULLET_POINTS): @@ -218,3 +223,11 @@ def predict(self, doc: Document) -> DocResult: attrs = get_attributes(doc.text) result = DocResult(doc=doc, spans=attrs.as_spans()) return result + + +@TaggerRegistry.add("gopher_v2") +class GopherTaggerV2(GopherTagger): + def predict(self, doc: Document) -> DocResult: + attrs = get_attributes(doc.text, ignore_empty_lines=True) + result = DocResult(doc=doc, spans=attrs.as_spans()) + return result