From 0b184d3409bfe577c058c89512cbe62e30ae4716 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 17 Jan 2025 15:47:03 -0500 Subject: [PATCH] Dataset tests --- fast_llm/data/config.py | 3 +- fast_llm/data/data/gpt/data.py | 4 +- fast_llm/data/dataset/gpt/config.py | 20 +++ fast_llm/data/dataset/gpt/dummy.py | 2 +- fast_llm/data/dataset/gpt/fim.py | 2 +- fast_llm/data/dataset/gpt/indexed.py | 6 +- fast_llm/data/dataset/gpt/sampled.py | 107 +++++++++------ fast_llm/data/dataset/indexed.py | 2 +- fast_llm/utils.py | 10 +- tests/common.py | 51 +++++-- tests/test_dataset.py | 195 +++++++++++++++++++++++++++ 11 files changed, 341 insertions(+), 61 deletions(-) create mode 100644 tests/test_dataset.py diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 32675749..1586d370 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -1,4 +1,5 @@ import enum +import pathlib from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.utils import Assert @@ -28,7 +29,7 @@ class TokenizerConfig(Config): hint=FieldHint.deprecated, valid=check_field(Assert.eq, TokenizerFromFile), ) - path: str | None = Field( + path: pathlib.Path | None = Field( default=None, desc="Path to the tokenizer file.", hint=FieldHint.core, diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index bf6ec573..6759c759 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -14,7 +14,7 @@ from fast_llm.data.dataset.blended import BlendedDataset from fast_llm.data.dataset.gpt.config import DatasetSource, GPTSamplingConfig from fast_llm.data.dataset.gpt.dummy import GPTDummyDataset -from fast_llm.data.dataset.gpt.fim import FimDataset +from fast_llm.data.dataset.gpt.fim import GPTFimDataset from fast_llm.data.dataset.gpt.indexed import GPTDatasetSlice from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.monitor import DatasetMonitor @@ -245,7 +245,7 @@ def _build_and_sample_gpt_dataset(self, name: str, sampling_configs: PhaseSplits datasets = SampledSplitDataset[GPTDatasetSlice]( "fim", { - phase: FimDataset(self.config.fim, dataset, sampling_configs[phase]) + phase: GPTFimDataset(self.config.fim, dataset, sampling_configs[phase]) for phase, dataset in datasets.items() }, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index ff15e55d..4743c809 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -68,6 +68,26 @@ class FimConfig(Config): desc="TODO.", hint=FieldHint.feature, ) + prefix_token: str = Field( + default="", + desc="TODO.", + hint=FieldHint.feature, + ) + middle_token: str = Field( + default="", + desc="TODO.", + hint=FieldHint.feature, + ) + pad_token: str = Field( + default="", + desc="TODO.", + hint=FieldHint.feature, + ) + suffix_token: str = Field( + default="", + desc="TODO.", + hint=FieldHint.feature, + ) def _validate(self): super()._validate() diff --git a/fast_llm/data/dataset/gpt/dummy.py b/fast_llm/data/dataset/gpt/dummy.py index 2aa868f8..a637ef93 100644 --- a/fast_llm/data/dataset/gpt/dummy.py +++ b/fast_llm/data/dataset/gpt/dummy.py @@ -29,7 +29,7 @@ def __len__(self) -> int: return self._config.num_samples def __getitem__(self, idx) -> np.ndarray: - return np.random.RandomState(self._config.seed + 4857643).randint( + return np.random.RandomState(self._config.seed + 48576439 + 74593 * idx).randint( 0, self._config.vocab_size, size=(self._config.sequence_length + 1,), dtype=np.int64 ) diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 0ed76d80..323953c8 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -10,7 +10,7 @@ FIM_SUFFIX = "" -class FimDataset(SampledDataset): +class GPTFimDataset(SampledDataset): """ An implementation of FIM (fill in the middle) post-processing of GPT datasets. Adapted from https://github.com/EleutherAI/gpt-neox/blob/FIM-clean/megatron/data/gpt2_dataset.py diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 8a0951a8..11c2dae2 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -28,7 +28,7 @@ def sample(self, config: GPTSamplingConfig) -> "GPTSampledIndexedDataset": return GPTSampledIndexedDataset(self, config) -class GPTDatasetSlice(DatasetSlice, GPTIndexedDataset): +class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[IndexedDatasetType], GPTIndexedDataset): """ A GPT dataset, which reads samples from (a split of) a `MMapIndexedDataset` pointing to a GPT dataset. """ @@ -56,7 +56,9 @@ def from_splits(cls, dataset: GPTIndexedDataset, phase_split: dict[PhaseType, fl ) -class GPTConcatenatedDataset(ConcatenatedDataset, GPTIndexedDataset): +class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( + ConcatenatedDataset[IndexedDatasetType], GPTIndexedDataset +): _datasets: list[GPTIndexedDataset] def get_document_sizes(self) -> np.ndarray: diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 943873a9..94544529 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -1,3 +1,4 @@ +import logging import math import pathlib import typing @@ -18,6 +19,8 @@ except ImportError: _extension_available = False +logger = logging.getLogger(__name__) + class GPTSampledIndexedDataset(SampledDataset): """ @@ -35,33 +38,41 @@ def __init__( assert isinstance(sampling_config, GPTSamplingConfig) self._indexed_dataset = indexed_dataset - cache_prefix = ( - f"{self.name}_ns_{sampling_config.num_samples}_sl_{sampling_config.sequence_length}" - f"_s_{sampling_config.seed}" - ) - # TODO: Any way to combine into a single file? (Memmap is harder) - self._doc_idx_filename = sampling_config.cache_directory / (cache_prefix + "_doc_idx.npy") - self._sample_idx_filename = sampling_config.cache_directory / (cache_prefix + "_sample_idx.npy") - self._shuffle_idx_filename = sampling_config.cache_directory / (cache_prefix + "_shuffle_idx.npy") - group = sampling_config.distributed.world_group - # Build the indexed mapping if it doesn't exist. - # TODO: This only works if the dataset location is accessible by all job. - if (group is None or group.rank() == 0) and not ( - self._doc_idx_filename.is_file() - and self._sample_idx_filename.is_file() - and self._shuffle_idx_filename.is_file() - ): - if sampling_config.verbose: + + if sampling_config.cache_directory is None: + log_main_rank( + " > No dataset cache directory provided, building the index map on all ranks." + "This may be very inefficient...", + log_fn=logger.warning, + ) + self._doc_idx, self._sample_idx, self._shuffle_idx = self._sample(sampling_config) + else: + cache_prefix = ( + f"{self.name}_ns_{sampling_config.num_samples}_sl_{sampling_config.sequence_length}" + f"_s_{sampling_config.seed}" + ) + # TODO: Any way to combine into a single file? (Memmap is harder) + self._doc_idx_filename = sampling_config.cache_directory / (cache_prefix + "_doc_idx.npy") + self._sample_idx_filename = sampling_config.cache_directory / (cache_prefix + "_sample_idx.npy") + self._shuffle_idx_filename = sampling_config.cache_directory / (cache_prefix + "_shuffle_idx.npy") + + # Build the indexed mapping if it doesn't exist. + # TODO: This only works if the dataset location is accessible by all job. + if (group is None or group.rank() == 0) and not ( + self._doc_idx_filename.is_file() + and self._sample_idx_filename.is_file() + and self._shuffle_idx_filename.is_file() + ): log_main_rank(" > Building the index map on rank 0 ...") - doc_idx, sample_idx, shuffle_idx = self._sample(sampling_config) - sampling_config.cache_directory.mkdir(parents=True, exist_ok=True) - np.save(self._doc_idx_filename, doc_idx) - np.save(self._sample_idx_filename, sample_idx) - np.save(self._shuffle_idx_filename, shuffle_idx) + doc_idx, sample_idx, shuffle_idx = self._sample(sampling_config) + sampling_config.cache_directory.mkdir(parents=True, exist_ok=True) + np.save(self._doc_idx_filename, doc_idx) + np.save(self._sample_idx_filename, sample_idx) + np.save(self._shuffle_idx_filename, shuffle_idx) safe_barrier(group, self._indexed_dataset.name) - self._load_mappings(sampling_config.verbose) + self._load_mappings(True) def _sample(self, sampling_config: GPTSamplingConfig) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ @@ -100,7 +111,7 @@ def _sample(self, sampling_config: GPTSamplingConfig) -> tuple[np.ndarray, np.nd sampling_config.sequence_length, num_epochs, num_tokens, - sampling_config.verbose, + True, ) # shuffle-idx. @@ -121,24 +132,44 @@ def _sample(self, sampling_config: GPTSamplingConfig) -> tuple[np.ndarray, np.nd # TODO: The doc and sample idx are way bigger than needed when sampling for << 1 epoch. return doc_idx, sample_idx, shuffle_idx[: sampling_config.num_samples] - def __getstate__(self) -> tuple[GPTIndexedDataset, pathlib.Path, pathlib.Path, pathlib.Path]: - return ( - self._indexed_dataset, - self._doc_idx_filename, - self._sample_idx_filename, - self._shuffle_idx_filename, - ) + def __getstate__( + self, + ) -> tuple[GPTIndexedDataset, pathlib.Path | np.ndarray, pathlib.Path | np.ndarray, pathlib.Path | np.ndarray]: + if hasattr(self, "_doc_idx_filename"): + return ( + self._indexed_dataset, + self._doc_idx, + self._sample_idx_filename, + self._shuffle_idx_filename, + ) + else: + return ( + self._indexed_dataset, + self._doc_idx, + self._sample_idx, + self._shuffle_idx, + ) def __setstate__(self, state: tuple[GPTIndexedDataset, pathlib.Path, pathlib.Path, pathlib.Path]) -> None: - ( - self._indexed_dataset, - self._doc_idx_filename, - self._sample_idx_filename, - self._shuffle_idx_filename, - ) = state + if isinstance(state[1], pathlib.Path): + ( + self._indexed_dataset, + self._doc_idx_filename, + self._sample_idx_filename, + self._shuffle_idx_filename, + ) = state + else: + ( + self._indexed_dataset, + self._doc_idx, + self._sample_idx, + self._shuffle_idx, + ) = state self._load_mappings(False) def _load_mappings(self, verbose: bool) -> None: + if hasattr(self, "_doc_idx"): + return if verbose: log_main_rank(lambda: f" > loading doc-idx mapping from {self._doc_idx_filename}") self._doc_idx = np.load(self._doc_idx_filename, mmap_mode="r") @@ -169,7 +200,7 @@ def __getitem__(self, idx: int) -> typing.Any: doc_l, offset_l = self._sample_idx[shuffled_idx + 1] sample_list = [ self._indexed_dataset.get( - self._doc_idx[doc], + self._doc_idx[doc].item(), offset=(doc == doc_f) * offset_f, length=offset_l + 1 - (doc == doc_f) * offset_f if doc == doc_l else None, ) diff --git a/fast_llm/data/dataset/indexed.py b/fast_llm/data/dataset/indexed.py index b9226724..8a652dda 100644 --- a/fast_llm/data/dataset/indexed.py +++ b/fast_llm/data/dataset/indexed.py @@ -24,7 +24,7 @@ def __len__(self) -> int: """ -class DatasetSlice(IndexedDataset): +class DatasetSlice[IndexedDatasetType: IndexedDataset](IndexedDataset): def __init__( self, diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 44e2d586..31d4c93d 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -144,6 +144,10 @@ def rms_close(x, y, threshold): def all_equal(x, y): import torch + # Make it work for numpy arrays. + x = torch.as_tensor(x) + y = torch.as_tensor(y) + neq = x != y if neq.any().item(): # noqa index = torch.where(neq) # noqa @@ -156,9 +160,13 @@ def all_equal(x, y): def all_different(x, y): import torch + # Make it work for numpy arrays. + x = torch.as_tensor(x) + y = torch.as_tensor(y) + eq = x == y if eq.any().item(): # noqa - index = torch.where(eq) # noqa + index = torch.where(torch.as_tensor(eq)) # noqa raise AssertionError( f"Tensors have {index[0].numel()} unexpected matching entries out of " f"{x.numel()}: {x[index]} != {y[index]} at index {torch.stack(index, -1)}" diff --git a/tests/common.py b/tests/common.py index 6c9d11d7..9494fe14 100644 --- a/tests/common.py +++ b/tests/common.py @@ -17,7 +17,6 @@ MixtralGPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) -from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM from fast_llm.tools.train import CliTrainingConfig from tests.compare_tensor_logs import CompareConfig, compare_tensor_logs @@ -34,10 +33,14 @@ ARTIFACT_PATH = "runs/0/artifacts" -TOKENIZER_PATH = TEST_RESULTS_PATH / "data" / "tokenizer" +TOKENIZER_PATH = TEST_RESULTS_PATH / "tokenizer" / "common" TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" -DATASET_PREFIX = TEST_RESULTS_PATH / "data" / "dataset/data" +DATASET_PREFIX = TEST_RESULTS_PATH / "dataset" / "common" +TEST_VOCAB_SIZE = 8192 +# Random lowercase: 80.7% (3.1% each); space: 18.6%; doc end: 0.6% +TEST_CHARACTERS = (string.ascii_lowercase) * 5 + " " * 30 + "\n" +TEST_DATASET_TOKENS = 1000000 CONFIG_BASE_FAST_LLM = [ "training.logs.interval=1", @@ -47,7 +50,7 @@ "model.base_model.transformer.hidden_size=256", "model.base_model.transformer.num_attention_heads=8", "model.base_model.transformer.init_method_std=0.022", - "model.base_model.vocab_size=8192", + f"model.base_model.vocab_size={TEST_VOCAB_SIZE}", f"model.multi_stage.debug_param_init={_LOG_LEVEL}", f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", f"model.multi_stage.debug_layer_gradients={_LOG_LEVEL}", @@ -58,7 +61,20 @@ "training.num_workers=0", "batch.batch_size=8", "batch.sequence_length=512", - f"data.path={DATASET_PREFIX}", + "data.datasets.Training.type=slice", + "data.datasets.Training.end=0.969", + "data.datasets.Training.dataset.type=memmap", + f"data.datasets.Training.dataset.path={DATASET_PREFIX}", + "data.datasets.Validation.type=slice", + "data.datasets.Validation.begin=0.969", + "data.datasets.Validation.end=0.999", + "data.datasets.Validation.dataset.type=memmap", + f"data.datasets.Validation.dataset.path={DATASET_PREFIX}", + "data.datasets.Test.type=slice", + "data.datasets.Test.begin=0.999", + "data.datasets.Test.end=1", + "data.datasets.Test.dataset.type=memmap", + f"data.datasets.Test.dataset.path={DATASET_PREFIX}", "optimizer.learning_rate.base=0.0001", ] CONFIG_BASE_MEGATRON = [ @@ -84,7 +100,7 @@ "--valid-num-workers=0", "--tokenizer-type=NullTokenizer", # Megatron messes with the vocab size, so we have to subtract 1. - "--vocab-size=8191", + f"--vocab-size={TEST_VOCAB_SIZE-1}", f"--data-path={DATASET_PREFIX}", "--lr-decay-style=constant", # Initialization is set up to match MCore models (MCore inverts self-attn qkv and dense layers compared to original Megatron) @@ -148,7 +164,7 @@ _CONFIGS = { "gpt2": ("gpt", CONFIG_GPT2_FAST_LLM, CONFIG_GPT2_MEGATRON, CONFIG_GPT2_COMMON, None), - "sc1": ("gpt", HuggingfaceGPTModelForCausalLM, CONFIG_SC1_FAST_LLM, CONFIG_SC1_MEGATRON, CONFIG_SC1_COMMON, None), + "sc1": ("gpt", CONFIG_SC1_FAST_LLM, CONFIG_SC1_MEGATRON, CONFIG_SC1_COMMON, None), "starcoder2": ( "gpt", CONFIG_SC2_FAST_LLM, @@ -193,21 +209,28 @@ requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") -def get_test_data(): +def get_test_dataset( + prefix=DATASET_PREFIX, + seed=1234, + num_tokens=TEST_DATASET_TOKENS, + characters=TEST_CHARACTERS, + vocab_size=TEST_VOCAB_SIZE, +): if not TOKENIZER_FILE.is_file(): import transformers transformers.AutoTokenizer.from_pretrained("bigcode/santacoder").save_pretrained(TOKENIZER_PATH) - if not (DATASET_PREFIX.with_suffix(".idx").is_file() and DATASET_PREFIX.with_suffix(".bin").is_file()): + if not (prefix.with_suffix(".idx").is_file() and prefix.with_suffix(".bin").is_file()): import transformers - characters = (string.ascii_lowercase) * 5 + " " * 30 + "\n" - documents = "".join(random.Random(1234).choices(characters, k=1000000)).splitlines() + documents = "".join(random.Random(seed).choices(characters, k=num_tokens)).splitlines() tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) - documents = [np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % 8192 for document in documents] - GPTMemmapDataset.write_dataset(DATASET_PREFIX, documents) + documents = [ + np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size for document in documents + ] + GPTMemmapDataset.write_dataset(prefix, documents) def run_test_script( @@ -264,7 +287,7 @@ def run_test_script( if skip: print("Reusing existing run.") else: - get_test_data() + get_test_dataset() if num_gpus == 1 and not is_megatron: CliTrainingConfig.parse_and_run(script) else: diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 00000000..18219dca --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,195 @@ +import pathlib + +import numpy +import numpy as np +import pytest + +from fast_llm.data.config import TokenizerConfig +from fast_llm.data.dataset.blended import BlendedDataset +from fast_llm.data.dataset.gpt.config import FimConfig, GPTSamplingConfig +from fast_llm.data.dataset.gpt.dummy import GPTDummyDataset +from fast_llm.data.dataset.gpt.fim import GPTFimDataset +from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset +from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.tokenizer import Tokenizer +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.utils import Assert +from tests.common import DATASET_PREFIX, TEST_RESULTS_PATH, TEST_VOCAB_SIZE, TOKENIZER_PATH, get_test_dataset + +DATASET_CACHE = TEST_RESULTS_PATH / "dataset" / "cache" + + +def get_sampling_config( + num_samples: int, + *, + seed: int = 95733, + cache_directory: pathlib.Path | None = None, + distributed: Distributed = Distributed(DistributedConfig(), use_cpu=True), + sequence_length: int = 512, + vocab_size=TEST_VOCAB_SIZE, + tokenizer: Tokenizer | None = None, +) -> GPTSamplingConfig: + # Config with convenient defaults. + return GPTSamplingConfig( + num_samples=num_samples, + seed=seed, + cache_directory=cache_directory, + distributed=distributed, + sequence_length=sequence_length, + vocab_size=vocab_size, + tokenizer=tokenizer, + verbose=True, + ) + + +def test_gpt_dummy_dataset(): + # Make sure the dummy dataset works and check for unintended changes in behavior. + sampled = GPTDummyDataset("dummy").sample(get_sampling_config(4, sequence_length=7)) + Assert.eq(len(sampled), 4) + Assert.all_equal( + numpy.stack([sampled[i] for i in range(4)]), + np.array( + [ + [3954, 4105, 6766, 859, 5494, 1675, 1303, 6913], + [1654, 5701, 32, 1662, 7053, 3487, 1861, 1502], + [5409, 6240, 5504, 7458, 7667, 3955, 3151, 3912], + [5640, 6131, 7750, 2699, 1349, 2585, 7113, 6981], + ] + ), + ) + + +# Most documents are too long to write here, we test a few known short ones. +_MEMMAP_DATASET_EXPECTED_LENGTH = 6153 +_MEMMAP_DATASET_EXPECTED_TOKENS = 508327 +_MEMMAP_DATASET_EXPECTED_SAMPLES = { + 9: [], + 10: [80, 85, 4295, 4182, 489, 727, 84, 698, 1197, 583], + 13: [78, 727, 74, 317, 1358, 89], + 15: [78], +} + + +@pytest.mark.parametrize("cache_directory", (None, pathlib.Path(DATASET_CACHE) / "test_memmap")) +def test_gpt_memmap(cache_directory): + # Make sure the memmap dataset works and check for unintended changes in behavior. + get_test_dataset() + dataset = GPTMemmapDataset("memmap", DATASET_PREFIX) + Assert.eq(len(dataset), _MEMMAP_DATASET_EXPECTED_LENGTH) + sizes = dataset.get_document_sizes() + Assert.eq(sizes.sum(), _MEMMAP_DATASET_EXPECTED_TOKENS) + Assert.all_equal([len(dataset.get(i)) for i in range(100)], sizes[:100]) + for i, sample in _MEMMAP_DATASET_EXPECTED_SAMPLES.items(): + Assert.all_equal(dataset.get(i), np.array(sample, dtype=numpy.uint16)) + + +def test_gpt_concatenate(): + # Make sure the dataset concatenation works and check for unintended changes in behavior. + get_test_dataset() + dataset = GPTConcatenatedDataset[GPTIndexedDataset]( + "concatenated", [GPTMemmapDataset("memmap", DATASET_PREFIX) for _ in range(3)] + ) + Assert.eq(len(dataset), 3 * _MEMMAP_DATASET_EXPECTED_LENGTH) + sizes = dataset.get_document_sizes() + Assert.eq(sizes.sum(), 3 * _MEMMAP_DATASET_EXPECTED_TOKENS) + for i in range(3): + begin = i * _MEMMAP_DATASET_EXPECTED_LENGTH + Assert.all_equal([len(dataset.get(begin + i)) for i in range(100)], sizes[begin : begin + 100]) + for i, sample in _MEMMAP_DATASET_EXPECTED_SAMPLES.items(): + Assert.all_equal(dataset.get(begin + i), np.array(sample, dtype=numpy.uint16)) + + +def test_gpt_slice(): + # Make sure dataset splitting works and check for unintended changes in behavior. + get_test_dataset() + # samples[9:18] + dataset = GPTDatasetSlice("slice", GPTMemmapDataset("memmap", DATASET_PREFIX), 9, 18) + Assert.eq(len(dataset), 9) + sizes = dataset.get_document_sizes() + Assert.all_equal([len(dataset.get(i)) for i in range(9)], sizes[:9]) + for i, sample in _MEMMAP_DATASET_EXPECTED_SAMPLES.items(): + Assert.all_equal(dataset.get(i - 9), np.array(sample, dtype=numpy.uint16)) + + +def test_gpt_sampling(): + # Make sure the memmap dataset works and check for unintended changes in behavior. + get_test_dataset() + sampled = GPTMemmapDataset("memmap", DATASET_PREFIX).sample(get_sampling_config(8, sequence_length=5)) + Assert.eq(len(sampled), 8) + Assert.all_equal( + np.stack([sampled[i] for i in range(8)]), + np.array( + [ + [1725, 74, 207, 1635, 4440, 2774], + [359, 489, 4266, 2052, 5351, 80], + [374, 7534, 87, 1073, 79, 480], + [8008, 498, 71, 727, 80, 315], + [2210, 8179, 73, 2582, 897, 1178], + [409, 5091, 328, 1378, 5483, 88], + [83, 4457, 3316, 333, 489, 317], + [330, 155, 2449, 1136, 1106, 5370], + ] + ), + ) + + +def test_gpt_blended(): + # Make sure dataset blending works and check for unintended changes in behavior. + get_test_dataset() + sampled = BlendedDataset( + "blended", + [ + GPTMemmapDataset("memmap", DATASET_PREFIX).sample(get_sampling_config(5, sequence_length=5)), + GPTDummyDataset("dummy").sample(get_sampling_config(3, sequence_length=5, seed=150516)), + ], + [0.6, 0.4], + get_sampling_config(8, sequence_length=5), + ) + Assert.eq(len(sampled), 8) + Assert.all_equal( + np.stack([sampled[i] for i in range(8)]), + np.array( + [ + [1725, 74, 207, 1635, 4440, 2774], + [5291, 3692, 4158, 503, 2201, 2587], + [359, 489, 4266, 2052, 5351, 80], + [5558, 4833, 2889, 7476, 1588, 226], + [374, 7534, 87, 1073, 79, 480], + [8008, 498, 71, 727, 80, 315], + [786, 3161, 8179, 2300, 6160, 2531], + [2210, 8179, 73, 2582, 897, 1178], + ] + ), + ) + + +def test_gpt_fim(): + # Make sure the FIM wrapper works in a simple case and check for unintended changes in behavior. + get_test_dataset() + # The test tokenizer doesn't have fim tokens, so we work around it. + sampling_config = get_sampling_config( + 8, sequence_length=5, tokenizer=Tokenizer(TokenizerConfig.from_dict({"path": TOKENIZER_PATH})) + ) + sampled = GPTFimDataset( + FimConfig(rate=0.5, prefix_token="w", middle_token="x", pad_token="y", suffix_token="z"), + GPTMemmapDataset("memmap", DATASET_PREFIX).sample(sampling_config), + sampling_config, + ) + Assert.eq(len(sampled), 8) + # TODO: Does this output make sense? + Assert.all_equal( + np.stack([sampled[i] for i in range(8)]), + np.array( + [ + [1725, 74, 207, 1635, 4440, 2774], + [359, 489, 4266, 2052, 5351, 80], + [86, 89, 22255, 1073, 79, 480], + [8008, 498, 71, 727, 80, 315], + [2210, 8179, 73, 2582, 897, 1178], + [86, 89, 88, 87, 409, 70], + [86, 83, 744, 89, 64, 333], + [86, 89, 1461, 87, 330, 7876], + ] + ), + )