-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
chore: refactor training loop #4435
Conversation
📝 Walkthrough📝 WalkthroughWalkthroughThe changes in the pull request focus on enhancing the Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer
participant Model
participant Logger
participant Saver
Trainer->>Trainer: Start training loop
Trainer->>Model: Select model key based on probability distribution
Trainer->>Model: Call step with _step_id
Model-->>Trainer: Perform training step
Trainer->>Logger: Log training status
Trainer->>Saver: Save model state
Trainer->>Trainer: Repeat for next step
Possibly related PRs
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
deepmd/pt/train/training.py
(2 hunks)
🔇 Additional comments (1)
deepmd/pt/train/training.py (1)
938-939
: Simplified training loop is appropriate
The changes to the training loop enhance readability by directly iterating over the step range and removing unnecessary checks. This simplification aligns with best practices and improves code maintainability.
Signed-off-by: Chun Cai <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (1)
deepmd/pt/train/training.py (1)
938-939
: Consider adding progress tracking for long training runs.While the simplified training loop is clean and correct, consider adding a progress tracking mechanism (e.g., tqdm) for better visibility during long training runs.
+from tqdm import tqdm - for step_id in range(self.start_step, self.num_steps): + for step_id in tqdm(range(self.start_step, self.num_steps), + desc="Training", + disable=self.rank != 0): step(step_id)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
deepmd/pt/train/training.py
(2 hunks)
🔇 Additional comments (1)
deepmd/pt/train/training.py (1)
656-661
: LGTM: Multi-task model selection is well implemented.
The implementation:
- Uses
dp_random
for reproducible randomization - Correctly specifies integer dtype for the array
- Properly uses pre-computed probability distribution
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## devel #4435 +/- ##
==========================================
- Coverage 84.64% 83.72% -0.93%
==========================================
Files 614 667 +53
Lines 57138 61441 +4303
Branches 3487 3486 -1
==========================================
+ Hits 48367 51442 +3075
- Misses 7646 8875 +1229
+ Partials 1125 1124 -1 ☔ View full report in Codecov by Sentry. |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (1)
deepmd/pt/utils/dataloader.py (1)
53-53
: LGTM! Consider adding CUDNN benchmark disable for complete determinism.The addition of
np.random.seed(seed)
is a good improvement as it ensures reproducibility across both PyTorch and NumPy random operations, which is particularly important for multi-task training.For complete determinism, consider also adding:
def setup_seed(seed) -> None: torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False np.random.seed(seed)
This ensures CUDNN doesn't auto-tune algorithms, maintaining complete reproducibility at the cost of some performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (1)
deepmd/pt/utils/dataloader.py (1)
Line range hint
82-83
: Consider adding type hints to other attributes for consistencyGood addition of type hints for
systems
andsampler_list
. For consistency, consider adding type hints to other attributes in the class such as:
dataloaders
batch_sizes
index
iters
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
deepmd/pt/utils/dataloader.py
(2 hunks)
🔇 Additional comments (2)
deepmd/pt/utils/dataloader.py (2)
31-31
: LGTM: Import addition is appropriate
The addition of dp_random
import is well-placed and necessary for the seed initialization enhancement.
54-54
: Verify seed propagation in distributed training
The addition of dp_random.seed(seed)
improves reproducibility by ensuring consistent randomization across all components. However, let's verify that this seed is properly propagated in distributed training scenarios.
✅ Verification successful
Let me analyze the distributed training setup further to make a conclusive verification.
Seed setup is properly handled in distributed training
The seed initialization is correctly implemented for distributed training scenarios because:
- The seed is set before distributed training initialization (
dist.init_process_group
) - Each rank receives the same global seed value from the config via
data_seed = config["training"].get("seed", None)
- The
dp_random.seed()
call is made before any rank-specific operations, ensuring consistent randomization across all processes
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check if the seed setup is used in distributed training contexts
# Look for distributed training setup code that might need seed synchronization
# Search for distributed training initialization code
echo "Checking distributed training initialization patterns:"
rg -l "dist\.init_process_group|DistributedDataParallel"
# Search for other seed initialization points that might need dp_random
echo "Checking other seed initialization points:"
rg "seed\(" --type py
Length of output: 14026
Script:
#!/bin/bash
# Check how seed is handled in distributed training code
echo "Checking seed handling in distributed training files:"
rg -A 5 "setup_seed|dp_random\.seed" deepmd/pt/train/training.py deepmd/pt/entrypoints/main.py
# Check if there's any specific distributed seed handling
echo "Checking distributed seed synchronization:"
rg "setup_seed.*rank|dist.*seed" --type py
Length of output: 353
Script:
#!/bin/bash
# Let's try a broader search to understand the training initialization
echo "Checking training initialization and seed setup:"
rg -B 3 -A 7 "def train|def main" deepmd/pt/train/training.py deepmd/pt/entrypoints/main.py
# Also check for any rank-specific initialization
echo "Checking rank-specific initialization:"
rg -B 2 -A 5 "local_rank|world_rank|dist\.get_rank" deepmd/pt/train/training.py deepmd/pt/entrypoints/main.py
Length of output: 5477
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced training loop to support multi-task training, allowing for more flexible model selection. - **Improvements** - Streamlined `step` function to accept only the step ID, simplifying its usage. - Adjusted logging and model saving mechanisms for consistency with the new training flow. - Improved random seed management for enhanced reproducibility in data processing. - Enhanced error handling in data retrieval to ensure seamless operation during data loading. - Added type hints for better clarity in data loader attributes. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Chun Cai <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 03c6e49)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced training loop to support multi-task training, allowing for more flexible model selection. - **Improvements** - Streamlined `step` function to accept only the step ID, simplifying its usage. - Adjusted logging and model saving mechanisms for consistency with the new training flow. - Improved random seed management for enhanced reproducibility in data processing. - Enhanced error handling in data retrieval to ensure seamless operation during data loading. - Added type hints for better clarity in data loader attributes. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Chun Cai <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 03c6e49)
Summary by CodeRabbit
New Features
Improvements
step
function to accept only the step ID, simplifying its usage.