Skip to content

Commit

Permalink
chore: refactor training loop (deepmodeling#4435)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Enhanced training loop to support multi-task training, allowing for
more flexible model selection.

- **Improvements**
- Streamlined `step` function to accept only the step ID, simplifying
its usage.
- Adjusted logging and model saving mechanisms for consistency with the
new training flow.
- Improved random seed management for enhanced reproducibility in data
processing.
- Enhanced error handling in data retrieval to ensure seamless operation
during data loading.
	- Added type hints for better clarity in data loader attributes.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Chun Cai <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
(cherry picked from commit 03c6e49)
  • Loading branch information
caic99 authored and njzjz committed Dec 22, 2024
1 parent 69a1628 commit 09bdfc5
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 18 deletions.
26 changes: 8 additions & 18 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,12 @@ def run(self) -> None:
prof.start()

def step(_step_id, task_key="Default") -> None:
if self.multi_task:
model_index = dp_random.choice(
np.arange(self.num_model, dtype=np.int_),
p=self.model_prob,
)
task_key = self.model_keys[model_index]
# PyTorch Profiler
if self.enable_profiler or self.profiling:
prof.step()
Expand Down Expand Up @@ -929,24 +935,8 @@ def log_loss_valid(_task_key="Default"):
self.wrapper.train()
self.t0 = time.time()
self.total_train_time = 0.0
for step_id in range(self.num_steps):
if step_id < self.start_step:
continue
if self.multi_task:
chosen_index_list = dp_random.choice(
np.arange(
self.num_model, dtype=np.int32
), # int32 should be enough for # models...
p=np.array(self.model_prob),
size=self.world_size,
replace=True,
)
assert chosen_index_list.size == self.world_size
model_index = chosen_index_list[self.rank]
model_key = self.model_keys[model_index]
else:
model_key = "Default"
step(step_id, model_key)
for step_id in range(self.start_step, self.num_steps):
step(step_id)
if JIT:
break

Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)

from deepmd.pt.utils import (
dp_random,
env,
)
from deepmd.pt.utils.dataset import (
Expand All @@ -50,6 +51,7 @@ def setup_seed(seed) -> None:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
dp_random.seed(seed)


class DpLoaderSet(Dataset):
Expand Down

0 comments on commit 09bdfc5

Please sign in to comment.