Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow loss masking for defined spans of characters #113

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
9367fcd
convert character spans to token spans
sohamparikh Jan 14, 2025
515dcb5
handle null spans
sohamparikh Jan 14, 2025
3457ba2
handle spans in data iterator, fix test
sohamparikh Jan 15, 2025
c7373b9
bump dataset version
sohamparikh Jan 16, 2025
0699e0f
create a document class
sohamparikh Jan 16, 2025
419acd7
make loss masking work for prepare and training
sohamparikh Jan 24, 2025
acad1e4
merge main
sohamparikh Jan 24, 2025
daa2ad7
bos and eos options for tokenizer
sohamparikh Jan 25, 2025
bb175bf
loss masking for triton cross entropy
sohamparikh Jan 27, 2025
0e7ad8b
fix random data tests
sohamparikh Jan 28, 2025
989a8f8
revert precommit versions
sohamparikh Jan 28, 2025
9633f88
fix memmap dataset test
sohamparikh Jan 28, 2025
4f955ff
fix remaining dataset tests
sohamparikh Jan 28, 2025
70e40e8
Merge branch 'main' into soham/loss-masking-spans
sohamparikh Jan 28, 2025
1ac5052
compose tests
sohamparikh Jan 28, 2025
aebb5a0
handle special tokens from config
sohamparikh Jan 28, 2025
d8e3ae1
fix fim to handle bos and eos
sohamparikh Jan 28, 2025
a887dd6
address review comments
sohamparikh Jan 28, 2025
40a80f6
fix memmap tests
sohamparikh Jan 28, 2025
e908303
fix fim tests
sohamparikh Jan 28, 2025
20ffae8
special tokens mode -> sequence delimiters
sohamparikh Jan 29, 2025
753e731
GPTDataBatch -> GPTBatch
sohamparikh Jan 29, 2025
cce0701
GPTMemmapDocument, GPTMemmapSample -> GPTSample
sohamparikh Jan 29, 2025
0583dec
make loss masking opt-in in cross-entropy
sohamparikh Jan 30, 2025
7c40bf2
make spans opt-in during prepare
sohamparikh Jan 30, 2025
1998b9f
make spans opt-in for train
sohamparikh Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions fast_llm/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ class MultiprocessingContext(str, enum.Enum):
TokenizerFromFile = "TokenizerFromFile"


class SequenceDelimiters(str, enum.Enum):
tokenizer_default = "tokenizer_default"
bos_only = "bos_only"
eos_only = "eos_only"
bos_eos = "bos_eos"
no_delimiters = "no_delimiters"


@config_class()
class TokenizerConfig(Config):
"""
Expand All @@ -34,3 +42,8 @@ class TokenizerConfig(Config):
desc="Path to the tokenizer file.",
hint=FieldHint.core,
)
sequence_delimiters: SequenceDelimiters = Field(
default=SequenceDelimiters.bos_only,
desc="Boundary tokens (bos/eos) to use for tokenizing sequences",
hint=FieldHint.core,
)
5 changes: 5 additions & 0 deletions fast_llm/data/data/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
desc="Multiprocessing context. Do not touch.",
hint=FieldHint.expert,
)
use_loss_masking_spans: bool = Field(
default=False,
desc="Read and use loss masking spans from the dataset, if present.",
hint=FieldHint.feature,
)

def _validate(self) -> None:
if not self.datasets:
Expand Down
19 changes: 19 additions & 0 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import dataclasses
import logging
import pathlib
import typing
import warnings

import numpy as np
import torch
import torch.utils.data

Expand All @@ -11,6 +13,7 @@
from fast_llm.data.data.gpt.config import GPTDataConfig
from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.gpt.config import GPTSamplingConfig
from fast_llm.data.dataset.gpt.sampled import GPTSample
from fast_llm.data.dataset.monitor import DatasetMonitor
from fast_llm.data.iterator import SampledDatasetIterator
from fast_llm.data.tokenizer import Tokenizer
Expand All @@ -23,6 +26,20 @@
logger = logging.getLogger(__name__)


@dataclasses.dataclass
class GPTBatch:
token_ids: torch.Tensor
loss_masking_spans: list[torch.Tensor]


def gpt_data_collate_fn(batch: list[GPTSample]) -> GPTBatch:
stacked_ids = np.stack([sample.token_ids for sample in batch])
stacked_spans = None
if batch[0].loss_masking_spans is not None:
stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch]
return GPTBatch(token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans)


class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]):
"""
A global class for all dataset needs, including loading, splitting, sampling and iteration.
Expand Down Expand Up @@ -82,6 +99,7 @@ def setup(
sequence_length=self._max_sequence_length,
vocab_size=self._vocab_size,
tokenizer=self._tokenizer,
use_loss_masking_spans=self._config.use_loss_masking_spans,
)
dataset = self._config.datasets[phase].build_and_sample(sampling_config)
self._datasets[phase] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)
Expand Down Expand Up @@ -120,6 +138,7 @@ def get_iterator(
num_workers=num_workers,
prefetch_factor=prefetch_factor,
pin_memory=True,
collate_fn=gpt_data_collate_fn,
multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None,
)
)
10 changes: 8 additions & 2 deletions fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class GPTSamplingConfig(SamplingConfig):
sequence_length: int
vocab_size: int
tokenizer: "Tokenizer"
use_loss_masking_spans: bool = False


@config_class()
Expand Down Expand Up @@ -128,11 +129,16 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig):
desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.",
hint=FieldHint.core,
)
use_loss_masking_spans: bool = Field(
default=False,
desc="Read and use loss masking spans from the dataset, if present.",
hint=FieldHint.feature,
)

def build(self) -> "GPTMemmapDataset":
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset

return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path)
return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.use_loss_masking_spans)


@config_class()
Expand Down Expand Up @@ -382,7 +388,7 @@ def build_and_sample(self, config: GPTSamplingConfig) -> SampledDataset:
dataset_configs = [
GPTDatasetSliceConfig(
# TODO: this duplicates memmap datasets for each phase.
dataset=GPTMemmapDatasetConfig(path=prefix),
dataset=GPTMemmapDatasetConfig(path=prefix, use_loss_masking_spans=config.use_loss_masking_spans),
begin=phase_splits[phase_index],
end=phase_splits[phase_index + 1],
)
Expand Down
33 changes: 19 additions & 14 deletions fast_llm/data/dataset/gpt/fim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.gpt.config import FimConfig, GPTSamplingConfig
from fast_llm.data.dataset.gpt.sampled import GPTSample
from fast_llm.engine.distributed.config import MAX_SEED


Expand Down Expand Up @@ -42,12 +43,14 @@ def __getitem__(self, idx: int) -> np.ndarray:
def name(self) -> str:
return f"{self._dataset.name}_fim"

def _fim(self, sample: np.ndarray, np_rng: np.random.RandomState) -> np.ndarray:
def _fim(self, sample: GPTSample, np_rng: np.random.RandomState) -> GPTSample:
# FIM
# TODO: permute segments in sample_list, before concatenating.
sample_len = sample.shape[0]
if self._config.rate > 0.0 and sample.loss_masking_spans is not None:
raise NotImplementedError("FIM is currently not compatible with loss masking.")
sample_len = sample.token_ids.shape[0]
eod = self._tokenizer.eod
segment_breaks = np.argwhere(sample == eod) # split sample by document
segment_breaks = np.argwhere(sample.token_ids == eod) # split sample by document

if segment_breaks.shape != (0, 1): # then there is an EOD token in this example
curr_start_position = 0
Expand All @@ -57,26 +60,26 @@ def _fim(self, sample: np.ndarray, np_rng: np.random.RandomState) -> np.ndarray:
# Only permute non-empty segments.
if loc - curr_start_position > 0:
# permute {prefix, suffix, middle} or {suffix, prefix, middle}
permuted = self._fim_split_and_permute_sequence(sample[curr_start_position:loc], np_rng)
permuted = self._fim_split_and_permute_sequence(sample.token_ids[curr_start_position:loc], np_rng)
new_samples += [permuted, [eod]]

curr_start_position = loc + 1 # jump over the EOD token
# Permute the segment after the last EOD
permuted = self._fim_split_and_permute_sequence(sample[curr_start_position:], np_rng)
permuted = self._fim_split_and_permute_sequence(sample.token_ids[curr_start_position:], np_rng)
new_samples.append(permuted)

sample = np.concatenate(new_samples)
sample.token_ids = np.concatenate(new_samples)
else:
sample = self._fim_split_and_permute_sequence(sample, np_rng)
sample.token_ids = self._fim_split_and_permute_sequence(sample.token_ids, np_rng)

# Truncate or pad sequence to max-length
diff = sample.shape[0] - sample_len
diff = sample.token_ids.shape[0] - sample_len
if diff > 0: # too long
sample = sample[:sample_len]
sample.token_ids = sample.token_ids[:sample_len]
elif diff < 0: # too short
sample = np.concatenate([sample, np.full((-1 * diff), self._pad_tok_id)])
sample.token_ids = np.concatenate([sample.token_ids, np.full((-1 * diff), self._pad_tok_id)])

assert sample.shape[0] == sample_len
assert sample.token_ids.shape[0] == sample_len
return sample
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this code will change the order of tokens in the sequence, we would need to change the masks accordingly to allow for FIM with loss masking.
At this point, I think we should not and fail if FIM was used with loss masking.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

throwing an error statement in this function now


def _fim_split_and_permute_sequence(self, sequence: np.ndarray, np_rng: np.random.RandomState) -> np.ndarray:
Expand Down Expand Up @@ -150,9 +153,11 @@ def _fim_permute_sequence(
middle = contents[boundaries[0] : boundaries[1]]
suffix = contents[boundaries[1] :]

prefix = np.array([*self._tokenizer.tokenize(prefix)], dtype=np.int64)
middle = np.array([*self._tokenizer.tokenize(middle)], dtype=np.int64)
suffix = np.array([*self._tokenizer.tokenize(suffix)], dtype=np.int64)
prefix = np.array([*self._tokenizer.tokenize(prefix, end_of_text=False)], dtype=np.int64)
middle = np.array(
[*self._tokenizer.tokenize(middle, beginning_of_text=False, end_of_text=False)], dtype=np.int64
)
suffix = np.array([*self._tokenizer.tokenize(suffix, beginning_of_text=False)], dtype=np.int64)

# here we truncate each given segment to fit the same length as it was before
# A consequence is that we never reach the end of a file?
Expand Down
6 changes: 6 additions & 0 deletions fast_llm/data/dataset/gpt/indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def get_document_sizes(self) -> np.ndarray:
# TODO: This can be really big.
return self._dataset.get_document_sizes()[self._begin : self._end]

def get_span_sizes(self) -> np.ndarray:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this only used in the tests? If so I'm not sure it's worth making a public method at this stage.
(And would need to be added to GPTIndexedDataset too)

return self._dataset.get_span_sizes()[self._begin : self._end]


class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset](
ConcatenatedDataset[IndexedDatasetType], GPTIndexedDataset
Expand All @@ -45,3 +48,6 @@ class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset](
def get_document_sizes(self) -> np.ndarray:
# TODO: This can be really big.
return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets])

def get_span_sizes(self) -> np.ndarray:
return np.concatenate([dataset.get_span_sizes() for dataset in self._datasets])
80 changes: 68 additions & 12 deletions fast_llm/data/dataset/gpt/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np

from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset
from fast_llm.data.dataset.gpt.sampled import GPTSample
from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.utils import Assert, div
Expand All @@ -19,17 +20,22 @@ class GPTMemmapDataset(GPTIndexedDataset):
See https://github.com/NVIDIA/Megatron-LM?tab=readme-ov-file#data-preprocessing for more details.
"""

def __init__(self, name: str, prefix: pathlib.Path | str):
self._init(name, prefix)
def __init__(self, name: str, prefix: pathlib.Path | str, use_loss_masking_spans: bool = False):
self._init(name, prefix, use_loss_masking_spans)

def _init(self, name: str, prefix: pathlib.Path | str) -> None:
def _init(self, name: str, prefix: pathlib.Path | str, use_loss_masking_spans: bool = False) -> None:
super().__init__()
self._name = name
self._prefix = pathlib.Path(prefix)
self._read_spans = False

with self._prefix.with_suffix(".idx").open("rb") as stream:
Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER)
Assert.eq(struct.unpack("<Q", stream.read(8))[0], 1)
self._version = struct.unpack("<Q", stream.read(8))[0]
assert self._version in [1, 2], f"Unsupported version for gpt_memmap dataset: {self._version}."
if self._version == 2:
self._has_spans = struct.unpack("<B", stream.read(1))[0]
self._read_spans = use_loss_masking_spans and self._has_spans and self._version == 2

self._dtype = MEMMAP_DTYPES[struct.unpack("<B", stream.read(1))[0]].numpy
self._num_documents = struct.unpack("<Q", stream.read(8))[0]
Expand All @@ -48,6 +54,16 @@ def _init(self, name: str, prefix: pathlib.Path | str) -> None:
offset=offset + self._document_sizes.nbytes,
)

if self._read_spans:
self._num_spans = np.frombuffer(
self._index_bin_buffer,
dtype=np.int32,
count=self._num_documents,
offset=offset + self._document_sizes.nbytes + self._pointers.nbytes,
)
self._span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes + self._num_spans.nbytes
self._num_spans_cumsum = np.cumsum(self._num_spans, dtype=np.int64)

self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C")
self._bin_buffer = memoryview(self._bin_buffer_mmap)

Expand All @@ -65,13 +81,26 @@ def __del__(self):
self._index_bin_buffer_mmap._mmap.close() # noqa
del self._index_bin_buffer_mmap

def get(self, idx, offset=0, length=None) -> np.ndarray:
return np.frombuffer(
def get(self, idx, offset=0, length=None) -> GPTSample:
token_ids = np.frombuffer(
self._bin_buffer,
dtype=self._dtype,
count=self._document_sizes[idx] - offset if length is None else length,
offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize,
)
spans = None
if self._read_spans:
spans = np.frombuffer(
self._index_bin_buffer,
dtype=np.int32,
count=self._num_spans[idx] * 2,
offset=self._span_offset + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize,
).reshape(-1, 2)
# adjust the spans for the offset and length
spans = spans[(spans[:, 0] < offset + len(token_ids)) & (spans[:, 1] >= offset)]
spans[:, 0] = np.maximum(spans[:, 0], offset) - offset
spans[:, 1] = np.minimum(spans[:, 1], offset + len(token_ids) - 1) - offset
return GPTSample(token_ids=token_ids, loss_masking_spans=spans)

@property
def name(self) -> str:
Expand All @@ -92,14 +121,25 @@ def get_document_sizes(self) -> np.ndarray:
"""
return self._document_sizes

def get_span_sizes(self) -> np.ndarray:
"""
The number of spans in each document in the dataset.
The resulting array could be very large, so this method should be called cautiously,
and derived classes should try to avoid holding the whole array im memory.
"""
return self._num_spans

@classmethod
def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[np.ndarray]):
def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]):
# Initialize metadata
dtype = None
num_documents = 0
lengths = []
pointers = []
offset = 0
# number of spans for each document
num_spans = []
spans = []

prefix = pathlib.Path(prefix)
prefix.parent.mkdir(parents=True, exist_ok=True)
Expand All @@ -109,30 +149,42 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[np
for document in documents:
# Infer dtype from the first document
if dtype is None:
dtype = document.dtype
dtype = document.token_ids.dtype
assert dtype is not None, "Document dtype could not be inferred from the data."

# Ensure all documents have the same dtype
assert document.dtype == dtype, f"Expected dtype {dtype}, got {document.dtype}."
assert document.token_ids.dtype == dtype, f"Expected dtype {dtype}, got {document.token_ids.dtype}."

# Write document to binary file
bin_stream.write(document.tobytes(order="C"))
bin_stream.write(document.token_ids.tobytes(order="C"))

# Update metadata
doc_length = len(document)
doc_length = len(document.token_ids)
lengths.append(doc_length)
pointers.append(offset)
if document.loss_masking_spans is not None:
num_spans.append(len(document.loss_masking_spans))
spans.append(document.loss_masking_spans)
offset += doc_length * np.dtype(dtype).itemsize
num_documents += 1

# Finalize metadata arrays
lengths = np.array(lengths, dtype=np.int32)
pointers = np.array(pointers, dtype=np.int64)
num_spans = np.array(num_spans, dtype=np.int32)
if len(spans) > 0:
spans = np.vstack(spans, dtype=np.int32)
else:
spans = np.array(spans, dtype=np.int32)

# Write the index file (.idx)
with prefix.with_suffix(".idx").open("wb") as idx_stream:
idx_stream.write(MEMMAP_INDEX_HEADER)
idx_stream.write(struct.pack("<Q", 1)) # Version
# Indicates the version
# Version 2 optionally adds loss-masking spans
idx_stream.write(struct.pack("<Q", 2))
# Flag to indicate whether loss-masking spans are present
idx_stream.write(struct.pack("<B", 1 if spans.size > 0 else 0))
# Data type
idx_stream.write(struct.pack("<B", MEMMAP_DTYPES_INV[DataType.from_numpy(dtype.type)]))
# "Number of sequences", same as documents in our case
Expand All @@ -143,5 +195,9 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[np
idx_stream.write(lengths.tobytes(order="C"))
# Sequence (document) begin offsets in the bin file
idx_stream.write(pointers.tobytes(order="C"))
# Number of spans per document
idx_stream.write(num_spans.tobytes(order="C"))
# Span indices for each document
idx_stream.write(spans.tobytes(order="C"))
# Document indices, unused but needed for compatibility with Megatron-LM
idx_stream.write(np.arange(num_documents + 1, dtype=np.int64).tobytes(order="C"))
Loading
Loading