Skip to content

Commit

Permalink
Dataset tweaks (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
jlamypoirier authored Jan 16, 2025
1 parent 8e30926 commit fbffa0f
Show file tree
Hide file tree
Showing 30 changed files with 620 additions and 520 deletions.
3 changes: 2 additions & 1 deletion fast_llm/data/auto.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from fast_llm.data.preparator.config import DatasetPreparatorConfig
from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig
from fast_llm.utils import Registry

dataset_preparator_registry = Registry(
dataset_preparator_registry = Registry[str, DatasetPreparatorConfig](
"DatasetPreparator",
{
dataset_preparator.preparator_name: dataset_preparator
Expand Down
9 changes: 0 additions & 9 deletions fast_llm/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,6 @@ class MultiprocessingContext(str, enum.Enum):
spawn = "spawn"


def _validate_split(value):
Assert.leq(len(value), 3)
return value + [0] * (len(value) - 3)


def _validate_path(value):
return [value] if isinstance(value, str) else value


TokenizerFromFile = "TokenizerFromFile"


Expand Down
35 changes: 29 additions & 6 deletions fast_llm/data/data/abstract.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,39 @@
import abc
import pathlib
import typing

from fast_llm.engine.distributed.config import PhaseType
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.config import Configurable
from fast_llm.data.data.config import DataConfig
from fast_llm.engine.distributed.config import DistributedConfig, PhaseType
from fast_llm.engine.schedule.config import BatchConfig

if typing.TYPE_CHECKING:
from fast_llm.engine.distributed.distributed import Distributed


class Data[ConfigType: DataConfig](Configurable[ConfigType], abc.ABC):
_distributed: "Distributed"
_samples_per_phase: dict[PhaseType, int]
_cache_directory: pathlib.Path | None

def __init__(self, config: DataConfig, distributed_config: DistributedConfig) -> None:
super().__init__(config)
self._distributed_config = distributed_config

class Data(abc.ABC):
# TODO: Improve interface
@abc.abstractmethod
def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int]):
pass
def setup(
self,
distributed: "Distributed",
samples_per_phase: dict[PhaseType, int],
cache_directory: pathlib.Path,
) -> None:
self._distributed = distributed
self._samples_per_phase = samples_per_phase
self._cache_directory = cache_directory

@property
def distributed(self):
return self._distributed

@abc.abstractmethod
def get_iterator(
Expand Down
12 changes: 2 additions & 10 deletions fast_llm/data/data/config.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
import pathlib
import typing

from fast_llm.config import Config, Field, config_class


@config_class
class SamplingConfig(Config):
num_samples: int = Field(default=1, desc="Number of samples to generate.")
seed: int = Field(default=0, desc="Random seed.")
cache_directory: pathlib.Path | None = Field(default=None, desc="Path to the sampling cache directory.")
verbose: bool = Field(default=True, desc="Log sampling progress.")
from fast_llm.config import Config, config_class
from fast_llm.data.dataset.config import SamplingConfig


@config_class()
Expand Down
51 changes: 4 additions & 47 deletions fast_llm/data/data/gpt/config.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,12 @@
import enum

from fast_llm.config import Field, FieldHint, check_field, config_class
from fast_llm.data.config import MultiprocessingContext, TokenizerConfig, _validate_path, _validate_split
from fast_llm.data.data.config import DataConfig, SamplingConfig
from fast_llm.data.dataset.gpt.fim.config import FimConfig
from fast_llm.data.config import MultiprocessingContext, TokenizerConfig
from fast_llm.data.data.config import DataConfig
from fast_llm.data.dataset.gpt.config import GPTLegacyConfig
from fast_llm.utils import Assert


class DatasetSource(str, enum.Enum):
"""
An enum for the different ways to load datasets.
TODO: Reduce the diversity?
TODO: Is this specific to GPT data?
"""

list = "list"
file = "file"
sample = "sample"
random = "random"


@config_class()
class GPTDataConfig(DataConfig):
class GPTDataConfig(DataConfig, GPTLegacyConfig):
"""
Configuration for the dataset(s), split and sampling.
Currently hard-coded to a GPT dataset.
Expand All @@ -35,29 +20,6 @@ class GPTDataConfig(DataConfig):
desc="Configuration for the tokenizer (for FIM).",
hint=FieldHint.feature,
)
fim: FimConfig = Field(
default_factory=FimConfig,
desc="Configuration for Fill In the Middle (FIM).",
hint=FieldHint.feature,
)
# TODO: set default to [1,0,0]?
split: list[float] = Field(
default_factory=lambda: [969, 30, 1],
desc="Split ratio for train, valid and test datasets.",
hint=FieldHint.core,
valid=_validate_split,
)
format: DatasetSource = Field(
default=DatasetSource.list,
desc="Format for the dataset definition.",
hint=FieldHint.core,
)
path: list[str] = Field(
default_factory=list,
desc="Path or list of paths and weights.",
hint=FieldHint.core,
valid=_validate_path,
)
data_sample_warn_time_ms: float = Field(
default=1000,
desc="Warn if a sample takes too long to load.",
Expand All @@ -69,8 +31,3 @@ class GPTDataConfig(DataConfig):
desc="Multiprocessing context. Do not touch.",
hint=FieldHint.expert,
)


@config_class
class GPTSamplingConfig(SamplingConfig):
sequence_length: int = Field(default=None, desc="Number of token in each sample.")
86 changes: 51 additions & 35 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,18 @@
import torch.utils.data

from fast_llm.data.data.abstract import Data
from fast_llm.data.data.gpt.config import DatasetSource, GPTDataConfig, GPTSamplingConfig
from fast_llm.data.data.gpt.config import GPTDataConfig
from fast_llm.data.dataset.abstract import CopySplitDataset, PhaseSplits, SampledSplitDataset
from fast_llm.data.dataset.blended import BlendedDataset
from fast_llm.data.dataset.gpt.dummy import DummyGPTDataset
from fast_llm.data.dataset.gpt.config import DatasetSource, GPTSamplingConfig
from fast_llm.data.dataset.gpt.dummy import GPTDummyDataset
from fast_llm.data.dataset.gpt.fim import FimDataset
from fast_llm.data.dataset.gpt.indexed import GPTDatasetSlice
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset
from fast_llm.data.dataset.gpt.slice import GPTDatasetSlice
from fast_llm.data.dataset.monitor import DatasetMonitor
from fast_llm.data.iterator import SampledDatasetIterator
from fast_llm.data.tokenizer import Tokenizer
from fast_llm.engine.config_utils.run import get_run, log_main_rank
from fast_llm.engine.config_utils.run import log_main_rank
from fast_llm.engine.distributed.config import DistributedConfig, PhaseType
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.engine.schedule.config import BatchConfig
Expand All @@ -26,7 +29,7 @@
logger = logging.getLogger(__name__)


class GPTData(Data):
class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]):
"""
A global class for all dataset needs, including loading, splitting, sampling and iteration.
Currently hard-coded to a GPT dataset.
Expand All @@ -35,9 +38,6 @@ class GPTData(Data):

_datasets: SampledSplitDataset
_tokenizer: Tokenizer | None
_distributed: Distributed
_cache_directory: pathlib.Path | None
_samples_per_phase: dict[PhaseType, int]
_phases: typing.ClassVar[tuple[PhaseType, ...]] = (PhaseType.training, PhaseType.validation, PhaseType.test)
_is_setup: bool = False

Expand All @@ -52,8 +52,7 @@ def __init__(
Create the data and gather some basic information on the dataset(s).
Should be `setup` before use.
"""
self._config = config
self._distributed_config = distributed_config
super().__init__(config, distributed_config)
self._vocab_size = vocab_size
self._max_sequence_length = max_sequence_length
Assert.eq(len(self._config.split), len(self._phases))
Expand Down Expand Up @@ -114,22 +113,25 @@ def __init__(
}
self._dataset_weights = {name: weight for name, weight in zip(dataset_names, dataset_weights)}

def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int]):
def setup(
self,
distributed: "Distributed",
samples_per_phase: dict[PhaseType, int],
cache_directory: pathlib.Path,
) -> None:
"""
Load the datasets, and prepare or load the samplings.
This may take a while and a significant amount of cpu memory.
"""
run = get_run()
super().setup(distributed, samples_per_phase, cache_directory)
Assert.leq(set(samples_per_phase), set(self._phase_split))
log_main_rank(f"Preparing {self._num_datasets} datasets. This may take several minutes.")
self._tokenizer = Tokenizer(self._config.tokenizer) if self._config.fim.rate > 0 else None
self._distributed = distributed
self._samples_per_phase = samples_per_phase
if run.experiment_directory is None:
if self._cache_directory is None:
# TODO: Avoid this
warnings.warn(f"Using the dataset directory for the index cache.")
self._cache_directory = None
else:
self._cache_directory = run.experiment_directory / "dataset_cache"

datasets_and_weights = []
for i, (name, weight) in enumerate(self._dataset_weights.items()):
Expand All @@ -144,18 +146,22 @@ def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int
expected_samples
+ 5 * math.sqrt(expected_samples * self._dataset_weights[name] * (1 - self._dataset_weights[name]))
)

sampling_configs = PhaseSplits[GPTSamplingConfig](
{
phase: GPTSamplingConfig(
num_samples=dataset_samples_per_phase[phase],
sequence_length=self._max_sequence_length,
seed=self._distributed_config.seed,
cache_directory=(
self._dataset_prefixes[name].parent
if self._cache_directory is None and isinstance(self._dataset_prefixes[name], pathlib.Path)
else self._cache_directory
),
verbose=self._num_datasets <= 5,
distributed=self._distributed,
sequence_length=self._max_sequence_length,
vocab_size=self._vocab_size,
tokenizer=self._tokenizer,
)
for phase, num_samples in dataset_samples_per_phase.items()
if num_samples > 0
Expand All @@ -166,40 +172,41 @@ def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int
)

if len(datasets_and_weights) == 1:
self._datasets = datasets_and_weights[0][0]
datasets = datasets_and_weights[0][0]
else:
self._datasets = BlendedDataset.apply(
datasets = BlendedDataset.apply(
"blended",
datasets_and_weights,
PhaseSplits[GPTSamplingConfig](
{
phase: GPTSamplingConfig(
num_samples=samples_per_phase,
sequence_length=self._max_sequence_length,
seed=self._distributed_config.seed,
cache_directory=None if self._cache_directory is None else self._cache_directory,
verbose=self._num_datasets <= 5,
distributed=self._distributed,
sequence_length=self._max_sequence_length,
vocab_size=self._vocab_size,
tokenizer=self._tokenizer,
)
for phase, samples_per_phase in self._samples_per_phase.items()
}
),
self,
)
self._datasets = SampledSplitDataset[GPTDatasetSlice](
"monitor",
{
phase: DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)
for phase, dataset in datasets.items()
},
)
self._is_setup = True

@property
def config(self):
return self._config

@property
def tokenizer(self):
def tokenizer(self) -> Tokenizer:
assert self._is_setup
return self._tokenizer

@property
def distributed(self):
return self._distributed

def get_iterator(
self,
batch_config: BatchConfig,
Expand All @@ -208,7 +215,7 @@ def get_iterator(
consumed_samples: int,
num_workers: int,
prefetch_factor: int | None = None,
):
) -> typing.Iterator[typing.Any]:
assert self._is_setup
Assert.incl(phase, self._datasets)
Assert.in_range_incl(batch_config.sequence_length, 1, self._max_sequence_length)
Expand All @@ -231,13 +238,22 @@ def get_iterator(
)

def _build_and_sample_gpt_dataset(self, name: str, sampling_configs: PhaseSplits[GPTSamplingConfig]):
return GPTDatasetSlice.from_splits(
datasets = GPTDatasetSlice.from_splits(
GPTMemmapDataset(name, self._dataset_prefixes[name]), self._phase_split
).sample(sampling_configs, self)
).sample(sampling_configs)
if self._config.fim.rate > 0:
datasets = SampledSplitDataset[GPTDatasetSlice](
"fim",
{
phase: FimDataset(self.config.fim, dataset, sampling_configs[phase])
for phase, dataset in datasets.items()
},
)
return datasets

def _build_and_sample_dummy_dataset(self, name: str, sampling_configs: PhaseSplits[GPTSamplingConfig]):
return CopySplitDataset(
f"{name}_split",
DummyGPTDataset(name, self._max_sequence_length, self._vocab_size),
GPTDummyDataset(name),
list(sampling_configs),
).sample(sampling_configs, self)
).sample(sampling_configs)
Loading

0 comments on commit fbffa0f

Please sign in to comment.