Skip to content

Commit

Permalink
PyTorch parameter arguments (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Aug 13, 2018
1 parent c9b879b commit 8a462ca
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/pytorch-model.js
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class PyTorchModelFactory {
functionTable['torchvision.models.resnet.ResNet'] = function () {};
functionTable['torchvision.models.vgg.VGG'] = function () {};
functionTable['torch.nn.backends.thnn._get_thnn_function_backend'] = function () {};
functionTable['torch.nn.parameter.Parameter'] = function() {};
functionTable['torch.nn.parameter.Parameter'] = function(data, requires_grad) { this.data = data; this.requires_grad = requires_grad; };
functionTable['torch.FloatStorage'] = function (size) { this.size = size; this.dataTypeSize = 4; this.dataType = 'float32'; };
functionTable['torch.LongStorage'] = function (size) { this.size = size; this.dataTypeSize = 4; this.dataType = 'int64'; };

Expand Down Expand Up @@ -303,10 +303,16 @@ class PyTorchNode {
input.name = parameter.__id__;
input.connections = [];
this._inputs.push(input);
if (parameter && parameter.storage) {
if (parameter) {
var connection = {};
connection.initializer = new PyTorchTensor(parameter);
connection.type = connection.initializer.type.toString();
if (parameter.data) {
connection.initializer = new PyTorchTensor(parameter.data);
connection.type = connection.initializer.type.toString();
}
else if (parameter.storage) {
connection.initializer = new PyTorchTensor(parameter);
connection.type = connection.initializer.type.toString();
}
input.connections.push(connection);
}
});
Expand Down

0 comments on commit 8a462ca

Please sign in to comment.