diff --git a/docs/source-pytorch/advanced/transfer_learning.rst b/docs/source-pytorch/advanced/transfer_learning.rst index 935bcd4bb8773..7f6af6ad5a56d 100644 --- a/docs/source-pytorch/advanced/transfer_learning.rst +++ b/docs/source-pytorch/advanced/transfer_learning.rst @@ -116,6 +116,7 @@ Here's a model that uses `Huggingface transformers Union[DeepSpeedSummary, Summary]: from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy @@ -83,12 +91,14 @@ def summarize( total_parameters: int, trainable_parameters: int, model_size: float, + total_training_modes: Dict[str, int], **summarize_kwargs: Any, ) -> None: summary_table = _format_summary_table( total_parameters, trainable_parameters, model_size, + total_training_modes, *summary_data, ) log.info("\n" + summary_table) diff --git a/src/lightning/pytorch/callbacks/rich_model_summary.py b/src/lightning/pytorch/callbacks/rich_model_summary.py index f551a9c397531..c6c429b4bd2f5 100644 --- a/src/lightning/pytorch/callbacks/rich_model_summary.py +++ b/src/lightning/pytorch/callbacks/rich_model_summary.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Tuple +from typing import Any, Dict, List, Tuple from typing_extensions import override @@ -71,6 +71,7 @@ def summarize( total_parameters: int, trainable_parameters: int, model_size: float, + total_training_modes: Dict[str, int], **summarize_kwargs: Any, ) -> None: from rich import get_console @@ -110,5 +111,7 @@ def summarize( grid.add_row(f"[bold]Non-trainable params[/]: {parameters[1]}") grid.add_row(f"[bold]Total params[/]: {parameters[2]}") grid.add_row(f"[bold]Total estimated model params size (MB)[/]: {parameters[3]}") + grid.add_row(f"[bold]Modules in train mode[/]: {total_training_modes['train']}") + grid.add_row(f"[bold]Modules in eval mode[/]: {total_training_modes['eval']}") console.print(grid) diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary.py b/src/lightning/pytorch/utilities/model_summary/model_summary.py index 806724e1c434a..0f48bee191c7b 100644 --- a/src/lightning/pytorch/utilities/model_summary/model_summary.py +++ b/src/lightning/pytorch/utilities/model_summary/model_summary.py @@ -187,6 +187,8 @@ class ModelSummary: 0 Non-trainable params 132 K Total params 0.530 Total estimated model params size (MB) + 3 Modules in train mode + 0 Modules in eval mode >>> ModelSummary(model, max_depth=-1) # doctest: +NORMALIZE_WHITESPACE | Name | Type | Params | Mode | In sizes | Out sizes ---------------------------------------------------------------------- @@ -198,6 +200,8 @@ class ModelSummary: 0 Non-trainable params 132 K Total params 0.530 Total estimated model params size (MB) + 3 Modules in train mode + 0 Modules in eval mode """ @@ -252,6 +256,12 @@ def param_nums(self) -> List[int]: def training_modes(self) -> List[bool]: return [layer.training for layer in self._layer_summary.values()] + @property + def total_training_modes(self) -> Dict[str, int]: + modes = [layer.training for layer in self._model.modules()] + modes = modes[1:] # exclude the root module + return {"train": modes.count(True), "eval": modes.count(False)} + @property def total_parameters(self) -> int: return sum(p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters()) @@ -351,8 +361,9 @@ def __str__(self) -> str: total_parameters = self.total_parameters trainable_parameters = self.trainable_parameters model_size = self.model_size + total_training_modes = self.total_training_modes - return _format_summary_table(total_parameters, trainable_parameters, model_size, *arrays) + return _format_summary_table(total_parameters, trainable_parameters, model_size, total_training_modes, *arrays) def __repr__(self) -> str: return str(self) @@ -372,6 +383,7 @@ def _format_summary_table( total_parameters: int, trainable_parameters: int, model_size: float, + total_training_modes: Dict[str, int], *cols: Tuple[str, List[str]], ) -> str: """Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one big @@ -408,6 +420,10 @@ def _format_summary_table( summary += "Total params" summary += "\n" + s.format(get_formatted_model_size(model_size), 10) summary += "Total estimated model params size (MB)" + summary += "\n" + s.format(total_training_modes["train"], 10) + summary += "Modules in train mode" + summary += "\n" + s.format(total_training_modes["eval"], 10) + summary += "Modules in eval mode" return summary diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index c065cdcdbf022..633c1dc0853e0 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -58,7 +58,7 @@ def on_train_epoch_end(self, trainer, pl_module): self.saved_states.append(self.state_dict().copy()) -@RunIf(sklearn=True) +@RunIf(sklearn=True, skip_windows=True) # Flaky test on Windows for unknown reasons @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_resume_early_stopping_from_checkpoint(tmp_path): """Prevent regressions to bugs: diff --git a/tests/tests_pytorch/callbacks/test_model_summary.py b/tests/tests_pytorch/callbacks/test_model_summary.py index 0f255367f1a10..b42907dc9a38d 100644 --- a/tests/tests_pytorch/callbacks/test_model_summary.py +++ b/tests/tests_pytorch/callbacks/test_model_summary.py @@ -49,6 +49,7 @@ def summarize( total_parameters: int, trainable_parameters: int, model_size: float, + total_training_modes, **summarize_kwargs: Any, ) -> None: assert summary_data[1][0] == "Name" @@ -64,6 +65,8 @@ def summarize( assert summary_data[4][0] == "Mode" assert summary_data[4][1][0] == "train" + assert total_training_modes == {"train": 1, "eval": 0} + model = BoringModel() trainer = Trainer(default_root_dir=tmp_path, callbacks=CustomModelSummary(), max_steps=1) diff --git a/tests/tests_pytorch/callbacks/test_rich_model_summary.py b/tests/tests_pytorch/callbacks/test_rich_model_summary.py index f8ede0eb0239e..73709fd80a833 100644 --- a/tests/tests_pytorch/callbacks/test_rich_model_summary.py +++ b/tests/tests_pytorch/callbacks/test_rich_model_summary.py @@ -56,7 +56,13 @@ def example_input_array(self) -> Any: summary = summarize(model) summary_data = summary._get_summary_data() - model_summary.summarize(summary_data=summary_data, total_parameters=1, trainable_parameters=1, model_size=1) + model_summary.summarize( + summary_data=summary_data, + total_parameters=1, + trainable_parameters=1, + model_size=1, + total_training_modes=summary.total_training_modes, + ) # ensure that summary was logged + the breakdown of model parameters assert mock_console.call_count == 2 diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index 14692b359dad3..65fccb691a33d 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -218,7 +218,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: assert dm.my_state_dict == {"my": "state_dict"} -@RunIf(sklearn=True) +@RunIf(sklearn=True, skip_windows=True) # Flaky test on Windows for unknown reasons def test_full_loop(tmp_path): seed_everything(7) diff --git a/tests/tests_pytorch/utilities/test_model_summary.py b/tests/tests_pytorch/utilities/test_model_summary.py index a50ec425fc894..00fdf77d4cdfd 100644 --- a/tests/tests_pytorch/utilities/test_model_summary.py +++ b/tests/tests_pytorch/utilities/test_model_summary.py @@ -423,6 +423,29 @@ def forward(self, x): assert not model.layer2.training +def test_total_training_modes(): + """Test that the `total_training_modes` counts the modules in 'train' and 'eval' mode, excluding the root + module.""" + + class ModelWithoutChildren(LightningModule): + pass + + summary = ModelSummary(ModelWithoutChildren()) + assert summary.total_training_modes == {"train": 0, "eval": 0} + + model = DeepNestedModel() + summary = ModelSummary(model) + assert summary.total_training_modes == {"train": 19, "eval": 0} + assert sum(summary.total_training_modes.values()) == len(list(model.modules())) - 1 + + model = DeepNestedModel() + summary = ModelSummary(model) + model.branch1[1][0].eval() + model.branch2.eval() + assert summary.total_training_modes == {"train": 17, "eval": 2} + assert sum(summary.total_training_modes.values()) == len(list(model.modules())) - 1 + + def test_summary_training_mode(): """Test that the model summary captures the training mode on all submodules.""" model = DeepNestedModel() @@ -436,6 +459,7 @@ def test_summary_training_mode(): "eval", # branch2 "train", # head ] + assert summary.total_training_modes == {"train": 17, "eval": 2} summary = summarize(model, max_depth=-1) expected_eval = {"branch1.1.0", "branch2"} @@ -445,5 +469,7 @@ def test_summary_training_mode(): # A model with params not belonging to a layer model = NonLayerParamsModel() model.layer.eval() - summary_data = OrderedDict(summarize(model)._get_summary_data()) + summary = summarize(model) + summary_data = OrderedDict(summary._get_summary_data()) assert summary_data["Mode"] == ["eval", "n/a"] + assert summary.total_training_modes == {"train": 0, "eval": 1}