Skip to content

Commit

Permalink
Fixing a pathological case for slow tokenizers (#14981)
Browse files Browse the repository at this point in the history
* Fixing a pathological case for slow tokenizers

* Update src/transformers/tokenization_utils.py
  • Loading branch information
Narsil authored Dec 30, 2021
1 parent d1ba56d commit d7d60df
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/transformers/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def split(self, text: str) -> List[str]:
# This is used by the lookahead which needs to skip over
# some text where the full match exceeded the place in the initial
# for loop
skip = None
skip = 0
# Main loop, Giving this algorithm O(n) complexity
for current, current_char in enumerate(text):
if skip and current < skip:
Expand Down Expand Up @@ -175,6 +175,11 @@ def split(self, text: str) -> List[str]:
lookahead_index = current
end = current
next_char = text[lookahead_index] if lookahead_index < len(text) else None
if "" in looktrie_pointer:
start = lookstart
end = lookahead_index
skip = lookahead_index

while next_char in looktrie_pointer:
looktrie_pointer = looktrie_pointer[next_char]
lookahead_index += 1
Expand Down Expand Up @@ -219,7 +224,7 @@ def split(self, text: str) -> List[str]:

# If this character is a starting character within the trie
# start keeping track of this partial match.
if current_char in self.data:
if current >= skip and current_char in self.data:
states[current] = self.data[current_char]

# We have a cut at the end with states.
Expand Down
7 changes: 7 additions & 0 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3687,6 +3687,13 @@ def test_trie_suffix_tokens(self):
trie.add("C")
self.assertEqual(trie.split("ABC"), ["AB", "C"])

def test_trie_skip(self):
trie = Trie()
trie.add("ABC")
trie.add("B")
trie.add("CD")
self.assertEqual(trie.split("ABCD"), ["ABC", "D"])

def test_cut_text_hardening(self):
# Even if the offsets are wrong, we necessarily output correct string
# parts.
Expand Down

0 comments on commit d7d60df

Please sign in to comment.