Skip to content

Commit

Permalink
PyTorch FloatTensor serialization (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Nov 1, 2018
1 parent b508243 commit 5e38b73
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
20 changes: 17 additions & 3 deletions src/pickle.js
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,27 @@ pickle.Unpickler = class {
var value = stack.pop();
var key = stack.pop();
var obj = stack[stack.length - 1];
obj[key] = value;
if (Array.isArray(obj)) {
value.__id__ = key;
obj.push(value);
}
else {
obj[key] = value;
}
break;
case pickle.OpCode.SETITEMS:
var index = marker.pop();
var obj = stack[index - 1];
for (var position = index; position < stack.length; position += 2) {
obj[stack[position]] = stack[position + 1];
if (Array.isArray(obj)) {
for (var position = index; position < stack.length; position += 2) {
stack[position + 1].__id__ = stack[position];
obj.push(stack[position + 1]);
}
}
else {
for (var position = index; position < stack.length; position += 2) {
obj[stack[position]] = stack[position + 1];
}
}
stack = stack.slice(0, index);
break;
Expand Down
13 changes: 13 additions & 0 deletions src/pytorch-model.js
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ class PyTorchModelFactory {
constructorTable['torch.LongStorage'] = function (size) { this.size = size; this.dataTypeSize = 4; this.dataType = 'int64'; };
constructorTable['torch.FloatStorage'] = function (size) { this.size = size; this.dataTypeSize = 4; this.dataType = 'float32'; };
constructorTable['torch.DoubleStorage'] = function (size) { this.size = size; this.dataTypeSize = 8; this.dataType = 'float64'; };
constructorTable['torch.FloatTensor'] = function () {
this.__setstate__ = function(state) {
this.storage = state[0];
this.storage_offset = state[1];
this.size = state[2];
this.stride = state[3];
};
};

functionTable['torch._utils._rebuild_tensor'] = function (storage, storage_offset, size, stride) {
var obj = {};
Expand Down Expand Up @@ -204,6 +212,11 @@ class PyTorchModelFactory {
}
});

if (Array.isArray(root) && root.every((item) => item.__type__ == 'torch.FloatTensor')) {
callback(new PyTorchError("File does not contain a model graph. Use 'torch.save()' to save both the graph and tensor data."), null);
return;
}

if (!root._modules) {
callback(new PyTorchError('Root object does not contain modules.'), null);
return;
Expand Down

0 comments on commit 5e38b73

Please sign in to comment.