Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use time.tokens for speedmonitor instead of dataset length #2762

Merged
24 changes: 11 additions & 13 deletions composer/callbacks/speed_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ class SpeedMonitor(Callback):
+-------------------------------------+-----------------------------------------------------------+
| | Rolling average (over `window_size` most recent |
| `throughput/tokens_per_sec` | batches) of the number of tokens processed per second. |
| | Only logged when dataloader.dataset has `max_seq_len`. |
| | This may include padding depending on dataset |
| | Only logged if dataspec returns tokens per batch |
+-------------------------------------+-----------------------------------------------------------+
| | Estimates flops by `flops_per_batch * batches_per_sec` |
| `throughput/flops_per_sec` | if model has attribute `flops_per_batch` |
Expand All @@ -186,8 +185,8 @@ class SpeedMonitor(Callback):
| `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size |
+-------------------------------------+-----------------------------------------------------------+
| | `throughput/tokens_per_sec` divided by world size. Only |
| `throughput/device/tokens_per_sec` | logged when dataloader.dataset has `max_seq_len`. This |
| | may include pad tokens depending on dataset |
| `throughput/device/tokens_per_sec` | logged if dataspec returns tokens per batch |
| | |
+-------------------------------------+-----------------------------------------------------------+
| | `throughput/flops_per_sec` divided by world size. Only |
| `throughput/device/flops_per_sec` | logged when model has attribute `flops_per_batch` |
Expand Down Expand Up @@ -222,6 +221,7 @@ def __init__(
):
# Track the batch num samples and wct to compute throughput over a window of batches
self.history_samples: Deque[int] = deque(maxlen=window_size + 1)
self.history_tokens: Deque[int] = deque(maxlen=window_size + 1)
self.history_wct: Deque[float] = deque(maxlen=window_size + 1)
self.history_flops: Deque[float] = deque(maxlen=window_size + 1)

Expand Down Expand Up @@ -259,13 +259,15 @@ def init(self, state: State, logger: Logger) -> None:
def batch_end(self, state: State, logger: Logger):
# Add the new element
self.history_samples.append(state.timestamp.sample.value)
self.history_tokens.append(state.timestamp.token.value)
self.history_wct.append(state.timestamp.total_wct.total_seconds())

# Log the throughput
if len(self.history_wct) == self.history_wct.maxlen:
world_size = dist.get_world_size()
elapsed_batches = len(self.history_samples) - 1
elapsed_samples = int(self.history_samples[-1]) - int(self.history_samples[0])
elapsed_tokens = int(self.history_tokens[-1]) - int(self.history_tokens[0])
elapsed_wct = self.history_wct[-1] - self.history_wct[0]
batches_per_sec = elapsed_batches / elapsed_wct
samples_per_sec = elapsed_samples / elapsed_wct
Expand All @@ -277,17 +279,13 @@ def batch_end(self, state: State, logger: Logger):
'throughput/device/batches_per_sec': dev_batches_per_sec,
'throughput/device/samples_per_sec': dev_samples_per_sec,
})

# Compute token stats if dataloader.dataset has max_seq_len. Assumes no padding.
try:
max_seq_len = state.dataloader.dataset.max_seq_len # type: ignore
# Only applicable to seq data / models
if elapsed_tokens > 0:
tokens_per_sec = elapsed_tokens / elapsed_wct
dev_tokens_per_sec = tokens_per_sec / world_size
logger.log_metrics({
'throughput/tokens_per_sec': samples_per_sec * max_seq_len,
'throughput/device/tokens_per_sec': dev_samples_per_sec * max_seq_len,
'throughput/tokens_per_sec': tokens_per_sec,
'throughput/device/tokens_per_sec': dev_tokens_per_sec,
})
except AttributeError:
pass

# Compute flops stats if model has flops_per_batch
composer_model = state.model
Expand Down
49 changes: 48 additions & 1 deletion tests/callbacks/test_speed_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from composer.core import Time
from composer.loggers import InMemoryLogger
from composer.trainer import Trainer
from tests.common import RandomClassificationDataset, SimpleModel
from tests.common import RandomClassificationDataset, SimpleModel, SimpleTransformerClassifier
from tests.common.datasets import dummy_text_classification_dataloader


def _assert_no_negative_values(logged_values):
Expand Down Expand Up @@ -53,6 +54,8 @@ def test_speed_monitor(flops_per_batch: bool):
_assert_no_negative_values(in_memory_logger.data['throughput/samples_per_sec'])
_assert_no_negative_values(in_memory_logger.data['throughput/device/batches_per_sec'])
_assert_no_negative_values(in_memory_logger.data['throughput/device/samples_per_sec'])
assert 'throughput/tokens_per_sec' not in in_memory_logger.data
assert 'throughput/device/tokens_per_sec' not in in_memory_logger.data
if flops_per_batch:
_assert_no_negative_values(in_memory_logger.data['throughput/flops_per_sec'])
_assert_no_negative_values(in_memory_logger.data['throughput/device/flops_per_sec'])
Expand All @@ -73,3 +76,47 @@ def test_speed_monitor(flops_per_batch: bool):
assert len(in_memory_logger.data['time/total']) == num_batches
assert len(in_memory_logger.data['time/train']) == num_batches
assert len(in_memory_logger.data['time/val']) == num_batches


def test_speed_monitor_tokens():
model = SimpleTransformerClassifier()
dataloader = dummy_text_classification_dataloader()
dataloader.dataset.max_seq_len = dataloader.dataset.sequence_length # type: ignore
in_memory_logger = InMemoryLogger() # track the logged metrics in the in_memory_logger
speed_monitor = SpeedMonitor(window_size=1)
trainer = Trainer(
model=model,
train_dataloader=dataloader,
callbacks=speed_monitor,
loggers=in_memory_logger,
max_duration='1ep',
)
trainer.fit()

print(list(in_memory_logger.data.keys()))

_assert_no_negative_values(in_memory_logger.data['time/train'])
_assert_no_negative_values(in_memory_logger.data['time/val'])
_assert_no_negative_values(in_memory_logger.data['time/total'])
_assert_no_negative_values(in_memory_logger.data['throughput/batches_per_sec'])
_assert_no_negative_values(in_memory_logger.data['throughput/samples_per_sec'])
_assert_no_negative_values(in_memory_logger.data['throughput/tokens_per_sec'])
_assert_no_negative_values(in_memory_logger.data['throughput/device/batches_per_sec'])
_assert_no_negative_values(in_memory_logger.data['throughput/device/samples_per_sec'])
_assert_no_negative_values(in_memory_logger.data['throughput/device/tokens_per_sec'])

assert isinstance(trainer.state.dataloader, collections.abc.Sized)
assert trainer.state.dataloader_label is not None
assert trainer.state.dataloader_len is not None
expected_step_calls = (trainer.state.dataloader_len - len(speed_monitor.history_samples) + 1) * int(
trainer.state.timestamp.epoch)
assert len(in_memory_logger.data['throughput/batches_per_sec']) == expected_step_calls
assert len(in_memory_logger.data['throughput/samples_per_sec']) == expected_step_calls
assert len(in_memory_logger.data['throughput/tokens_per_sec']) == expected_step_calls
assert len(in_memory_logger.data['throughput/device/batches_per_sec']) == expected_step_calls
assert len(in_memory_logger.data['throughput/device/samples_per_sec']) == expected_step_calls
assert len(in_memory_logger.data['throughput/device/tokens_per_sec']) == expected_step_calls
num_batches = int(trainer.state.timestamp.batch)
assert len(in_memory_logger.data['time/total']) == num_batches
assert len(in_memory_logger.data['time/train']) == num_batches
assert len(in_memory_logger.data['time/val']) == num_batches
14 changes: 7 additions & 7 deletions tests/datasets/test_in_context_learning_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,7 @@ def test_mc_task_evaluation(device, num_fewshot, dataset_uri, tiny_gpt2_tokenize
@device('gpu')
@world_size(1, 2)
@pytest.mark.parametrize('num_fewshot', [0, 5])
@pytest.mark.filterwarnings(r'ignore:The dataloader_len \(2\) is greater than the length.*:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning')
def test_qa_task_evaluation_opt_tokenizer(device, world_size, tiny_opt_tokenizer, tiny_opt_model, num_fewshot,
dataset_uri, tmp_path):
pytest.importorskip('datasets')
Expand Down Expand Up @@ -1167,7 +1167,7 @@ def test_qa_task_evaluation_opt_tokenizer(device, world_size, tiny_opt_tokenizer
@device('gpu')
@world_size(1, 2)
@pytest.mark.parametrize('num_fewshot', [5])
@pytest.mark.filterwarnings(r'ignore:The dataloader_len \(2\) is greater than the length.*:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning')
def test_qa_task_evaluation_with_cot_opt_tokenizer(device, world_size, tiny_opt_tokenizer, tiny_opt_model, num_fewshot,
dataset_uri, tmp_path):
pytest.importorskip('datasets')
Expand Down Expand Up @@ -1212,7 +1212,7 @@ def test_qa_task_evaluation_with_cot_opt_tokenizer(device, world_size, tiny_opt_
@device('gpu')
@world_size(1, 2)
@pytest.mark.parametrize('num_fewshot', [0, 5])
@pytest.mark.filterwarnings(r'ignore:The dataloader_len \(2\) is greater than the length.*:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning')
def test_qa_task_evaluation(device, world_size, num_fewshot, dataset_uri, tiny_gpt2_tokenizer, tiny_gpt2_model,
tmp_path):
pytest.importorskip('datasets')
Expand Down Expand Up @@ -1256,7 +1256,7 @@ def test_qa_task_evaluation(device, world_size, num_fewshot, dataset_uri, tiny_g
@device('gpu')
@world_size(1, 2)
@pytest.mark.parametrize('num_fewshot', [5])
@pytest.mark.filterwarnings(r'ignore:The dataloader_len \(2\) is greater than the length.*:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning')
def test_qa_task_with_cot_evaluation(device, world_size, num_fewshot, dataset_uri, tiny_gpt2_tokenizer, tiny_gpt2_model,
tmp_path):
pytest.importorskip('datasets')
Expand Down Expand Up @@ -1314,7 +1314,7 @@ def test_code_eval_requires_valid_envvar(monkeypatch):
@world_size(1, 2)
@pytest.mark.parametrize('num_fewshot', [0])
@pytest.mark.parametrize('generations_per_sample', [1, 2])
@pytest.mark.filterwarnings(r'ignore:The dataloader_len \(2\) is greater than the length.*:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning')
def test_code_eval_microbatching(monkeypatch, device, world_size, tiny_opt_tokenizer, tiny_opt_model, num_fewshot,
dataset_uri, tmp_path, generations_per_sample):
pytest.importorskip('datasets')
Expand Down Expand Up @@ -1365,7 +1365,7 @@ def test_code_eval_microbatching(monkeypatch, device, world_size, tiny_opt_token
@world_size(1, 2)
@pytest.mark.parametrize('num_fewshot', [0])
@pytest.mark.parametrize('generations_per_sample', [1, 2])
@pytest.mark.filterwarnings(r'ignore:The dataloader_len \(2\) is greater than the length.*:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning')
def test_code_eval_sentpiece_evaluation(monkeypatch, device, world_size, num_fewshot, dataset_uri, tiny_t5_tokenizer,
tiny_t5_model, tmp_path, generations_per_sample):
pytest.importorskip('datasets')
Expand Down Expand Up @@ -1413,7 +1413,7 @@ def test_code_eval_sentpiece_evaluation(monkeypatch, device, world_size, num_few
@pytest.mark.parametrize('num_fewshot', [0, 2])
@pytest.mark.parametrize('generations_per_sample', [1])
@pytest.mark.filterwarnings(r'ignore: Input length of input_ids is')
@pytest.mark.filterwarnings(r'ignore:The dataloader_len \(2\) is greater than the length.*:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning')
def test_code_eval_task_evaluation(monkeypatch, device, world_size, num_fewshot, dataset_uri, tiny_gpt2_tokenizer,
tiny_gpt2_model, tmp_path, generations_per_sample):
pytest.importorskip('datasets')
Expand Down
Loading