Skip to content

Commit

Permalink
convert character spans to token spans
Browse files Browse the repository at this point in the history
  • Loading branch information
sohamparikh committed Jan 14, 2025
1 parent 07b1622 commit 9367fcd
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
3 changes: 3 additions & 0 deletions fast_llm/data/preparator/gpt_memmap/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class GPTHuggingfaceDatasetConfig(Config):
desc="Field of the dataset to use.",
hint=FieldHint.optional,
)
spans_field: None | str = Field(
default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional
)
data_type: DataType | None = Field(
default=None,
desc="Data type of the dataset field."
Expand Down
36 changes: 32 additions & 4 deletions fast_llm/data/preparator/gpt_memmap/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,37 @@ class GPTMemmapDatasetPreparator(DatasetPreparator):
_tokenizer: Tokenizer
_data_type: DataType

def _tokenize_with_spans(self, sample):
"""
Perform span-aware tokenization and return the tokenized input_ids along with token spans.
"""
char_spans = sample[self._config.dataset.spans_field]
text = sample[self._config.dataset.field]
input_ids = []
token_spans = []
char_pos = 0
for start, end in char_spans:
if char_pos < start:
curr_text = text[char_pos:start]
tokenized_text = self._tokenizer.tokenize(curr_text)
input_ids.extend(tokenized_text)
curr_text = text[start : end + 1]
tokenized_text = self._tokenizer.tokenize(curr_text)
input_ids.extend(tokenized_text)
token_spans.append((len(token_spans), len(token_spans) + len(tokenized_text) - 1))
char_pos = end + 1
if char_pos < len(text):
curr_text = text[char_pos:]
tokenized_text = self._tokenizer.tokenize(curr_text)
input_ids.extend(tokenized_text)
return np.array(input_ids, dtype=self._data_type.numpy), token_spans

def _tokenize_batch(self, batch):
input_ids = [
np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy)
for text in batch[self._config.dataset.field]
]
input_ids, token_spans = zip(*[self._tokenize_with_spans(sample) for sample in batch])
num_tokens = [len(x) for x in input_ids]
return {
"input_ids": input_ids,
"token_spans": token_spans,
"num_tokens": num_tokens,
}

Expand Down Expand Up @@ -126,6 +149,11 @@ def run(self):
)
if self._config.dataset.field not in dataset.column_names:
raise ValueError(f"Dataset does not have field '{self._config.dataset.field}'.")
if (
self._config.dataset.spans_field is not None
and self._config.dataset.spans_field not in dataset.column_names
):
raise ValueError(f"Dataset does not have spans field '{self._config.dataset.spans_field}'.")

# Tokenize the dataset in parallel
tokenized_dataset = dataset.map(
Expand Down

0 comments on commit 9367fcd

Please sign in to comment.