From cf23f7a84b0a51b6eac949acd447d02f8edc7b76 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 5 Dec 2023 19:32:21 -0500 Subject: [PATCH 1/4] change token math --- composer/callbacks/speed_monitor.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/composer/callbacks/speed_monitor.py b/composer/callbacks/speed_monitor.py index f66d62c31b..b0b4a596ac 100644 --- a/composer/callbacks/speed_monitor.py +++ b/composer/callbacks/speed_monitor.py @@ -174,8 +174,7 @@ class SpeedMonitor(Callback): +-------------------------------------+-----------------------------------------------------------+ | | Rolling average (over `window_size` most recent | | `throughput/tokens_per_sec` | batches) of the number of tokens processed per second. | - | | Only logged when dataloader.dataset has `max_seq_len`. | - | | This may include padding depending on dataset | + | | Only logged if dataspec returns tokens per batch | +-------------------------------------+-----------------------------------------------------------+ | | Estimates flops by `flops_per_batch * batches_per_sec` | | `throughput/flops_per_sec` | if model has attribute `flops_per_batch` | @@ -186,8 +185,8 @@ class SpeedMonitor(Callback): | `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size | +-------------------------------------+-----------------------------------------------------------+ | | `throughput/tokens_per_sec` divided by world size. Only | - | `throughput/device/tokens_per_sec` | logged when dataloader.dataset has `max_seq_len`. This | - | | may include pad tokens depending on dataset | + | `throughput/device/tokens_per_sec` | logged if dataspec returns tokens per batch | + | | | +-------------------------------------+-----------------------------------------------------------+ | | `throughput/flops_per_sec` divided by world size. Only | | `throughput/device/flops_per_sec` | logged when model has attribute `flops_per_batch` | @@ -222,6 +221,7 @@ def __init__( ): # Track the batch num samples and wct to compute throughput over a window of batches self.history_samples: Deque[int] = deque(maxlen=window_size + 1) + self.history_tokens: Deque[int] = deque(maxlen=window_size + 1) self.history_wct: Deque[float] = deque(maxlen=window_size + 1) self.history_flops: Deque[float] = deque(maxlen=window_size + 1) @@ -259,6 +259,7 @@ def init(self, state: State, logger: Logger) -> None: def batch_end(self, state: State, logger: Logger): # Add the new element self.history_samples.append(state.timestamp.sample.value) + self.history_tokens.append(state.timestamp.sample.tokens) self.history_wct.append(state.timestamp.total_wct.total_seconds()) # Log the throughput @@ -266,6 +267,7 @@ def batch_end(self, state: State, logger: Logger): world_size = dist.get_world_size() elapsed_batches = len(self.history_samples) - 1 elapsed_samples = int(self.history_samples[-1]) - int(self.history_samples[0]) + elapsed_tokens = int(self.history_tokens[-1]) - int(self.history_tokens[0]) elapsed_wct = self.history_wct[-1] - self.history_wct[0] batches_per_sec = elapsed_batches / elapsed_wct samples_per_sec = elapsed_samples / elapsed_wct @@ -277,17 +279,13 @@ def batch_end(self, state: State, logger: Logger): 'throughput/device/batches_per_sec': dev_batches_per_sec, 'throughput/device/samples_per_sec': dev_samples_per_sec, }) - - # Compute token stats if dataloader.dataset has max_seq_len. Assumes no padding. - try: - max_seq_len = state.dataloader.dataset.max_seq_len # type: ignore - # Only applicable to seq data / models + if elapsed_tokens > 0: + tokens_per_sec = elapsed_tokens / elapsed_wct + dev_tokens_per_sec = tokens_per_sec / world_size logger.log_metrics({ - 'throughput/tokens_per_sec': samples_per_sec * max_seq_len, - 'throughput/device/tokens_per_sec': dev_samples_per_sec * max_seq_len, + 'throughput/tokens_per_sec': tokens_per_sec, + 'throughput/device/tokens_per_sec': dev_tokens_per_sec, }) - except AttributeError: - pass # Compute flops stats if model has flops_per_batch composer_model = state.model From b1b4ad99f0b64d7d808d03f10571e2cae5ce8a0f Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 5 Dec 2023 19:44:24 -0500 Subject: [PATCH 2/4] tokens --- composer/callbacks/speed_monitor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/callbacks/speed_monitor.py b/composer/callbacks/speed_monitor.py index b0b4a596ac..2b0eeedc80 100644 --- a/composer/callbacks/speed_monitor.py +++ b/composer/callbacks/speed_monitor.py @@ -259,7 +259,7 @@ def init(self, state: State, logger: Logger) -> None: def batch_end(self, state: State, logger: Logger): # Add the new element self.history_samples.append(state.timestamp.sample.value) - self.history_tokens.append(state.timestamp.sample.tokens) + self.history_tokens.append(state.timestamp.token.value) self.history_wct.append(state.timestamp.total_wct.total_seconds()) # Log the throughput From d8faaeb5931ae6d10dd3dd0bae98171115c6dc61 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 5 Dec 2023 20:14:54 -0500 Subject: [PATCH 3/4] add test --- tests/callbacks/test_speed_monitor.py | 49 ++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/tests/callbacks/test_speed_monitor.py b/tests/callbacks/test_speed_monitor.py index 037e5a2598..f880a7c370 100644 --- a/tests/callbacks/test_speed_monitor.py +++ b/tests/callbacks/test_speed_monitor.py @@ -11,7 +11,8 @@ from composer.core import Time from composer.loggers import InMemoryLogger from composer.trainer import Trainer -from tests.common import RandomClassificationDataset, SimpleModel +from tests.common import RandomClassificationDataset, SimpleModel, SimpleTransformerClassifier +from tests.common.datasets import dummy_text_classification_dataloader def _assert_no_negative_values(logged_values): @@ -53,6 +54,8 @@ def test_speed_monitor(flops_per_batch: bool): _assert_no_negative_values(in_memory_logger.data['throughput/samples_per_sec']) _assert_no_negative_values(in_memory_logger.data['throughput/device/batches_per_sec']) _assert_no_negative_values(in_memory_logger.data['throughput/device/samples_per_sec']) + assert 'throughput/tokens_per_sec' not in in_memory_logger.data + assert 'throughput/device/tokens_per_sec' not in in_memory_logger.data if flops_per_batch: _assert_no_negative_values(in_memory_logger.data['throughput/flops_per_sec']) _assert_no_negative_values(in_memory_logger.data['throughput/device/flops_per_sec']) @@ -73,3 +76,47 @@ def test_speed_monitor(flops_per_batch: bool): assert len(in_memory_logger.data['time/total']) == num_batches assert len(in_memory_logger.data['time/train']) == num_batches assert len(in_memory_logger.data['time/val']) == num_batches + + +def test_speed_monitor_tokens(): + model = SimpleTransformerClassifier() + dataloader = dummy_text_classification_dataloader() + dataloader.dataset.max_seq_len = dataloader.dataset.sequence_length # type: ignore + in_memory_logger = InMemoryLogger() # track the logged metrics in the in_memory_logger + speed_monitor = SpeedMonitor(window_size=1) + trainer = Trainer( + model=model, + train_dataloader=dataloader, + callbacks=speed_monitor, + loggers=in_memory_logger, + max_duration='1ep', + ) + trainer.fit() + + print(list(in_memory_logger.data.keys())) + + _assert_no_negative_values(in_memory_logger.data['time/train']) + _assert_no_negative_values(in_memory_logger.data['time/val']) + _assert_no_negative_values(in_memory_logger.data['time/total']) + _assert_no_negative_values(in_memory_logger.data['throughput/batches_per_sec']) + _assert_no_negative_values(in_memory_logger.data['throughput/samples_per_sec']) + _assert_no_negative_values(in_memory_logger.data['throughput/tokens_per_sec']) + _assert_no_negative_values(in_memory_logger.data['throughput/device/batches_per_sec']) + _assert_no_negative_values(in_memory_logger.data['throughput/device/samples_per_sec']) + _assert_no_negative_values(in_memory_logger.data['throughput/device/tokens_per_sec']) + + assert isinstance(trainer.state.dataloader, collections.abc.Sized) + assert trainer.state.dataloader_label is not None + assert trainer.state.dataloader_len is not None + expected_step_calls = (trainer.state.dataloader_len - len(speed_monitor.history_samples) + 1) * int( + trainer.state.timestamp.epoch) + assert len(in_memory_logger.data['throughput/batches_per_sec']) == expected_step_calls + assert len(in_memory_logger.data['throughput/samples_per_sec']) == expected_step_calls + assert len(in_memory_logger.data['throughput/tokens_per_sec']) == expected_step_calls + assert len(in_memory_logger.data['throughput/device/batches_per_sec']) == expected_step_calls + assert len(in_memory_logger.data['throughput/device/samples_per_sec']) == expected_step_calls + assert len(in_memory_logger.data['throughput/device/tokens_per_sec']) == expected_step_calls + num_batches = int(trainer.state.timestamp.batch) + assert len(in_memory_logger.data['time/total']) == num_batches + assert len(in_memory_logger.data['time/train']) == num_batches + assert len(in_memory_logger.data['time/val']) == num_batches From 7c382f1655be6c7b5e4d29e0c8eb454a114c63f9 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Thu, 7 Dec 2023 16:53:19 -0500 Subject: [PATCH 4/4] fix tests --- .../datasets/test_in_context_learning_datasets.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/datasets/test_in_context_learning_datasets.py b/tests/datasets/test_in_context_learning_datasets.py index 804bdb67e0..807a0d84e6 100644 --- a/tests/datasets/test_in_context_learning_datasets.py +++ b/tests/datasets/test_in_context_learning_datasets.py @@ -1123,7 +1123,7 @@ def test_mc_task_evaluation(device, num_fewshot, dataset_uri, tiny_gpt2_tokenize @device('gpu') @world_size(1, 2) @pytest.mark.parametrize('num_fewshot', [0, 5]) -@pytest.mark.filterwarnings(r'ignore:The dataloader_len \(2\) is greater than the length.*:UserWarning') +@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning') def test_qa_task_evaluation_opt_tokenizer(device, world_size, tiny_opt_tokenizer, tiny_opt_model, num_fewshot, dataset_uri, tmp_path): pytest.importorskip('datasets') @@ -1167,7 +1167,7 @@ def test_qa_task_evaluation_opt_tokenizer(device, world_size, tiny_opt_tokenizer @device('gpu') @world_size(1, 2) @pytest.mark.parametrize('num_fewshot', [5]) -@pytest.mark.filterwarnings(r'ignore:The dataloader_len \(2\) is greater than the length.*:UserWarning') +@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning') def test_qa_task_evaluation_with_cot_opt_tokenizer(device, world_size, tiny_opt_tokenizer, tiny_opt_model, num_fewshot, dataset_uri, tmp_path): pytest.importorskip('datasets') @@ -1212,7 +1212,7 @@ def test_qa_task_evaluation_with_cot_opt_tokenizer(device, world_size, tiny_opt_ @device('gpu') @world_size(1, 2) @pytest.mark.parametrize('num_fewshot', [0, 5]) -@pytest.mark.filterwarnings(r'ignore:The dataloader_len \(2\) is greater than the length.*:UserWarning') +@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning') def test_qa_task_evaluation(device, world_size, num_fewshot, dataset_uri, tiny_gpt2_tokenizer, tiny_gpt2_model, tmp_path): pytest.importorskip('datasets') @@ -1256,7 +1256,7 @@ def test_qa_task_evaluation(device, world_size, num_fewshot, dataset_uri, tiny_g @device('gpu') @world_size(1, 2) @pytest.mark.parametrize('num_fewshot', [5]) -@pytest.mark.filterwarnings(r'ignore:The dataloader_len \(2\) is greater than the length.*:UserWarning') +@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning') def test_qa_task_with_cot_evaluation(device, world_size, num_fewshot, dataset_uri, tiny_gpt2_tokenizer, tiny_gpt2_model, tmp_path): pytest.importorskip('datasets') @@ -1314,7 +1314,7 @@ def test_code_eval_requires_valid_envvar(monkeypatch): @world_size(1, 2) @pytest.mark.parametrize('num_fewshot', [0]) @pytest.mark.parametrize('generations_per_sample', [1, 2]) -@pytest.mark.filterwarnings(r'ignore:The dataloader_len \(2\) is greater than the length.*:UserWarning') +@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning') def test_code_eval_microbatching(monkeypatch, device, world_size, tiny_opt_tokenizer, tiny_opt_model, num_fewshot, dataset_uri, tmp_path, generations_per_sample): pytest.importorskip('datasets') @@ -1365,7 +1365,7 @@ def test_code_eval_microbatching(monkeypatch, device, world_size, tiny_opt_token @world_size(1, 2) @pytest.mark.parametrize('num_fewshot', [0]) @pytest.mark.parametrize('generations_per_sample', [1, 2]) -@pytest.mark.filterwarnings(r'ignore:The dataloader_len \(2\) is greater than the length.*:UserWarning') +@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning') def test_code_eval_sentpiece_evaluation(monkeypatch, device, world_size, num_fewshot, dataset_uri, tiny_t5_tokenizer, tiny_t5_model, tmp_path, generations_per_sample): pytest.importorskip('datasets') @@ -1413,7 +1413,7 @@ def test_code_eval_sentpiece_evaluation(monkeypatch, device, world_size, num_few @pytest.mark.parametrize('num_fewshot', [0, 2]) @pytest.mark.parametrize('generations_per_sample', [1]) @pytest.mark.filterwarnings(r'ignore: Input length of input_ids is') -@pytest.mark.filterwarnings(r'ignore:The dataloader_len \(2\) is greater than the length.*:UserWarning') +@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning') def test_code_eval_task_evaluation(monkeypatch, device, world_size, num_fewshot, dataset_uri, tiny_gpt2_tokenizer, tiny_gpt2_model, tmp_path, generations_per_sample): pytest.importorskip('datasets')