From 8a462cabaa32d3be0c3a04adecfc134b10c43b13 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Mon, 13 Aug 2018 00:19:10 -0700 Subject: [PATCH] PyTorch parameter arguments (#133) --- src/pytorch-model.js | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/pytorch-model.js b/src/pytorch-model.js index eb061d2ca2..3c1c8f4198 100755 --- a/src/pytorch-model.js +++ b/src/pytorch-model.js @@ -68,7 +68,7 @@ class PyTorchModelFactory { functionTable['torchvision.models.resnet.ResNet'] = function () {}; functionTable['torchvision.models.vgg.VGG'] = function () {}; functionTable['torch.nn.backends.thnn._get_thnn_function_backend'] = function () {}; - functionTable['torch.nn.parameter.Parameter'] = function() {}; + functionTable['torch.nn.parameter.Parameter'] = function(data, requires_grad) { this.data = data; this.requires_grad = requires_grad; }; functionTable['torch.FloatStorage'] = function (size) { this.size = size; this.dataTypeSize = 4; this.dataType = 'float32'; }; functionTable['torch.LongStorage'] = function (size) { this.size = size; this.dataTypeSize = 4; this.dataType = 'int64'; }; @@ -303,10 +303,16 @@ class PyTorchNode { input.name = parameter.__id__; input.connections = []; this._inputs.push(input); - if (parameter && parameter.storage) { + if (parameter) { var connection = {}; - connection.initializer = new PyTorchTensor(parameter); - connection.type = connection.initializer.type.toString(); + if (parameter.data) { + connection.initializer = new PyTorchTensor(parameter.data); + connection.type = connection.initializer.type.toString(); + } + else if (parameter.storage) { + connection.initializer = new PyTorchTensor(parameter); + connection.type = connection.initializer.type.toString(); + } input.connections.push(connection); } });