Skip to content

Commit

Permalink
Dataset tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jlamypoirier committed Jan 17, 2025
1 parent 5ba311c commit 0b184d3
Show file tree
Hide file tree
Showing 11 changed files with 341 additions and 61 deletions.
3 changes: 2 additions & 1 deletion fast_llm/data/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import enum
import pathlib

from fast_llm.config import Config, Field, FieldHint, check_field, config_class
from fast_llm.utils import Assert
Expand Down Expand Up @@ -28,7 +29,7 @@ class TokenizerConfig(Config):
hint=FieldHint.deprecated,
valid=check_field(Assert.eq, TokenizerFromFile),
)
path: str | None = Field(
path: pathlib.Path | None = Field(
default=None,
desc="Path to the tokenizer file.",
hint=FieldHint.core,
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from fast_llm.data.dataset.blended import BlendedDataset
from fast_llm.data.dataset.gpt.config import DatasetSource, GPTSamplingConfig
from fast_llm.data.dataset.gpt.dummy import GPTDummyDataset
from fast_llm.data.dataset.gpt.fim import FimDataset
from fast_llm.data.dataset.gpt.fim import GPTFimDataset
from fast_llm.data.dataset.gpt.indexed import GPTDatasetSlice
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset
from fast_llm.data.dataset.monitor import DatasetMonitor
Expand Down Expand Up @@ -245,7 +245,7 @@ def _build_and_sample_gpt_dataset(self, name: str, sampling_configs: PhaseSplits
datasets = SampledSplitDataset[GPTDatasetSlice](
"fim",
{
phase: FimDataset(self.config.fim, dataset, sampling_configs[phase])
phase: GPTFimDataset(self.config.fim, dataset, sampling_configs[phase])
for phase, dataset in datasets.items()
},
)
Expand Down
20 changes: 20 additions & 0 deletions fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,26 @@ class FimConfig(Config):
desc="TODO.",
hint=FieldHint.feature,
)
prefix_token: str = Field(
default="<fim_prefix>",
desc="TODO.",
hint=FieldHint.feature,
)
middle_token: str = Field(
default="<fim_middle>",
desc="TODO.",
hint=FieldHint.feature,
)
pad_token: str = Field(
default="<fim_pad>",
desc="TODO.",
hint=FieldHint.feature,
)
suffix_token: str = Field(
default="<fim_suffix>",
desc="TODO.",
hint=FieldHint.feature,
)

def _validate(self):
super()._validate()
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/dataset/gpt/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __len__(self) -> int:
return self._config.num_samples

def __getitem__(self, idx) -> np.ndarray:
return np.random.RandomState(self._config.seed + 4857643).randint(
return np.random.RandomState(self._config.seed + 48576439 + 74593 * idx).randint(
0, self._config.vocab_size, size=(self._config.sequence_length + 1,), dtype=np.int64
)

Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/dataset/gpt/fim.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
FIM_SUFFIX = "<fim_suffix>"


class FimDataset(SampledDataset):
class GPTFimDataset(SampledDataset):
"""
An implementation of FIM (fill in the middle) post-processing of GPT datasets.
Adapted from https://github.com/EleutherAI/gpt-neox/blob/FIM-clean/megatron/data/gpt2_dataset.py
Expand Down
6 changes: 4 additions & 2 deletions fast_llm/data/dataset/gpt/indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def sample(self, config: GPTSamplingConfig) -> "GPTSampledIndexedDataset":
return GPTSampledIndexedDataset(self, config)


class GPTDatasetSlice(DatasetSlice, GPTIndexedDataset):
class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[IndexedDatasetType], GPTIndexedDataset):
"""
A GPT dataset, which reads samples from (a split of) a `MMapIndexedDataset` pointing to a GPT dataset.
"""
Expand Down Expand Up @@ -56,7 +56,9 @@ def from_splits(cls, dataset: GPTIndexedDataset, phase_split: dict[PhaseType, fl
)


class GPTConcatenatedDataset(ConcatenatedDataset, GPTIndexedDataset):
class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset](
ConcatenatedDataset[IndexedDatasetType], GPTIndexedDataset
):
_datasets: list[GPTIndexedDataset]

def get_document_sizes(self) -> np.ndarray:
Expand Down
107 changes: 69 additions & 38 deletions fast_llm/data/dataset/gpt/sampled.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import math
import pathlib
import typing
Expand All @@ -18,6 +19,8 @@
except ImportError:
_extension_available = False

logger = logging.getLogger(__name__)


class GPTSampledIndexedDataset(SampledDataset):
"""
Expand All @@ -35,33 +38,41 @@ def __init__(
assert isinstance(sampling_config, GPTSamplingConfig)
self._indexed_dataset = indexed_dataset

cache_prefix = (
f"{self.name}_ns_{sampling_config.num_samples}_sl_{sampling_config.sequence_length}"
f"_s_{sampling_config.seed}"
)
# TODO: Any way to combine into a single file? (Memmap is harder)
self._doc_idx_filename = sampling_config.cache_directory / (cache_prefix + "_doc_idx.npy")
self._sample_idx_filename = sampling_config.cache_directory / (cache_prefix + "_sample_idx.npy")
self._shuffle_idx_filename = sampling_config.cache_directory / (cache_prefix + "_shuffle_idx.npy")

group = sampling_config.distributed.world_group
# Build the indexed mapping if it doesn't exist.
# TODO: This only works if the dataset location is accessible by all job.
if (group is None or group.rank() == 0) and not (
self._doc_idx_filename.is_file()
and self._sample_idx_filename.is_file()
and self._shuffle_idx_filename.is_file()
):
if sampling_config.verbose:

if sampling_config.cache_directory is None:
log_main_rank(
" > No dataset cache directory provided, building the index map on all ranks."
"This may be very inefficient...",
log_fn=logger.warning,
)
self._doc_idx, self._sample_idx, self._shuffle_idx = self._sample(sampling_config)
else:
cache_prefix = (
f"{self.name}_ns_{sampling_config.num_samples}_sl_{sampling_config.sequence_length}"
f"_s_{sampling_config.seed}"
)
# TODO: Any way to combine into a single file? (Memmap is harder)
self._doc_idx_filename = sampling_config.cache_directory / (cache_prefix + "_doc_idx.npy")
self._sample_idx_filename = sampling_config.cache_directory / (cache_prefix + "_sample_idx.npy")
self._shuffle_idx_filename = sampling_config.cache_directory / (cache_prefix + "_shuffle_idx.npy")

# Build the indexed mapping if it doesn't exist.
# TODO: This only works if the dataset location is accessible by all job.
if (group is None or group.rank() == 0) and not (
self._doc_idx_filename.is_file()
and self._sample_idx_filename.is_file()
and self._shuffle_idx_filename.is_file()
):
log_main_rank(" > Building the index map on rank 0 ...")
doc_idx, sample_idx, shuffle_idx = self._sample(sampling_config)
sampling_config.cache_directory.mkdir(parents=True, exist_ok=True)
np.save(self._doc_idx_filename, doc_idx)
np.save(self._sample_idx_filename, sample_idx)
np.save(self._shuffle_idx_filename, shuffle_idx)
doc_idx, sample_idx, shuffle_idx = self._sample(sampling_config)
sampling_config.cache_directory.mkdir(parents=True, exist_ok=True)
np.save(self._doc_idx_filename, doc_idx)
np.save(self._sample_idx_filename, sample_idx)
np.save(self._shuffle_idx_filename, shuffle_idx)

safe_barrier(group, self._indexed_dataset.name)
self._load_mappings(sampling_config.verbose)
self._load_mappings(True)

def _sample(self, sampling_config: GPTSamplingConfig) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Expand Down Expand Up @@ -100,7 +111,7 @@ def _sample(self, sampling_config: GPTSamplingConfig) -> tuple[np.ndarray, np.nd
sampling_config.sequence_length,
num_epochs,
num_tokens,
sampling_config.verbose,
True,
)

# shuffle-idx.
Expand All @@ -121,24 +132,44 @@ def _sample(self, sampling_config: GPTSamplingConfig) -> tuple[np.ndarray, np.nd
# TODO: The doc and sample idx are way bigger than needed when sampling for << 1 epoch.
return doc_idx, sample_idx, shuffle_idx[: sampling_config.num_samples]

def __getstate__(self) -> tuple[GPTIndexedDataset, pathlib.Path, pathlib.Path, pathlib.Path]:
return (
self._indexed_dataset,
self._doc_idx_filename,
self._sample_idx_filename,
self._shuffle_idx_filename,
)
def __getstate__(
self,
) -> tuple[GPTIndexedDataset, pathlib.Path | np.ndarray, pathlib.Path | np.ndarray, pathlib.Path | np.ndarray]:
if hasattr(self, "_doc_idx_filename"):
return (
self._indexed_dataset,
self._doc_idx,
self._sample_idx_filename,
self._shuffle_idx_filename,
)
else:
return (
self._indexed_dataset,
self._doc_idx,
self._sample_idx,
self._shuffle_idx,
)

def __setstate__(self, state: tuple[GPTIndexedDataset, pathlib.Path, pathlib.Path, pathlib.Path]) -> None:
(
self._indexed_dataset,
self._doc_idx_filename,
self._sample_idx_filename,
self._shuffle_idx_filename,
) = state
if isinstance(state[1], pathlib.Path):
(
self._indexed_dataset,
self._doc_idx_filename,
self._sample_idx_filename,
self._shuffle_idx_filename,
) = state
else:
(
self._indexed_dataset,
self._doc_idx,
self._sample_idx,
self._shuffle_idx,
) = state
self._load_mappings(False)

def _load_mappings(self, verbose: bool) -> None:
if hasattr(self, "_doc_idx"):
return
if verbose:
log_main_rank(lambda: f" > loading doc-idx mapping from {self._doc_idx_filename}")
self._doc_idx = np.load(self._doc_idx_filename, mmap_mode="r")
Expand Down Expand Up @@ -169,7 +200,7 @@ def __getitem__(self, idx: int) -> typing.Any:
doc_l, offset_l = self._sample_idx[shuffled_idx + 1]
sample_list = [
self._indexed_dataset.get(
self._doc_idx[doc],
self._doc_idx[doc].item(),
offset=(doc == doc_f) * offset_f,
length=offset_l + 1 - (doc == doc_f) * offset_f if doc == doc_l else None,
)
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/dataset/indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __len__(self) -> int:
"""


class DatasetSlice(IndexedDataset):
class DatasetSlice[IndexedDatasetType: IndexedDataset](IndexedDataset):

def __init__(
self,
Expand Down
10 changes: 9 additions & 1 deletion fast_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ def rms_close(x, y, threshold):
def all_equal(x, y):
import torch

# Make it work for numpy arrays.
x = torch.as_tensor(x)
y = torch.as_tensor(y)

neq = x != y
if neq.any().item(): # noqa
index = torch.where(neq) # noqa
Expand All @@ -156,9 +160,13 @@ def all_equal(x, y):
def all_different(x, y):
import torch

# Make it work for numpy arrays.
x = torch.as_tensor(x)
y = torch.as_tensor(y)

eq = x == y
if eq.any().item(): # noqa
index = torch.where(eq) # noqa
index = torch.where(torch.as_tensor(eq)) # noqa
raise AssertionError(
f"Tensors have {index[0].numel()} unexpected matching entries out of "
f"{x.numel()}: {x[index]} != {y[index]} at index {torch.stack(index, -1)}"
Expand Down
Loading

0 comments on commit 0b184d3

Please sign in to comment.