From dcb18415b1e45b498ac7a4bf9bda9a3b2b0c4d48 Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Wed, 27 Nov 2024 13:21:26 +0800 Subject: [PATCH] perf: print summary on rank 0 --- deepmd/pt/utils/dataloader.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index 9920622792..2fea6b72d2 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -185,19 +185,21 @@ def print_summary( name: str, prob: list[float], ) -> None: - print_summary( - name, - len(self.systems), - [ss.system for ss in self.systems], - [ss._natoms for ss in self.systems], - self.batch_sizes, - [ - ss._data_system.get_sys_numb_batch(self.batch_sizes[ii]) - for ii, ss in enumerate(self.systems) - ], - prob, - [ss._data_system.pbc for ss in self.systems], - ) + rank = dist.get_rank() if dist.is_initialized() else 0 + if rank == 0: + print_summary( + name, + len(self.systems), + [ss.system for ss in self.systems], + [ss._natoms for ss in self.systems], + self.batch_sizes, + [ + ss._data_system.get_sys_numb_batch(self.batch_sizes[ii]) + for ii, ss in enumerate(self.systems) + ], + prob, + [ss._data_system.pbc for ss in self.systems], + ) _sentinel = object()