Skip to content

Commit

Permalink
TorchScript 1.3 prototype (#281)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Sep 6, 2019
1 parent a0bb598 commit 38b625a
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 68 deletions.
2 changes: 1 addition & 1 deletion src/pickle.js
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ pickle.Unpickler = class {
break;
case pickle.OpCode.TUPLE:
items = stack;
stack = marker .pop();
stack = marker.pop();
stack.push(items);
break;
case pickle.OpCode.SETITEM:
Expand Down
147 changes: 83 additions & 64 deletions src/torchscript.js
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,9 @@ torchscript.ModelFactory = class {
container.prefix = version.name.substring(0, version.name.length - 7);
let find = (name) => {
let entry = container.entries.find((entry) => entry.name == container.prefix + name);
if (entry) {
return entry.data;
}
return null;
return entry ? entry.data : null;
}
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md
container.version = version.data;
container.attributes = find('attribtues.pkl');
container.constants = find('constants.pkl');
Expand Down Expand Up @@ -285,8 +283,8 @@ torchscript.Graph = class {
if (parameter.tensorId) {
let tensorId = parseInt(parameter.tensorId, 10);
parameter.initializer = container.tensors[tensorId];
if (parameter.outputs && parameter.outputs.length == 1) {
container.parameters[parameter.outputs[0]] = parameter;
if (parameter.__outputs__ && parameter.__outputs__.length == 1) {
container.parameters[parameter.__outputs__[0]] = parameter;
}
}
}
Expand All @@ -299,34 +297,32 @@ torchscript.Graph = class {
}
}
}
/*
if (container.data) {
let queue = [ container.data ];
while (queue.length > 0) {
let module = queue.shift();
if (module.parameters) {
for (let parameter of module.parameters) {
if (parameter.tensorId) {
let tensorId = parseInt(parameter.tensorId, 10);
parameter.initializer = container.tensors[tensorId];
if (parameter.outputs && parameter.outputs.length == 1) {
container.parameters[parameter.outputs[0]] = parameter;
}
}
}
}
for (let key of Object.keys(module)) {
if (key !== '__type__' && key !== '__parent__') {
let submodule = module[key];
if (submodule === Object(submodule)) {
submodule.__parent__ = module;
queue.push(submodule);
let obj = module[key];
if (!Array.isArray(obj) && obj === Object(obj)) {
if (obj && obj.__type__ && obj.__type__.endsWith('Tensor')) {
let parameter = obj;
if (!parameter.initializer) {
parameter.initializer = new torchscript.Tensor('pickle', parameter);
}
if (parameter.__outputs__ && parameter.__outputs__.length == 1) {
container.parameters[parameter.__outputs__[0]] = parameter;
}
}
else {
obj.__parent__ = module;
queue.push(obj);
}
}
}
}
}
}
*/

if (context) {
for (let input of context.inputs) {
Expand All @@ -350,7 +346,7 @@ torchscript.Graph = class {
}

_loadModule(metadata, container, module) {
if (module.parameters && module.parameters.length > 0 && !module.hide) {
if (module.parameters && module.parameters.length > 0 && !module.__hide__) {
let node = new torchscript.Node(metadata, container, module, null);
this._nodes.push(node);
}
Expand Down Expand Up @@ -446,9 +442,9 @@ torchscript.Node = class {
this._inputs.push(new torchscript.Parameter(parameter.name, true, [
new torchscript.Argument('', null, parameter.initializer || null)
]));
if (parameter.outputs) {
if (parameter.__outputs__) {
this._outputs.push(new torchscript.Parameter(parameter.name, true,
parameter.outputs.map((id) => new torchscript.Argument(id, null, null))
parameter.__outputs__.map((id) => new torchscript.Argument(id, null, null))
));
}
}
Expand All @@ -468,8 +464,8 @@ torchscript.Node = class {
for (let argument of input) {
let parameter = container.parameters[argument.id];
if (parameter) {
if (parameter.module && (module == null || module == parameter.module)) {
module = parameter.module;
if (parameter.__module__ && (module == null || module == parameter.__module__)) {
module = parameter.__module__;
count++;
}
else {
Expand All @@ -482,8 +478,16 @@ torchscript.Node = class {
break;
}
}
if (module && module.parameters.length == count && match) {
module.hide = true;
let parametersLength = 0;
if (module && module.parameters) {
parametersLength = module.parameters.length;
}
else if (module) {
parametersLength = Object.keys(module).filter((k) => module[k] && module[k].__type__ && module[k].__type__.endsWith('Tensor')).length;
}

if (module && parametersLength == count && match) {
module.__hide__ = true;
for (let input of node.inputs) {
for (let argument of input) {
let parameter = container.parameters[argument.id];
Expand Down Expand Up @@ -1321,25 +1325,34 @@ torchscript.GraphContext = class {
return false;
}

_submodule(module, name) {
var obj = module[name];
if (obj && (!obj.__type__ || !obj.__type__.endsWith('Tensor'))) {
return obj;
}
if (module.submodules) {
for (let submodule of module.submodules) {
if (submodule.name === name) {
return submodule;
}
}
}
return null;
}

_module(expression) {
let module;
let submodule;
if (expression.type === '.') {
module = this._module(expression.target);
if (module && module.submodules) {
for (submodule of module.submodules) {
if (submodule.name === expression.member.value) {
return submodule;
}
let module = this._module(expression.target);
if (module) {
let submodule = this._submodule(module, expression.member.value);
if (submodule) {
return submodule;
}
}
if (module[expression.member.value]) {
return module[expression.member.value];
}
}
if (expression.type == 'call' &&
expression.target.type == 'identifier' && expression.target.value == 'getattr' && expression.arguments.length == 2) {
module = this._module(expression.arguments[0]);
let module = this._module(expression.arguments[0]);
if (!module) {
return null;
}
Expand All @@ -1348,21 +1361,17 @@ torchscript.GraphContext = class {
name = expression.arguments[1].value.substring(1, expression.arguments[1].value.length - 1);
}
if (module) {
if (module[name]) {
return module[name];
}
for (submodule of module.submodules) {
if (submodule.name === name) {
return submodule;
}
let submodule = this._submodule(module, name);
if (submodule) {
return submodule;
}
}
}
if (expression.type == 'identifier') {
if (expression.value == 'self') {
return this._mainModule;
}
module = this._moduleMap[expression.value];
let module = this._moduleMap[expression.value];
if (module) {
return module;
}
Expand All @@ -1387,30 +1396,40 @@ torchscript.GraphContext = class {
expression = this._moduleTensor(expression);
if (expression.type === '.' && expression.member.type == 'identifier') {
let targetModule = this._module(expression.target);
if (targetModule && targetModule.parameters) {
for (let parameter of targetModule.parameters) {
parameter.module = targetModule;
if (parameter.name === expression.member.value) {
parameter.outputs = parameter.outputs || [];
parameter.outputs.push(target.value);
return true;
if (targetModule) {
if (targetModule.parameters) {
for (let parameter of targetModule.parameters) {
parameter.__module__ = targetModule;
if (parameter.name === expression.member.value) {
parameter.__outputs__ = parameter.__outputs__ || [];
parameter.__outputs__.push(target.value);
return true;
}
}
}
targetModule.unresolvedParameters = targetModule.unresolvedParameters || [];
for (let unresolvedParameter of targetModule.unresolvedParameters) {
unresolvedParameter.module = targetModule;
let obj = targetModule[expression.member.value];
if (obj && obj.__type__ && obj.__type__.endsWith('Tensor')) {
obj.__module__ = targetModule;
obj.__outputs__ = obj.__outputs__ || [];
obj.__outputs__.push(target.value);
return true;
}
/*
targetModule.__unresolvedParameters__ = targetModule.__unresolvedParameters__ || [];
for (let unresolvedParameter of targetModule.__unresolvedParameters__) {
unresolvedParameter.__module__ = targetModule;
if (unresolvedParameter.name === expression.member.value) {
unresolvedParameter.outputs = unresolvedParameter.outputs || [];
unresolvedParameter.outputs.push(target.value);
unresolvedParameter.__outputs__ = unresolvedParameter.__outputs__ || [];
unresolvedParameter.__outputs__.push(target.value);
return true;
}
}
targetModule.unresolvedParameters.push({
targetModule.__unresolvedParameters__.push({
module: targetModule,
name: expression.member.value,
outputs: [ target.value ]
});
return true;
*/
}
}
return false;
Expand Down
13 changes: 10 additions & 3 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -3703,7 +3703,7 @@
{
"type": "pytorch",
"target": "mnist_linear.ckpt",
"source": "https://github.com/lutzroeder/netron/files/3269075/mnist_linear_torchscript.zip[mnist_linear.ckpt]",
"source": "https://github.com/lutzroeder/netron/files/3585288/mnist_linear_torchscript.zip[mnist_linear.ckpt]",
"format": "PyTorch",
"link": "https://github.com/lutzroeder/netron/issues/281"
},
Expand Down Expand Up @@ -5006,8 +5006,15 @@
},
{
"type": "torchscript",
"target": "mnist_linear_torchscript.pt",
"source": "https://github.com/lutzroeder/netron/files/3269075/mnist_linear_torchscript.zip[mnist_linear_torchscript.pt]",
"target": "mnist_linear_torchscript_1.pt",
"source": "https://github.com/lutzroeder/netron/files/3585288/mnist_linear_torchscript.zip[mnist_linear_torchscript_1.pt]",
"format": "TorchScript v1",
"link": "https://github.com/lutzroeder/netron/issues/281"
},
{
"type": "torchscript",
"target": "mnist_linear_torchscript_2.pt",
"source": "https://github.com/lutzroeder/netron/files/3585288/mnist_linear_torchscript.zip[mnist_linear_torchscript_2.pt]",
"format": "TorchScript v1",
"link": "https://github.com/lutzroeder/netron/issues/281"
},
Expand Down

0 comments on commit 38b625a

Please sign in to comment.