Skip to content

Commit

Permalink
Add united test for trainer.test and description in the example (#165)
Browse files Browse the repository at this point in the history
  • Loading branch information
kenko911 authored Sep 13, 2023
1 parent 9c79457 commit c55cf17
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@
"source": [
"# If you wish to disable GPU or MPS (M1 mac) training, use the accelerator=\"cpu\" kwarg.\n",
"logger = CSVLogger(\"logs\", name=\"M3GNet_training\")\n",
"trainer = pl.Trainer(max_epochs=1, accelerator=\"cpu\", logger=logger)\n",
"# Inference mode = False is required for calculating forces, stress in test mode and prediction mode\n",
"trainer = pl.Trainer(max_epochs=1, accelerator=\"cpu\", logger=logger, inference_mode=False)\n",
"trainer.fit(model=lit_module, train_dataloaders=train_loader, val_dataloaders=val_loader)"
]
},
Expand Down Expand Up @@ -405,7 +406,7 @@
"source": [
"# If you wish to disable GPU or MPS (M1 mac) training, use the accelerator=\"cpu\" kwarg.\n",
"logger = CSVLogger(\"logs\", name=\"M3GNet_finetuning\")\n",
"trainer = pl.Trainer(max_epochs=1, accelerator=\"cpu\", logger=logger)\n",
"trainer = pl.Trainer(max_epochs=1, accelerator=\"cpu\", logger=logger, inference_mode=False)\n",
"trainer.fit(model=lit_module_finetune, train_dataloaders=train_loader, val_dataloaders=val_loader)"
]
},
Expand Down Expand Up @@ -467,7 +468,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
"version": "3.10.9"
}
},
"nbformat": 4,
Expand Down
3 changes: 2 additions & 1 deletion tests/utils/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,10 @@ def test_m3gnet_training(self, LiFePO4, BaNiO3):
model = M3GNet(element_types=element_types, is_intensive=False)
lit_model = PotentialLightningModule(model=model, stress_weight=0.0001)
# We will use CPU if MPS is available since there is a serious bug.
trainer = pl.Trainer(max_epochs=5, accelerator=device)
trainer = pl.Trainer(max_epochs=5, accelerator=device, inference_mode=False)

trainer.fit(model=lit_model, train_dataloaders=train_loader, val_dataloaders=val_loader)
trainer.test(lit_model, dataloaders=test_loader)

pred_LFP_energy = model.predict_structure(LiFePO4)
pred_BNO_energy = model.predict_structure(BaNiO3)
Expand Down

0 comments on commit c55cf17

Please sign in to comment.