Skip to content

Commit

Permalink
PyTorch buffer support (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Aug 18, 2018
1 parent 4bb36e0 commit 7d69a55
Showing 1 changed file with 31 additions and 6 deletions.
37 changes: 31 additions & 6 deletions src/pytorch-model.js
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ class PyTorchModelFactory {
callback(new PyTorchError('Unsupported system information.'));
return;
}
if (sysInfo.protocol_version != 1001) {
callback(new PyTorchError("Unsupported protocol version '" + sysInfo.protocol_version + "'.", null));
return;
}
if (sysInfo.type_sizes) {
if ((sysInfo.type_sizes.int && sysInfo.type_sizes.int != 4) ||
(sysInfo.type_sizes.long && sysInfo.type_sizes.long != 4) ||
(sysInfo.type_sizes.short && sysInfo.type_sizes.short != 2))
{
callback(new PyTorchError('Unsupported type sizes.'));
return;
}
}

var functionTable = {};
functionTable['argparse.Namespace'] = function (args) { this.args = args; };
Expand Down Expand Up @@ -298,12 +311,24 @@ class PyTorchNode {
});
this._inputs = [ input ];

module._parameters.forEach((parameter) => {
var input = {};
input.name = parameter.__id__;
input.connections = [];
this._inputs.push(input);
if (parameter) {
var initializers = [];
if (module._parameters) {
module._parameters.forEach((parameter) => {
initializers.push(parameter);
});
}
if (module._buffers) {
module._buffers.forEach((buffer) => {
initializers.push(buffer);
});
}

initializers.forEach((parameter) => {
if (parameter && (parameter.data || parameter.storage)) {
var input = {};
input.name = parameter.__id__;
input.connections = [];
this._inputs.push(input);
var connection = {};
if (parameter.data) {
connection.initializer = new PyTorchTensor(parameter.data);
Expand Down

0 comments on commit 7d69a55

Please sign in to comment.