diff --git a/src/pickle.js b/src/pickle.js index 6846915a93f..343272341b7 100644 --- a/src/pickle.js +++ b/src/pickle.js @@ -140,13 +140,27 @@ pickle.Unpickler = class { var value = stack.pop(); var key = stack.pop(); var obj = stack[stack.length - 1]; - obj[key] = value; + if (Array.isArray(obj)) { + value.__id__ = key; + obj.push(value); + } + else { + obj[key] = value; + } break; case pickle.OpCode.SETITEMS: var index = marker.pop(); var obj = stack[index - 1]; - for (var position = index; position < stack.length; position += 2) { - obj[stack[position]] = stack[position + 1]; + if (Array.isArray(obj)) { + for (var position = index; position < stack.length; position += 2) { + stack[position + 1].__id__ = stack[position]; + obj.push(stack[position + 1]); + } + } + else { + for (var position = index; position < stack.length; position += 2) { + obj[stack[position]] = stack[position + 1]; + } } stack = stack.slice(0, index); break; diff --git a/src/pytorch-model.js b/src/pytorch-model.js index b7ee82a2a6f..960167af143 100755 --- a/src/pytorch-model.js +++ b/src/pytorch-model.js @@ -110,6 +110,14 @@ class PyTorchModelFactory { constructorTable['torch.LongStorage'] = function (size) { this.size = size; this.dataTypeSize = 4; this.dataType = 'int64'; }; constructorTable['torch.FloatStorage'] = function (size) { this.size = size; this.dataTypeSize = 4; this.dataType = 'float32'; }; constructorTable['torch.DoubleStorage'] = function (size) { this.size = size; this.dataTypeSize = 8; this.dataType = 'float64'; }; + constructorTable['torch.FloatTensor'] = function () { + this.__setstate__ = function(state) { + this.storage = state[0]; + this.storage_offset = state[1]; + this.size = state[2]; + this.stride = state[3]; + }; + }; functionTable['torch._utils._rebuild_tensor'] = function (storage, storage_offset, size, stride) { var obj = {}; @@ -204,6 +212,11 @@ class PyTorchModelFactory { } }); + if (Array.isArray(root) && root.every((item) => item.__type__ == 'torch.FloatTensor')) { + callback(new PyTorchError("File does not contain a model graph. Use 'torch.save()' to save both the graph and tensor data."), null); + return; + } + if (!root._modules) { callback(new PyTorchError('Root object does not contain modules.'), null); return;