Skip to content

Commit

Permalink
Match legacy
Browse files Browse the repository at this point in the history
  • Loading branch information
jlamypoirier committed Jan 22, 2025
1 parent 54e5fa5 commit 755c355
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 14 deletions.
17 changes: 13 additions & 4 deletions fast_llm/data/dataset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,10 @@ class BlendedDatasetConfig(SampledDatasetConfig):
desc="The blending weight of each dataset.",
hint=FieldHint.core,
)
seed_shift: int = Field(
default=54783, desc="Shift the seed for each sub-dataset for extra randomness.", hint=FieldHint.feature
legacy: bool = Field(
default=False,
desc="Use the legacy formulas for sub-dataset seeds and sample sizes.",
hint=FieldHint.deprecated,
)

def _validate(self) -> None:
Expand All @@ -162,13 +164,20 @@ def build_and_sample(

# Build and sample the datasets.
# TODO: Vary the seed?
# Add 5 times the standard deviation (of a binomial distribution)
# so the probability of sampling more than this amount during blending is negligible.

sampled_datasets = [
dataset.build_and_sample(
# Blending is deterministic and the error will never be higher than 1.
dataclasses.replace(
config,
num_samples=math.ceil(weight * config.num_samples) + 1,
seed=config.seed + i * self.seed_shift,
num_samples=(
math.ceil(weight * (config.num_samples + 5 * (config.num_samples * (1 - weight)) ** 0.5))
if self.legacy
else math.ceil(weight * config.num_samples) + 1
),
seed=config.seed + i * (0 if self.legacy else 697),
),
)
for i, (dataset, weight) in enumerate(zip(self.datasets, self.weights, strict=True))
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def build_and_sample(self, config: GPTSamplingConfig) -> SampledDataset:
name="blended",
datasets=dataset_configs,
weights=dataset_weights,
seed_shift=0,
legacy=True,
)
if len(dataset_configs) > 1
else dataset_configs[0]
Expand Down
30 changes: 21 additions & 9 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,11 +365,11 @@ def test_gpt_slice_data_legacy():

GPT_BLENDED_EXPECTED_SAMPLES = [
[1725, 74, 207, 1635, 4440, 2774],
[328, 80, 263, 890, 1797, 88],
[2066, 207, 6436, 2360, 2210, 6633],
[359, 489, 4266, 2052, 5351, 80],
[374, 7534, 87, 1073, 79, 480],
[8008, 498, 71, 727, 80, 315],
[1852, 71, 776, 7878, 7390, 80],
[555, 3042, 83, 207, 498, 3373],
[2210, 8179, 73, 2582, 897, 1178],
[409, 5091, 328, 1378, 5483, 88],
]
Expand All @@ -387,11 +387,11 @@ def test_gpt_blended():
{"type": "memmap", "path": DATASET_PREFIX_MIX_1},
],
"weights": [0.75, 0.25],
"seed_shift": 0,
},
GPTBlendedDatasetConfig,
).build_and_sample(get_sampling_config(8, sequence_length=5))
Assert.eq(len(sampled), 8)
print(np.stack([sampled[i] for i in range(8)]).tolist())
Assert.all_equal(
np.stack([sampled[i] for i in range(8)]),
np.array(GPT_BLENDED_EXPECTED_SAMPLES),
Expand All @@ -411,7 +411,6 @@ def test_gpt_blended_data():
{"type": "memmap", "path": DATASET_PREFIX_MIX_1},
],
"weights": [0.75, 0.25],
"seed_shift": 0,
}
}
},
Expand All @@ -424,6 +423,18 @@ def test_gpt_blended_data():
)


GPT_BLENDED_LEGACY_EXPECTED_SAMPLES = [
[1725, 74, 207, 1635, 4440, 2774],
[328, 80, 263, 890, 1797, 88],
[359, 489, 4266, 2052, 5351, 80],
[374, 7534, 87, 1073, 79, 480],
[8008, 498, 71, 727, 80, 315],
[1852, 71, 776, 7878, 7390, 80],
[2210, 8179, 73, 2582, 897, 1178],
[409, 5091, 328, 1378, 5483, 88],
]


def test_gpt_blended_data_legacy():
get_test_dataset()
get_test_dataset_1()
Expand All @@ -438,18 +449,18 @@ def test_gpt_blended_data_legacy():
)
Assert.all_equal(
np.stack(samples[PhaseType.training]),
np.array(GPT_BLENDED_EXPECTED_SAMPLES),
np.array(GPT_BLENDED_LEGACY_EXPECTED_SAMPLES),
)


GPT_BLENDED_MIXED_EXPECTED_SAMPLES = [
[1725, 74, 207, 1635, 4440, 2774],
[5291, 3692, 4158, 503, 2201, 2587],
[916, 6683, 7685, 1277, 5106, 378],
[359, 489, 4266, 2052, 5351, 80],
[5558, 4833, 2889, 7476, 1588, 226],
[3359, 6803, 780, 4561, 669, 7878],
[374, 7534, 87, 1073, 79, 480],
[8008, 498, 71, 727, 80, 315],
[786, 3161, 8179, 2300, 6160, 2531],
[6920, 2218, 2921, 3963, 7606, 6904],
[2210, 8179, 73, 2582, 897, 1178],
]

Expand All @@ -467,8 +478,9 @@ def test_gpt_blended_mixed():
"weights": [0.6, 0.4],
},
GPTBlendedDatasetConfig,
).build_and_sample(get_sampling_config(8, sequence_length=5, seed=109766))
).build_and_sample(get_sampling_config(8, sequence_length=5))
Assert.eq(len(sampled), 8)
print(np.stack([sampled[i] for i in range(8)]).tolist())
Assert.all_equal(
np.stack([sampled[i] for i in range(8)]),
np.array(GPT_BLENDED_MIXED_EXPECTED_SAMPLES),
Expand Down

0 comments on commit 755c355

Please sign in to comment.