We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add an example to make predictions with trained model. Something like the following:
from alignn.models.alignn import ALIGNN, ALIGNNConfig import torch import pprint from alignn.config import TrainingConfig from jarvis.core.atoms import Atoms from jarvis.core.graphs import Graph from jarvis.db.jsonutils import dumpjson, loadjson device = "cpu" if torch.cuda.is_available(): device = torch.device("cuda") filename = "checkpoint_100.pt" cutoff = 8 max_neighbors = 12 config = loadjson("config.json") print(pprint.pprint(config)) config = TrainingConfig(**config) model = ALIGNN(config.model) model.load_state_dict(torch.load(filename, map_location=device)["model"]) model.to(device) model.eval() atoms = Atoms.from_poscar("POSCAR") g, lg = Graph.atom_dgl_multigraph( atoms, cutoff=float(cutoff), max_neighbors=max_neighbors, ) out_data = ( torch.argmax(model([g.to(device), lg.to(device)])) .detach() .cpu() .numpy() .flatten() .tolist() )[0] print("out_data class ", out_data)
The text was updated successfully, but these errors were encountered:
Can I make predictions on a new dataset of CIFs with a model trained from an old dataset?
Sorry, something went wrong.
No branches or pull requests
Add an example to make predictions with trained model. Something like the following:
The text was updated successfully, but these errors were encountered: