-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow loss masking for defined spans of characters #113
base: main
Are you sure you want to change the base?
Conversation
Looks good so far, but can you please add a short description and/or point to an issue? |
for start, end in char_spans: | ||
if char_pos < start: | ||
curr_text = text[char_pos:start] | ||
tokenized_text = self._tokenizer.tokenize(curr_text, add_special_tokens=beginning_of_text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works only for those tokenizers that only have a BOS but not a EOS token.
For those that come with both, can we control whether tokenize
adds the BOS and EOS tokens independently? I'm worried that we are adding the EOS token at the end of the first segment and the BOS token at the beginning of the last segment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch! I'll make it explicitly add BOS only for the first segment
Btw, most tokenizers (Llama-3.1
, Mistral-Nemo-Base-2407
, OLMoE-1B-7B-0924
) do not add the EOS token with add_special_tokens=True
. Does this mean we've been training the models without the EOS token?
In the future I think we should make this config driven. The default behaviour would be to add both BOS and EOS tokens. It's important for pretraining with attention mask, and especially for SFT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean we've been training the models without the EOS token?
Indeed, we decided that adding both BOS and EOS tokens in pretraining was unnecessary, because they are redundant. Here though I think we need to add them because we need to teach the model to terminate a response with the EOS token so that generation can stop at the right moment. Btw, I think HF is not adding the EOS token by default because otherwise prompts would end with it.
exp_logits1 = exp_logits.scatter( | ||
1, target, exp_logits.gather(1, target) - target_mask * sum_exp_logits.unsqueeze(dim=-1) | ||
) | ||
exp_logits2 = exp_logits1.mul((grad_output / logits.size(0)) / sum_exp_logits.unsqueeze(dim=-1)) | ||
if logits_scale_factor != 1.0: | ||
exp_logits2 *= logits_scale_factor | ||
|
||
grad = exp_logits2.to(logits.dtype) | ||
grad.index_put_((mask,), exp_logits2.to(logits.dtype)) | ||
|
||
predicted_logits = (target_mask * logits_norm.gather(1, target)).squeeze(1) | ||
all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does the triton implementation support masking?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it doesn't: https://github.com/ServiceNow/Fast-LLM/blob/soham/loss-masking-spans/fast_llm/functional/triton/cross_entropy.py
We need to add it. Since this is the same for all loss functions, it would make sense to implement it before dispatching to specialized cross-entropy implementations:
def cross_entropy_forward_backward(
logits,
target,
grad_output: float | None,
group: ProcessGroup | None,
implementation: CrossEntropyImpl = CrossEntropyImpl.fused,
logits_scale_factor: float = 1.0,
ignore_index: int=-100,
) -> tuple[torch.Tensor, torch.Tensor | None]:
...
mask = target != ignore_index
target = target[mask]
logits = logits[mask]
...
|
||
assert sample.shape[0] == sample_len | ||
assert sample.ids.shape[0] == sample_len | ||
return sample |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since this code will change the order of tokens in the sequence, we would need to change the masks accordingly to allow for FIM with loss masking.
At this point, I think we should not and fail if FIM was used with loss masking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
throwing an error statement in this function now
fast_llm/data/dataset/gpt/random.py
Outdated
end = np.random.RandomState(np_seed).randint(start, len(ids)) | ||
spans.append([start, end]) | ||
prev_end = end | ||
return GPTSample(ids=ids, spans=np.array(spans, dtype=np.int32).reshape(-1, 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's a nice addition, though I'm not sure if the random dataset actually needs spans... We use this only for benchmarking purposes to measure training performance without being IO bound.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added it for the tests. Can make the spans empty if you think it'll mess with the benchmark numbers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spans really need to be None here, though we could add a config parameter to generate random spans.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @sohamparikh, nice progress!
Is anything functional missing from this PR? Do the spans make it all the way to the loss function, and does packing work as expected with spans? Can we test this?
Functionally it's good to go now. How do we want to test this? I can test it on a bigger model with multiple nodes if that makes sense |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reviewed the overall structure, looks good for the intended purpose but will need polishing. Main areas to focus on (see individual comments):
- Spans need to be opt-in, so that there is a negligible impact when they are not used (which is most of the time).
- Turning samples and friends into dataclasses is a good idea but goes a bit outside the present scope. Ideally it would go in a separate PR, but it's ok to include here if done well (see coments)
- Names need to follow our style guide (https://servicenow.github.io/Fast-LLM/contributing/style-guide/) a bit better. Please use self-descriptive names as much as possible, ex.
spans
is a bit cryptic and could mean lots of things. (Not sure what to replace it with, but it should hint to loss masking in some way)
fast_llm/data/config.py
Outdated
@@ -34,3 +41,8 @@ class TokenizerConfig(Config): | |||
desc="Path to the tokenizer file.", | |||
hint=FieldHint.core, | |||
) | |||
special_tokens_mode: SpecialTokensMode = Field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't look self-descriptive enough. Could we think of a better name, and improve the description? (also not sure what this does?)
@@ -21,12 +22,17 @@ | |||
logger = logging.getLogger(__name__) | |||
|
|||
|
|||
@dataclasses.dataclass | |||
class GPTSample: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a nice addition, but will have an impact far beyond the scope of the current PR, so we need to do it carefully (same for other similar dataclasses):
- Use meaningful field names (
ids
->token_ids
,spans
->?) - If we're going the dataclass way we should go for it all the way, i.e. inherit from a
Sample
base class and adjust type hints everywhere. Same for the batch thing. No need to do it in this PR, but if not we'll need an issue to refer to. - The custom model also needs to be adjusted since it inherit from GPT
fast_llm/data/dataset/gpt/memmap.py
Outdated
spans: np.ndarray | ||
|
||
|
||
@dataclasses.dataclass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant with GPTSample (and breaks type hint)
fast_llm/data/dataset/gpt/memmap.py
Outdated
@@ -10,6 +11,18 @@ | |||
from fast_llm.utils import Assert, div | |||
|
|||
|
|||
@dataclasses.dataclass | |||
class GPTMemmapDocument: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GPTDocument (nothing to do with memmap), also not sure it belong in this file.
Anyway isn't this also the same as GPTSample
?
fast_llm/data/dataset/gpt/memmap.py
Outdated
@@ -10,6 +11,18 @@ | |||
from fast_llm.utils import Assert, div | |||
|
|||
|
|||
@dataclasses.dataclass | |||
class GPTMemmapDocument: | |||
text: np.ndarray |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this ids?
fast_llm/data/dataset/gpt/random.py
Outdated
end = np.random.RandomState(np_seed).randint(start, len(ids)) | ||
spans.append([start, end]) | ||
prev_end = end | ||
return GPTSample(ids=ids, spans=np.array(spans, dtype=np.int32).reshape(-1, 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spans really need to be None here, though we could add a config parameter to generate random spans.
fast_llm/functional/cross_entropy.py
Outdated
logits, target, grad_output, logits_scale_factor=logits_scale_factor | ||
) | ||
if grad_logits is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to go inside the implementation because each of them can be optimized in its own way. torch implementation has ignore_index
already, compiled version can include this inside the compile block, and triton kernels can include masking. For the triton part you can keep this if you don't know triton (it's a really easy one though).
Also torch.where
would do a better job here.
(And as usual, masking needs to be opt-in)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(And as usual, masking needs to be opt-in)
Do you mean an additional flag indicating whether loss masking should take place (using the config option for reading spans)?
I'm not clear why ignore_index
isn't sufficient since it wouldn't be set without the spans config flag anyway
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. ignore_index
isn't sufficient because it would slow things down when not in use.
fast_llm/data/data/gpt/data.py
Outdated
@@ -23,6 +26,19 @@ | |||
logger = logging.getLogger(__name__) | |||
|
|||
|
|||
@dataclasses.dataclass | |||
class GPTDataBatch: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GPTBatch
is enough
@@ -82,8 +83,8 @@ def get_test_data_and_samples( | |||
batch_config.setup(distributed_config) | |||
batch_config.validate() | |||
samples = { | |||
phase: [batch[0] for batch in data.get_iterator(batch_config, phase, consumed_samples=0, num_workers=0)] | |||
for phase, samples in samples_per_phase.items() | |||
phase: list(data.get_iterator(batch_config, phase, consumed_samples=consumed_samples, num_workers=0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes the existing tests too complicated. Instead, Please test spans with a small number of separate tests specifically targeting them. (Not sure we need full coverage for all cases, or you could make one complicated test case that indirectly test many classes.)
@@ -36,6 +36,9 @@ def get_document_sizes(self) -> np.ndarray: | |||
# TODO: This can be really big. | |||
return self._dataset.get_document_sizes()[self._begin : self._end] | |||
|
|||
def get_span_sizes(self) -> np.ndarray: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this only used in the tests? If so I'm not sure it's worth making a public method at this stage.
(And would need to be added to GPTIndexedDataset too)
✨ Description
Support loss masking for spans specified in the input data. This PR will ensure that loss will not be computed on the specified spans. The biggest use-case for this is instruction tuning data where we want to avoid training on the prompts.
Closes #109
📝 Changes
List the key changes introduced in this PR:
🔍 Type of change