Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow loss masking for defined spans of characters #113

Open
wants to merge 25 commits into
base: main
Choose a base branch
from

Conversation

sohamparikh
Copy link
Member

@sohamparikh sohamparikh commented Jan 14, 2025

✨ Description

Support loss masking for spans specified in the input data. This PR will ensure that loss will not be computed on the specified spans. The biggest use-case for this is instruction tuning data where we want to avoid training on the prompts.

Closes #109

📝 Changes

List the key changes introduced in this PR:

  • Support character spans as inputs specified in the prepare command
  • Read the spans during training and apply masks to cross-entropy loss

🔍 Type of change

  • 🐛 Bug fix (non-breaking change that addresses a specific issue)
  • 🚀 New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • 📈 Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • 🛠️ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • 📦 Dependency bump (updates dependencies, including Dockerfile or package changes)
  • 📝 Documentation change (updates documentation, including new content or typo fixes)
  • 🔧 Infrastructure/Build change (affects build process, CI/CD, or dependencies)

@jlamypoirier
Copy link
Collaborator

Looks good so far, but can you please add a short description and/or point to an issue?

@sohamparikh sohamparikh changed the title convert character spans to token spans Allow loss masking for defined spans of characters Jan 24, 2025
for start, end in char_spans:
if char_pos < start:
curr_text = text[char_pos:start]
tokenized_text = self._tokenizer.tokenize(curr_text, add_special_tokens=beginning_of_text)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works only for those tokenizers that only have a BOS but not a EOS token.
For those that come with both, can we control whether tokenize adds the BOS and EOS tokens independently? I'm worried that we are adding the EOS token at the end of the first segment and the BOS token at the beginning of the last segment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! I'll make it explicitly add BOS only for the first segment
Btw, most tokenizers (Llama-3.1, Mistral-Nemo-Base-2407, OLMoE-1B-7B-0924) do not add the EOS token with add_special_tokens=True. Does this mean we've been training the models without the EOS token?

In the future I think we should make this config driven. The default behaviour would be to add both BOS and EOS tokens. It's important for pretraining with attention mask, and especially for SFT.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean we've been training the models without the EOS token?

Indeed, we decided that adding both BOS and EOS tokens in pretraining was unnecessary, because they are redundant. Here though I think we need to add them because we need to teach the model to terminate a response with the EOS token so that generation can stop at the right moment. Btw, I think HF is not adding the EOS token by default because otherwise prompts would end with it.

exp_logits1 = exp_logits.scatter(
1, target, exp_logits.gather(1, target) - target_mask * sum_exp_logits.unsqueeze(dim=-1)
)
exp_logits2 = exp_logits1.mul((grad_output / logits.size(0)) / sum_exp_logits.unsqueeze(dim=-1))
if logits_scale_factor != 1.0:
exp_logits2 *= logits_scale_factor

grad = exp_logits2.to(logits.dtype)
grad.index_put_((mask,), exp_logits2.to(logits.dtype))

predicted_logits = (target_mask * logits_norm.gather(1, target)).squeeze(1)
all_reduce(predicted_logits, op=ReduceOp.SUM, group=group)
Copy link
Collaborator

@tscholak tscholak Jan 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the triton implementation support masking?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it doesn't: https://github.com/ServiceNow/Fast-LLM/blob/soham/loss-masking-spans/fast_llm/functional/triton/cross_entropy.py
We need to add it. Since this is the same for all loss functions, it would make sense to implement it before dispatching to specialized cross-entropy implementations:

def cross_entropy_forward_backward(
    logits,
    target,
    grad_output: float | None,
    group: ProcessGroup | None,
    implementation: CrossEntropyImpl = CrossEntropyImpl.fused,
    logits_scale_factor: float = 1.0,
    ignore_index: int=-100,
) -> tuple[torch.Tensor, torch.Tensor | None]:
    ...
    mask = target != ignore_index
    target = target[mask]
    logits = logits[mask]
    ...

@sohamparikh sohamparikh marked this pull request as ready for review January 28, 2025 08:19
@sohamparikh sohamparikh marked this pull request as draft January 28, 2025 08:29

assert sample.shape[0] == sample_len
assert sample.ids.shape[0] == sample_len
return sample
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this code will change the order of tokens in the sequence, we would need to change the masks accordingly to allow for FIM with loss masking.
At this point, I think we should not and fail if FIM was used with loss masking.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

throwing an error statement in this function now

end = np.random.RandomState(np_seed).randint(start, len(ids))
spans.append([start, end])
prev_end = end
return GPTSample(ids=ids, spans=np.array(spans, dtype=np.int32).reshape(-1, 2))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a nice addition, though I'm not sure if the random dataset actually needs spans... We use this only for benchmarking purposes to measure training performance without being IO bound.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it for the tests. Can make the spans empty if you think it'll mess with the benchmark numbers

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spans really need to be None here, though we could add a config parameter to generate random spans.

Copy link
Collaborator

@tscholak tscholak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @sohamparikh, nice progress!
Is anything functional missing from this PR? Do the spans make it all the way to the loss function, and does packing work as expected with spans? Can we test this?

@tscholak tscholak requested a review from jlamypoirier January 28, 2025 20:34
@sohamparikh
Copy link
Member Author

sohamparikh commented Jan 28, 2025

Functionally it's good to go now.
I've tested prepare and train on SmolLM2-135M on a single GPU using a dummy dataset with spans. Seems to be working fine, including packing and the loss functions.

How do we want to test this? I can test it on a bigger model with multiple nodes if that makes sense

@sohamparikh sohamparikh marked this pull request as ready for review January 28, 2025 22:33
Copy link
Collaborator

@jlamypoirier jlamypoirier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed the overall structure, looks good for the intended purpose but will need polishing. Main areas to focus on (see individual comments):

  • Spans need to be opt-in, so that there is a negligible impact when they are not used (which is most of the time).
  • Turning samples and friends into dataclasses is a good idea but goes a bit outside the present scope. Ideally it would go in a separate PR, but it's ok to include here if done well (see coments)
  • Names need to follow our style guide (https://servicenow.github.io/Fast-LLM/contributing/style-guide/) a bit better. Please use self-descriptive names as much as possible, ex. spans is a bit cryptic and could mean lots of things. (Not sure what to replace it with, but it should hint to loss masking in some way)

@@ -34,3 +41,8 @@ class TokenizerConfig(Config):
desc="Path to the tokenizer file.",
hint=FieldHint.core,
)
special_tokens_mode: SpecialTokensMode = Field(
Copy link
Collaborator

@jlamypoirier jlamypoirier Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look self-descriptive enough. Could we think of a better name, and improve the description? (also not sure what this does?)

@@ -21,12 +22,17 @@
logger = logging.getLogger(__name__)


@dataclasses.dataclass
class GPTSample:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a nice addition, but will have an impact far beyond the scope of the current PR, so we need to do it carefully (same for other similar dataclasses):

  • Use meaningful field names (ids->token_ids, spans->?)
  • If we're going the dataclass way we should go for it all the way, i.e. inherit from a Sample base class and adjust type hints everywhere. Same for the batch thing. No need to do it in this PR, but if not we'll need an issue to refer to.
  • The custom model also needs to be adjusted since it inherit from GPT

spans: np.ndarray


@dataclasses.dataclass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant with GPTSample (and breaks type hint)

@@ -10,6 +11,18 @@
from fast_llm.utils import Assert, div


@dataclasses.dataclass
class GPTMemmapDocument:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPTDocument (nothing to do with memmap), also not sure it belong in this file.
Anyway isn't this also the same as GPTSample?

@@ -10,6 +11,18 @@
from fast_llm.utils import Assert, div


@dataclasses.dataclass
class GPTMemmapDocument:
text: np.ndarray
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this ids?

end = np.random.RandomState(np_seed).randint(start, len(ids))
spans.append([start, end])
prev_end = end
return GPTSample(ids=ids, spans=np.array(spans, dtype=np.int32).reshape(-1, 2))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spans really need to be None here, though we could add a config parameter to generate random spans.

logits, target, grad_output, logits_scale_factor=logits_scale_factor
)
if grad_logits is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to go inside the implementation because each of them can be optimized in its own way. torch implementation has ignore_index already, compiled version can include this inside the compile block, and triton kernels can include masking. For the triton part you can keep this if you don't know triton (it's a really easy one though).

Also torch.where would do a better job here.

(And as usual, masking needs to be opt-in)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(And as usual, masking needs to be opt-in)

Do you mean an additional flag indicating whether loss masking should take place (using the config option for reading spans)?
I'm not clear why ignore_index isn't sufficient since it wouldn't be set without the spans config flag anyway

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. ignore_index isn't sufficient because it would slow things down when not in use.

@@ -23,6 +26,19 @@
logger = logging.getLogger(__name__)


@dataclasses.dataclass
class GPTDataBatch:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPTBatch is enough

@@ -82,8 +83,8 @@ def get_test_data_and_samples(
batch_config.setup(distributed_config)
batch_config.validate()
samples = {
phase: [batch[0] for batch in data.get_iterator(batch_config, phase, consumed_samples=0, num_workers=0)]
for phase, samples in samples_per_phase.items()
phase: list(data.get_iterator(batch_config, phase, consumed_samples=consumed_samples, num_workers=0))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes the existing tests too complicated. Instead, Please test spans with a small number of separate tests specifically targeting them. (Not sure we need full coverage for all cases, or you could make one complicated test case that indirectly test many classes.)

@@ -36,6 +36,9 @@ def get_document_sizes(self) -> np.ndarray:
# TODO: This can be really big.
return self._dataset.get_document_sizes()[self._begin : self._end]

def get_span_sizes(self) -> np.ndarray:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this only used in the tests? If so I'm not sure it's worth making a public method at this stage.
(And would need to be added to GPTIndexedDataset too)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[feat] Implement Loss Masking to Exclude Predefined Token Spans from LM Loss
3 participants