Skip to content

Commit

Permalink
Fix serialization and clean up API
Browse files Browse the repository at this point in the history
  • Loading branch information
Raimondas Galvelis committed Apr 10, 2020
1 parent 96c92f1 commit 6494b3a
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 19 deletions.
8 changes: 3 additions & 5 deletions openmmapi/include/TensorRTForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ class OPENMM_EXPORT_NN TensorRTForce : public Force {
* @param file the path to the file containing the graph
*/
TensorRTForce(const std::string& file);
/**
* Get the path to the file containing the graph.
*/
const std::string& getFile() const { return file; }
/**
* Get the content of the protocol buffer defining the graph.
*/
Expand All @@ -42,7 +38,9 @@ class OPENMM_EXPORT_NN TensorRTForce : public Force {
protected:
ForceImpl* createImpl() const;
private:
std::string file;
friend class TensorRTForceProxy;
TensorRTForce(const std::string& serializedGraph, bool usePeriodic) :
serializedGraph(serializedGraph), usePeriodic(usePeriodic) {};
std::string serializedGraph;
bool usePeriodic;
};
Expand Down
2 changes: 1 addition & 1 deletion openmmapi/src/TensorRTForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

using namespace OpenMM;

TensorRTForce::TensorRTForce(const std::string& file): file(file), usePeriodic(false) {
TensorRTForce::TensorRTForce(const std::string& file): usePeriodic(false) {

// Read the serialized graph from a file
std::stringstream stream;
Expand Down
2 changes: 1 addition & 1 deletion serialization/include/TensorRTForceProxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace OpenMM {

class OPENMM_EXPORT_NN TensorRTForceProxy : public SerializationProxy {
public:
TensorRTForceProxy();
TensorRTForceProxy() : SerializationProxy("TensorRTForce") {};
void serialize(const void* object, SerializationNode& node) const;
void* deserialize(const SerializationNode& node) const;
};
Expand Down
11 changes: 5 additions & 6 deletions serialization/src/TensorRTForceProxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@

using namespace OpenMM;

TensorRTForceProxy::TensorRTForceProxy() : SerializationProxy("TensorRTForce") {
}

void TensorRTForceProxy::serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 1);
const auto& force = *reinterpret_cast<const TensorRTForce*>(object);
node.setStringProperty("file", force.getFile());
node.setStringProperty("serializedGraph", force.serializedGraph);
node.setBoolProperty("usePeriodic", force.usePeriodic);
}

void* TensorRTForceProxy::deserialize(const SerializationNode& node) const {
if (node.getIntProperty("version") != 1)
throw OpenMMException("Unsupported version number");
auto force = new TensorRTForce(node.getStringProperty("file"));
return force;
const auto& serializedGraph = node.getStringProperty("serializedGraph");
bool usePeriodic = node.getBoolProperty("usePeriodic");
return new TensorRTForce(serializedGraph, usePeriodic);
}
11 changes: 5 additions & 6 deletions serialization/tests/TestSerializeTensorRTForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include <sstream>

using namespace OpenMM;
using namespace std;

extern "C" void registerTensorRTSerializationProxies();

Expand All @@ -17,25 +16,25 @@ void testSerialization() {

// Serialize and then deserialize it.

stringstream buffer;
std::stringstream buffer;
XmlSerializer::serialize<TensorRTForce>(&force, "Force", buffer);
TensorRTForce* copy = XmlSerializer::deserialize<TensorRTForce>(buffer);

// Compare the two forces to see if they are identical.

TensorRTForce& force2 = *copy;
ASSERT_EQUAL(force.getFile(), force2.getFile());
ASSERT_EQUAL(force.getSerializedGraph(), force2.getSerializedGraph());
}

int main() {
try {
registerTensorRTSerializationProxies();
testSerialization();
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
catch(const std::exception& e) {
std::cout << "exception: " << e.what() << std::endl;
return 1;
}
cout << "Done" << endl;
std::cout << "Done" << std::endl;
return 0;
}

0 comments on commit 6494b3a

Please sign in to comment.