From 5e38b73ef13d22fb8d8870d8d2b986f214f03e5c Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Wed, 31 Oct 2018 20:33:04 -0700 Subject: [PATCH] PyTorch FloatTensor serialization (#133) --- src/pickle.js | 20 +++++++++++++++++--- src/pytorch-model.js | 13 +++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/pickle.js b/src/pickle.js index 6846915a93f..343272341b7 100644 --- a/src/pickle.js +++ b/src/pickle.js @@ -140,13 +140,27 @@ pickle.Unpickler = class { var value = stack.pop(); var key = stack.pop(); var obj = stack[stack.length - 1]; - obj[key] = value; + if (Array.isArray(obj)) { + value.__id__ = key; + obj.push(value); + } + else { + obj[key] = value; + } break; case pickle.OpCode.SETITEMS: var index = marker.pop(); var obj = stack[index - 1]; - for (var position = index; position < stack.length; position += 2) { - obj[stack[position]] = stack[position + 1]; + if (Array.isArray(obj)) { + for (var position = index; position < stack.length; position += 2) { + stack[position + 1].__id__ = stack[position]; + obj.push(stack[position + 1]); + } + } + else { + for (var position = index; position < stack.length; position += 2) { + obj[stack[position]] = stack[position + 1]; + } } stack = stack.slice(0, index); break; diff --git a/src/pytorch-model.js b/src/pytorch-model.js index b7ee82a2a6f..960167af143 100755 --- a/src/pytorch-model.js +++ b/src/pytorch-model.js @@ -110,6 +110,14 @@ class PyTorchModelFactory { constructorTable['torch.LongStorage'] = function (size) { this.size = size; this.dataTypeSize = 4; this.dataType = 'int64'; }; constructorTable['torch.FloatStorage'] = function (size) { this.size = size; this.dataTypeSize = 4; this.dataType = 'float32'; }; constructorTable['torch.DoubleStorage'] = function (size) { this.size = size; this.dataTypeSize = 8; this.dataType = 'float64'; }; + constructorTable['torch.FloatTensor'] = function () { + this.__setstate__ = function(state) { + this.storage = state[0]; + this.storage_offset = state[1]; + this.size = state[2]; + this.stride = state[3]; + }; + }; functionTable['torch._utils._rebuild_tensor'] = function (storage, storage_offset, size, stride) { var obj = {}; @@ -204,6 +212,11 @@ class PyTorchModelFactory { } }); + if (Array.isArray(root) && root.every((item) => item.__type__ == 'torch.FloatTensor')) { + callback(new PyTorchError("File does not contain a model graph. Use 'torch.save()' to save both the graph and tensor data."), null); + return; + } + if (!root._modules) { callback(new PyTorchError('Root object does not contain modules.'), null); return;