Skip to content

Commit

Permalink
pt: add necessary jit.export (#3337)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
CaRoLZhangxy and pre-commit-ci[bot] authored Feb 25, 2024
1 parent 91049df commit 261c802
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 0 deletions.
1 change: 1 addition & 0 deletions deepmd/pt/model/model/dipole_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def forward(
model_predict["updated_coord"] += coord
return model_predict

@torch.jit.export
def forward_lower(
self,
extended_coord,
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def forward(
model_predict["force"] = model_ret["dforce"]
return model_predict

@torch.jit.export
def forward_lower(
self,
extended_coord,
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/model/ener_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def forward(
model_predict["updated_coord"] += coord
return model_predict

@torch.jit.export
def forward_lower(
self,
extended_coord,
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def forward_common_lower(
)
return model_predict

@torch.jit.export
def format_nlist(
self,
extended_coord: torch.Tensor,
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/model/polar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def forward(
model_predict["updated_coord"] += coord
return model_predict

@torch.jit.export
def forward_lower(
self,
extended_coord,
Expand Down

0 comments on commit 261c802

Please sign in to comment.