Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a constructor to TorchForce that takes a torch::jit::Module #97

Merged
merged 34 commits into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f344f82
Add version number as a member to TorchForceProxy
RaulPPelaez Jan 20, 2023
73b5222
Encode the model file contents when serializing TorchForce
RaulPPelaez Jan 20, 2023
66b2530
Add tests for new TorchForce serialization
RaulPPelaez Jan 20, 2023
74bf087
Fix test not finding Python executable
RaulPPelaez Jan 23, 2023
85a6acd
Format include directives correctly
RaulPPelaez Jan 25, 2023
68cc189
Hardcode TorchForceProxy version number
RaulPPelaez Jan 25, 2023
49d5d6e
Fix formatting issues
RaulPPelaez Jan 25, 2023
20f4a7e
Move Python serialization test to the correct place
RaulPPelaez Jan 25, 2023
58dbbba
Make function encodeFromFileName static
RaulPPelaez Jan 25, 2023
612d383
Update serialization python test to correctly remove temporary files …
RaulPPelaez Jan 26, 2023
8fe7906
Use the base64 encoding capabilities of openssl to serialize model file
RaulPPelaez Jan 26, 2023
f996da9
Update TorchForce serializer
RaulPPelaez Jan 26, 2023
05d7764
Add a constructor to TorchForce that takes a torch::jit::Module.
RaulPPelaez Jan 30, 2023
f4e2b76
Remove unnecessary include
RaulPPelaez Jan 30, 2023
60a1ebe
Change i_file to file in TorchForce constructor
RaulPPelaez Jan 31, 2023
7c7b068
Add swig typemaps to new TorchForce constructor
RaulPPelaez Feb 1, 2023
8fae5eb
Add setup.py as a dependency for the PythonInstall CMake rule
RaulPPelaez Feb 3, 2023
ab0ef40
Fix swig out typemap for torch::jit::Module
RaulPPelaez Feb 3, 2023
a23672a
Remove commented line in CMakeLists.txt
RaulPPelaez Feb 6, 2023
1e4cae6
Remove unnecessary dependency in setup.py
RaulPPelaez Feb 6, 2023
103be5c
Add more tests for new constructor
RaulPPelaez Feb 6, 2023
068779b
Add some comments for the new constructor
RaulPPelaez Feb 6, 2023
5d54cbb
Merge branch 'serialization' into module_constructor
RaulPPelaez Feb 6, 2023
0ed5cee
Updates to TorchForce serialization
RaulPPelaez Feb 6, 2023
fecb6d0
Use hex encoding instead of base64 for serialization.
RaulPPelaez Feb 6, 2023
57191be
Remove unnecessary header
RaulPPelaez Feb 6, 2023
25bdfac
Update Python serialization test
RaulPPelaez Feb 6, 2023
8597c34
Merge remote-tracking branch 'origin/master' into module_constructor
RaulPPelaez Feb 7, 2023
ba554e8
Minor changes
RaulPPelaez Feb 7, 2023
51baa27
Improve temporary path handling in python serialization tests
RaulPPelaez Feb 8, 2023
abf43ff
More informative exception when failing to serialize TorchForce
RaulPPelaez Feb 8, 2023
6210b2b
Remove unnecessary check in TorchForce serialization
RaulPPelaez Feb 8, 2023
7c4cf66
Changes to C++ serialization tests
RaulPPelaez Feb 8, 2023
19749f0
Changes to C++ serialization tests
RaulPPelaez Feb 8, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions openmmapi/include/TorchForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "openmm/Context.h"
#include "openmm/Force.h"
#include <string>
#include <torch/torch.h>
#include "internal/windowsExportTorch.h"

namespace TorchPlugin {
Expand All @@ -52,10 +53,22 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
* @param file the path to the file containing the network
*/
TorchForce(const std::string& file);
/**
* Create a TorchForce. The network is defined by a PyTorch ScriptModule
*
* @param module an instance of the torch module
*/
TorchForce(const torch::jit::Module &module);
/**
* Get the path to the file containing the network.
* If the TorchForce instance was constructed with a module, instead of a filename,
* this function returns an empty string.
*/
const std::string& getFile() const;
/**
* Get the torch module currently in use.
*/
const torch::jit::Module & getModule() const;
/**
* Set whether this force makes use of periodic boundary conditions. If this is set
* to true, the network must take a 3x3 tensor as its second input, which
Expand Down Expand Up @@ -128,6 +141,7 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
std::string file;
bool usePeriodic, outputsForces;
std::vector<GlobalParameterInfo> globalParameters;
torch::jit::Module module;
};

/**
Expand Down
1 change: 0 additions & 1 deletion openmmapi/include/internal/TorchForceImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ class OPENMM_EXPORT_NN TorchForceImpl : public OpenMM::ForceImpl {
private:
const TorchForce& owner;
OpenMM::Kernel kernel;
torch::jit::script::Module module;
};

} // namespace TorchPlugin
Expand Down
12 changes: 11 additions & 1 deletion openmmapi/src/TorchForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,28 @@
#include "openmm/OpenMMException.h"
#include "openmm/internal/AssertionUtilities.h"
#include <fstream>
#include <torch/torch.h>
#include <torch/csrc/jit/serialization/import.h>

using namespace TorchPlugin;
using namespace OpenMM;
using namespace std;

TorchForce::TorchForce(const std::string& file) : file(file), usePeriodic(false), outputsForces(false) {
TorchForce::TorchForce(const torch::jit::Module& module) : file(), usePeriodic(false), outputsForces(false), module(module) {
}

TorchForce::TorchForce(const std::string& i_file) : TorchForce(torch::jit::load(i_file)) {
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved
this->file = i_file;
}

const string& TorchForce::getFile() const {
return file;
}

const torch::jit::Module& TorchForce::getModule() const {
return this->module;
}

ForceImpl* TorchForce::createImpl() const {
return new TorchForceImpl(*this);
}
Expand Down
6 changes: 1 addition & 5 deletions openmmapi/src/TorchForceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,8 @@ TorchForceImpl::~TorchForceImpl() {
}

void TorchForceImpl::initialize(ContextImpl& context) {
// Load the module from the file.

module = torch::jit::load(owner.getFile());

auto module = owner.getModule();
// Create the kernel.

kernel = context.getPlatform().createKernel(CalcTorchForceKernel::Name(), context);
kernel.getAs<CalcTorchForceKernel>().initialize(context.getSystem(), owner, module);
}
Expand Down
4 changes: 4 additions & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@ set(WRAP_FILE TorchPluginWrapper.cpp)
set(MODULE_NAME openmmtorch)

# Execute SWIG to generate source code for the Python module.
foreach(dir ${TORCH_INCLUDE_DIRS})
set(torchincs "${torchincs}" "-I${dir}")
endforeach()

add_custom_command(
OUTPUT "${WRAP_FILE}"
COMMAND "${SWIG_EXECUTABLE}"
-python -c++
-o "${WRAP_FILE}"
"-I${OPENMM_DIR}/include"
${torchincs}
"${CMAKE_CURRENT_SOURCE_DIR}/openmmtorch.i"
DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/openmmtorch.i"
WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}"
Expand Down
6 changes: 3 additions & 3 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import platform

openmm_dir = '@OPENMM_DIR@'
torch_include_dirs = '@TORCH_INCLUDE_DIRS@'.split(';')
nn_plugin_header_dir = '@NN_PLUGIN_HEADER_DIR@'
nn_plugin_library_dir = '@NN_PLUGIN_LIBRARY_DIR@'
torch_dir, _ = os.path.split('@TORCH_LIBRARY@')

# setup extra compile and link arguments on Mac
extra_compile_args = ['-std=c++11']
extra_compile_args = ['-std=c++14']
extra_link_args = []

if platform.system() == 'Darwin':
Expand All @@ -20,7 +21,7 @@
extension = Extension(name='_openmmtorch',
sources=['TorchPluginWrapper.cpp'],
libraries=['OpenMM', 'OpenMMTorch'],
include_dirs=[os.path.join(openmm_dir, 'include'), nn_plugin_header_dir],
include_dirs=[os.path.join(openmm_dir, 'include'), nn_plugin_header_dir] + torch_include_dirs,
library_dirs=[os.path.join(openmm_dir, 'lib'), nn_plugin_library_dir],
runtime_library_dirs=[os.path.join(openmm_dir, 'lib'), torch_dir],
extra_compile_args=extra_compile_args,
Expand All @@ -32,4 +33,3 @@
py_modules=['openmmtorch'],
ext_modules=[extension],
)

26 changes: 22 additions & 4 deletions serialization/tests/TestSerializeTorchForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,30 @@
#include "openmm/internal/AssertionUtilities.h"
#include "openmm/serialization/XmlSerializer.h"
#include <iostream>
#include <cstdio>
#include <sstream>
#include <torch/torch.h>

using namespace TorchPlugin;
using namespace OpenMM;
using namespace std;

extern "C" void registerTorchSerializationProxies();

void testSerialization() {
// Create a Force.
struct ExampleModuleImpl : torch::nn::Module { };
TORCH_MODULE(ExampleModule);

string writeExampleModuleFile() {
auto fileName = string(tmpnam(nullptr)) + ".pt";
ExampleModule module;
torch::save(module, fileName);
return fileName;
}
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved

TorchForce force("module.pt");
void testSerializationFromFile() {
// Create a Force.
string fileName = writeExampleModuleFile();
TorchForce force(fileName);
force.setForceGroup(3);
force.addGlobalParameter("x", 1.3);
force.addGlobalParameter("y", 2.221);
Expand All @@ -62,6 +74,11 @@ void testSerialization() {

TorchForce& force2 = *copy;
ASSERT_EQUAL(force.getFile(), force2.getFile());
ostringstream bufferModule;
force.getModule().save(bufferModule);
ostringstream bufferModule2;
force2.getModule().save(bufferModule2);
ASSERT_EQUAL(bufferModule.str(), bufferModule2.str());
ASSERT_EQUAL(force.getForceGroup(), force2.getForceGroup());
ASSERT_EQUAL(force.getNumGlobalParameters(), force2.getNumGlobalParameters());
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
Expand All @@ -70,12 +87,13 @@ void testSerialization() {
}
ASSERT_EQUAL(force.usesPeriodicBoundaryConditions(), force2.usesPeriodicBoundaryConditions());
ASSERT_EQUAL(force.getOutputsForces(), force2.getOutputsForces());
remove(fileName.c_str());
}

int main() {
try {
registerTorchSerializationProxies();
testSerialization();
testSerializationFromFile();
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
Expand Down