Skip to content

Commit

Permalink
Replace prefetch with val iterator check in megatron models (NVIDIA#7318
Browse files Browse the repository at this point in the history
)

* Add counter for num_microbatches

Signed-off-by: Abhishree <[email protected]>

* Reset self.total_val_micro_batches

Signed-off-by: Abhishree <[email protected]>

* Replace _prefetch() with _val_iterator_done()

Signed-off-by: Abhishree <[email protected]>

* Override limit_val_batches for pretraining models

Signed-off-by: Abhishree <[email protected]>

* Return iterator in _val_iterator_done when iterator is not exhuasted

Signed-off-by: Abhishree <[email protected]>

* Temporarily comment BioMegatron Bert CI test

Signed-off-by: Abhishree <[email protected]>

* Move _reconfigure_val_batches() to MegatronGPTModel

Signed-off-by: Abhishree <[email protected]>

* Move self_reconfigure_val_batches to build_train_valid_test_datasets

Signed-off-by: Abhishree <[email protected]>

* Avoid fetching and reinserting back to the iterator

Signed-off-by: Abhishree <[email protected]>

* Increase limit_val_batches in CI tests

Signed-off-by: Abhishree <[email protected]>

* Use _val_iterator_done to check for iterator end in all megatron models

1) Remove if condition, self._val_micro_batches_consumed in def _val_iterator_done and check with just try(and reinsert), except
2) Use _val_iterator_done in all megatron models that use dataloader_iter to maintain uniformity

Signed-off-by: Abhishree <[email protected]>

* Minor edit to return outside of try block

Signed-off-by: Abhishree <[email protected]>

* Add _val_iterator_done for megatron_nmt_model

Signed-off-by: Abhishree <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Abhishree <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <[email protected]>
  • Loading branch information
3 people authored Aug 29, 2023
1 parent 22e61ca commit 2d830b5
Show file tree
Hide file tree
Showing 11 changed files with 181 additions and 172 deletions.
3 changes: 3 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -4668,6 +4668,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
trainer.accelerator=gpu \
trainer.log_every_n_steps=1 \
trainer.val_check_interval=2 \
trainer.limit_val_batches=5 \
trainer.accumulate_grad_batches=1 \
trainer.max_steps=6 \
trainer.precision=16 \
Expand Down Expand Up @@ -4853,6 +4854,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
steps {
sh "python examples/nlp/language_modeling/megatron_gpt_pretraining.py \
trainer.max_steps=10 \
trainer.limit_val_batches=7 \
trainer.val_check_interval=10 \
exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \
model.data.data_impl=mock \
Expand All @@ -4865,6 +4867,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
steps {
sh "python examples/nlp/language_modeling/megatron_t5_pretraining.py \
trainer.max_steps=10 \
trainer.limit_val_batches=3 \
trainer.val_check_interval=10 \
exp_manager.exp_dir=examples/nlp/language_modeling/t5_pretrain_results \
model.data.data_impl=mock \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,15 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):
gc.disable()
self.validation_global_step = 1

def _reconfigure_val_batches(self):
"""
Reconfigure trainer.limit_val_batches for pretraining
"""
# Override limit_val_batches to be a multiple of num microbatches and so there are limit_val_batches//num_micro_batches num of global batches
self.trainer.limit_val_batches *= get_num_microbatches()
# Override num sanity steps equal to num of microbatches and perform one val_step
self.trainer.num_sanity_val_steps = get_num_microbatches()

def _enable_nvidia_optimizations(self):
"These optimizations are present in NVIDIA NGC PyTorch Containers"

Expand Down Expand Up @@ -805,19 +814,13 @@ def build_model_parallel_config(self) -> ModelParallelConfig:

return model_parallel_config

def _prefetch(self, iterator):
"""Checks if the iterator still has elements to return.
Used in models using dataloader_iter to prefetch the next batch before fwd_bwd func
is called to avoid PP rank 2 from wait indefinitely to get outpits from PP 1
def _val_iterator_done(self, iterator):
"""
elements = []
num_microbatches = get_num_microbatches()
for _ in range(num_microbatches):
try:
element = next(iterator)
elements.append(element)
except StopIteration:
return iterator, True

# return a new iterator with the prefetched element reinserted at the front
return itertools.chain(elements, iterator), False
Check if the iterator is exhausted, if so raise a StopIteration and exit validation_step
"""
try:
element = next(iterator)
except StopIteration:
return iterator, True
# reinsert the element back to the iterator
return itertools.chain([element], iterator), False
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,8 @@ def allreduce_first_last_embeddings(self):
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())

def validation_step(self, dataloader_iter, batch_idx):
# Prefetch the dataloader_iter before fwd_bwd func to avoid PP rank 2 from waiting indefinitely when PP rank 1 reaches the end of dataloader_iter
dataloader_iter, done = self._prefetch(dataloader_iter)
# Check if iterator is exhausted
dataloader_iter, done = self._val_iterator_done(dataloader_iter)
if done:
return
prefix = "test" if self.trainer.testing else "val"
Expand Down Expand Up @@ -588,6 +588,8 @@ def build_LDDL_data(self, cfg):
logging.info(f'Finished building LDDL Dataloaders')

def build_train_valid_test_datasets(self):
# Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step
self._reconfigure_val_batches()
logging.info('Building Bert datasets.')
if self.trainer.limit_val_batches > 1.0 and isinstance(self.trainer.limit_val_batches, float):
raise ValueError("limit_val_batches must be an integer or float less than or equal to 1.0.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,77 +298,77 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
)

def inference_step(self, dataloader_iter, batch_idx: int, mode: str, dataloader_idx=0):
# Add try except since dataloader_iter in PTL 2.0 doesnt catch the end of the iterator
try:
# Regular finetuning datasets will return a list of dicts for each microbatch.
# But T0 datasets will return a single dict for the global batch.
batch = next(dataloader_iter)
batch_has_lang_information = isinstance(batch, list) and len(batch[0]) == 7
data_cfg = self.cfg.data.validation_ds if mode == 'validation' else self.cfg.data.test_ds

self._reconfigure_and_process_inference_batch(batch, data_cfg)

# NOTE: There could be extra keys in the processed_batch dictionary such as "langs" for XNLI,
# this will be ignored.
loss = self.fwd_bwd_step(itertools.chain([batch]), batch_idx, forward_only=True)

predicted_token_ids, _ = self.decode(
tokens_enc=batch['text_enc'],
enc_mask=batch['enc_mask'],
num_tokens_to_generate=30,
bos_id=self.tokenizer.pad_id if data_cfg.get('replace_bos_with_pad', False) else self.tokenizer.bos_id,
)
# Check if iterator is exhausted
dataloader_iter, done = self._val_iterator_done(dataloader_iter)
if done:
return
# Regular finetuning datasets will return a list of dicts for each microbatch.
# But T0 datasets will return a single dict for the global batch.
batch = next(dataloader_iter)
batch_has_lang_information = isinstance(batch, list) and len(batch[0]) == 7
data_cfg = self.cfg.data.validation_ds if mode == 'validation' else self.cfg.data.test_ds

# Special ids to text function to handle stripping <eos> and special tokens with sentencepiece tokenizers.
preds_text = MegatronT5FinetuneModel.ids_to_text(predicted_token_ids, self.tokenizer)
labels_text = MegatronT5FinetuneModel.ids_to_text(batch['labels'], self.tokenizer)
input_text = MegatronT5FinetuneModel.ids_to_text(batch['text_enc'], self.tokenizer)
self._reconfigure_and_process_inference_batch(batch, data_cfg)

if not batch_has_lang_information:
categories = [None] * len(preds_text)
# NOTE: There could be extra keys in the processed_batch dictionary such as "langs" for XNLI,
# this will be ignored.
loss = self.fwd_bwd_step(itertools.chain([batch]), batch_idx, forward_only=True)

predicted_token_ids, _ = self.decode(
tokens_enc=batch['text_enc'],
enc_mask=batch['enc_mask'],
num_tokens_to_generate=30,
bos_id=self.tokenizer.pad_id if data_cfg.get('replace_bos_with_pad', False) else self.tokenizer.bos_id,
)

# Special ids to text function to handle stripping <eos> and special tokens with sentencepiece tokenizers.
preds_text = MegatronT5FinetuneModel.ids_to_text(predicted_token_ids, self.tokenizer)
labels_text = MegatronT5FinetuneModel.ids_to_text(batch['labels'], self.tokenizer)
input_text = MegatronT5FinetuneModel.ids_to_text(batch['text_enc'], self.tokenizer)

if not batch_has_lang_information:
categories = [None] * len(preds_text)
else:
categories = batch['lang']

metric = self.val_metric[dataloader_idx] if mode == 'validation' else self.test_metric[dataloader_idx]
assert len(categories) == len(preds_text) == len(labels_text)
for _, (pred, label, category) in enumerate(zip(preds_text, labels_text, categories)):
# To compute metrics like pearson or spearman correlation, we need to cast the predicted string and labels to floats.
pred, label = self.cast_for_metric(
pred=pred,
label=label,
metric_name=self.val_metric_name if mode == 'validation' else self.test_metric_name,
class_labels=data_cfg.metric.get('class_labels', None),
labels_are_strings=data_cfg.metric.get('labels_are_strings', False),
)
if batch_has_lang_information:
_ = metric(pred, label, category)
else:
categories = batch['lang']

metric = self.val_metric[dataloader_idx] if mode == 'validation' else self.test_metric[dataloader_idx]
assert len(categories) == len(preds_text) == len(labels_text)
for _, (pred, label, category) in enumerate(zip(preds_text, labels_text, categories)):
# To compute metrics like pearson or spearman correlation, we need to cast the predicted string and labels to floats.
pred, label = self.cast_for_metric(
pred=pred,
label=label,
metric_name=self.val_metric_name if mode == 'validation' else self.test_metric_name,
class_labels=data_cfg.metric.get('class_labels', None),
labels_are_strings=data_cfg.metric.get('labels_are_strings', False),
)
if batch_has_lang_information:
_ = metric(pred, label, category)
else:
_ = metric(pred, label)
_ = metric(pred, label)

outputs = {
'preds': preds_text,
'labels': labels_text,
'categories': categories,
'inputs': input_text,
}
outputs = {
'preds': preds_text,
'labels': labels_text,
'categories': categories,
'inputs': input_text,
}

if isinstance(loss, dict):
outputs.update(loss)
if isinstance(loss, dict):
outputs.update(loss)
else:
outputs['loss'] = loss
if mode == 'validation':
if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(outputs)
else:
outputs['loss'] = loss
if mode == 'validation':
if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(outputs)
else:
self.validation_step_outputs.append(outputs)
self.validation_step_outputs.append(outputs)
else:
if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(outputs)
else:
if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(outputs)
else:
self.test_step_outputs.append(outputs)
return outputs
except StopIteration:
return
self.test_step_outputs.append(outputs)
return outputs

@classmethod
def ids_to_text(cls, batch_ids, tokenizer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -911,8 +911,8 @@ def validation_step(self, dataloader_iter, batch_idx):
from the dataloader to produce a list of microbatches.
The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions.
"""
# Prefetch the dataloader_iter before fwd_bwd func to avoid PP rank 2 from waiting indefinitely when PP rank 1 reaches the end of dataloader_iter
dataloader_iter, done = self._prefetch(dataloader_iter)
# Check if iterator is exhausted
dataloader_iter, done = self._val_iterator_done(dataloader_iter)
if done:
return
mode = 'test' if self.trainer.testing else 'val'
Expand Down Expand Up @@ -969,6 +969,8 @@ def loss_func(self, loss_mask, output_tensor):
return loss

def build_train_valid_test_datasets(self):
# Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step
self._reconfigure_val_batches()
logging.info('Building GPT datasets.')
if self.trainer.limit_val_batches > 1.0 and isinstance(self.trainer.limit_val_batches, float):
raise ValueError("limit_val_batches must be an integer or float less than or equal to 1.0.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,8 @@ def optimizer_zero_grad(self, *args, **kwargs):
return

def validation_step(self, dataloader_iter, batch_idx):
# Prefetch the dataloader_iter before fwd_bwd func to avoid PP rank 2 from waiting indefinitely when PP rank 1 reaches the end of dataloader_iter
dataloader_iter, done = self._prefetch(dataloader_iter)
# Check if iterator is exhausted
dataloader_iter, done = self._val_iterator_done(dataloader_iter)
if done:
return
mode = 'test' if self.trainer.testing else 'val'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,20 +385,17 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
return loss_mean

def validation_step(self, dataloader_iter, batch_idx, dataloader_idx=0):
# Add try except since dataloader_iter in PTL 2.0 doesnt catch the end of iterables
try:
return self.inference_step(dataloader_iter, batch_idx, 'validation', dataloader_idx)
except StopIteration:
return
return self.inference_step(dataloader_iter, batch_idx, 'validation', dataloader_idx)

def test_step(self, dataloader_iter, batch_idx, dataloader_idx=0):
# Add try except since dataloader_iter in PTL 2.0 doesnt catch the end of iterables
try:
return self.inference_step(dataloader_iter, batch_idx, 'test', dataloader_idx)
except StopIteration:
return
return self.inference_step(dataloader_iter, batch_idx, 'test', dataloader_idx)

def inference_step(self, dataloader_iter, batch_idx, mode, dataloader_idx=0):
# Check if iterator is exhausted
dataloader_iter, done = self._val_iterator_done(dataloader_iter)
if done:
return
batch = next(dataloader_iter)
data_cfg = self.cfg.data.validation_ds if mode == 'validation' else self.cfg.data.test_ds
self._reconfigure_and_process_inference_batch(batch, data_cfg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,8 @@ def _test_validation_step(self, step_outputs, dataloader_iter, batch_idx, datalo
"""
Shared code for validation and test step
"""
# Prefetch the dataloader_iter before fwd_bwd func to avoid PP rank 2 from waiting indefinitely with PP rank 1 reaches the end of dataloader_iter
dataloader_iter, done = self._prefetch(dataloader_iter)
# Check if iterator is exhausted
dataloader_iter, done = self._val_iterator_done(dataloader_iter)
if done:
return

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,28 +147,28 @@ def compute_accuracy(self, enc_input, enc_mask, encoder_input, labels):
}

def validation_step(self, dataloader_iter, batch_idx, inference=False):
# Add try except since dataloader_iter in PTL 2.0 doesnt catch the end of iterables
try:
batch = next(dataloader_iter)
enc_input, dec_input, labels, loss_mask, enc_mask, dec_mask, position_ids, taskname_ids = batch
# Check if iterator is exhausted
dataloader_iter, done = self._val_iterator_done(dataloader_iter)
if done:
return
batch = next(dataloader_iter)
enc_input, dec_input, labels, loss_mask, enc_mask, dec_mask, position_ids, taskname_ids = batch

mode = self.training
self.eval()
gbs = self.cfg.get('validation_global_batch_size', self.cfg.global_batch_size)
self._reconfigure_and_process_inference_batch(enc_input.size(0), gbs)
loss_mean = self.fwd_bwd_step(itertools.chain([batch]), batch_idx, forward_only=True)
mode = self.training
self.eval()
gbs = self.cfg.get('validation_global_batch_size', self.cfg.global_batch_size)
self._reconfigure_and_process_inference_batch(enc_input.size(0), gbs)
loss_mean = self.fwd_bwd_step(itertools.chain([batch]), batch_idx, forward_only=True)

if self.cfg.get('report_validation_metric', False):
metrics = self.compute_accuracy(enc_input, enc_mask, labels)
metrics['loss'] = loss_mean
else:
metrics = {'loss': loss_mean}
if self.cfg.get('report_validation_metric', False):
metrics = self.compute_accuracy(enc_input, enc_mask, labels)
metrics['loss'] = loss_mean
else:
metrics = {'loss': loss_mean}

self.validation_step_outputs.append(metrics)
self.train(mode=mode)
return metrics
except StopIteration:
return
self.validation_step_outputs.append(metrics)
self.train(mode=mode)
return metrics

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def add_special_tokens_to_tokenizer(
tokenizer.add_special_tokens([f'<extra_id_{mask_type}>'])

def build_train_valid_test_datasets(self):
# Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step
self._reconfigure_val_batches()
logging.info(f'Building {self.model_name} datasets.')
if self.trainer.limit_val_batches > 1.0 and isinstance(self.trainer.limit_val_batches, float):
raise ValueError("limit_val_batches must be an integer or float less than or equal to 1.0.")
Expand Down
Loading

0 comments on commit 2d830b5

Please sign in to comment.