From 7a0458eb8845dc6c7d6afce49ce3f790da2eba0c Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sun, 26 Jan 2020 16:59:21 -0800 Subject: [PATCH] TorchScript 1.4 prototype (#281) --- src/pytorch-metadata.json | 31 ++----- src/pytorch.js | 173 ++++++++++++++++++++++++-------------- test/models.json | 65 ++++++-------- 3 files changed, 143 insertions(+), 126 deletions(-) diff --git a/src/pytorch-metadata.json b/src/pytorch-metadata.json index c7b825aeeb7..c6f47d8ad46 100755 --- a/src/pytorch-metadata.json +++ b/src/pytorch-metadata.json @@ -1807,27 +1807,9 @@ "outputs": [ { "name": "output" - } - ] - } - }, - { - "name": "torch.size", - "schema": { - "attributes": [ - { - "name": "dim", - "type": "int64" - } - ], - "inputs": [ - { - "name": "input" - } - ], - "outputs": [ + }, { - "name": "output" + "name": "output2" } ] } @@ -2001,16 +1983,15 @@ { "name": "bidirectional", "type": "boolean" - }, - { - "name": "batch_first", - "type": "boolean" } ], "category": "Layer", "inputs": [ { - "name": "input" + "name": "data" + }, + { + "name": "batch_sizes" }, { "name": "hx" diff --git a/src/pytorch.js b/src/pytorch.js index ec9cc343745..fd9022ec228 100644 --- a/src/pytorch.js +++ b/src/pytorch.js @@ -702,7 +702,7 @@ pytorch.Attribute = class { } } - if (Array.isArray(value) && value.length > 0 && value.every((obj) => obj.__module__ && obj.__module__.startsWith('torch.nn'))) { + if (Array.isArray(value) && value.length > 0 && value.every((obj) => obj && obj.__module__ && obj.__module__.startsWith('torch.nn'))) { this._value = '?'; } } @@ -1023,6 +1023,7 @@ pytorch.Execution = class { this._context.scope.builtins.type = { __module__: 'builtins', __name__: 'type' }; this._context.scope.builtins.module = { __module__: 'builtins', __name__: 'module', __class__: this._context.scope.builtins.type }; this._context.scope.builtins.function = { __module__: 'builtins', __name__: 'function', __class__:this._context.scope.builtins.type }; + this._context.scope.builtins.method = { __module__: 'builtins', __name__: 'method', __class__:this._context.scope.builtins.type }; this._registerConstructor('argparse.Namespace', function (args) { this.args = args; }); @@ -1453,6 +1454,9 @@ pytorch.Execution = class { this._registerFunction('unchecked_cast', function(type, value) { return value; }); + this._registerFunction('ops.prim.data', function(tensor) { + return tensor; + }); this._registerFunction('ops.prim.unchecked_unwrap_optional', function(value) { return value; }); @@ -1502,6 +1506,10 @@ pytorch.Execution = class { this._registerFunction('torch._unwrap_optional', function(value) { return value; // TODO }); + this._registerFunction('torch.append', function(tensors, tensor) { + tensors.push(tensor); + return tensor; + }); this._registerFunction('torch.dim', function(tensor) { if (tensor && tensor.size) { return tensor.size.length; @@ -1517,6 +1525,9 @@ pytorch.Execution = class { } throw new pytorch.Error('Unknown expression type.'); }); + this._registerFunction('torch.floordiv', function(/* left, right */) { + return undefined; + }); this._registerFunction('torch.gt', function(left, right) { if (typeof left === 'number' && typeof right === 'number') { return left > right; @@ -1538,6 +1549,9 @@ pytorch.Execution = class { this._registerFunction('torch.len', function(/* value */) { return undefined; }); + this._registerFunction('torch.list_with_default', function(size /*, defaults */) { + return size; + }); this._registerFunction('torch.lt', function(left, right) { if (typeof left === 'number' && typeof right === 'number') { return left < right; @@ -1549,7 +1563,7 @@ pytorch.Execution = class { return left * right; } if (pytorch.Utility.isTensor(left) && pytorch.Utility.isTensor(right)) { - return { __module__: 'torch', __name__: 'Tensor' }; + return { __module__: 'torch', __name__: 'Tensor', __origin__: 'torch.mul' }; } throw new pytorch.Error('Unknown expression type.'); }); @@ -1565,8 +1579,14 @@ pytorch.Execution = class { this._registerFunction('torch.t', function(tensor) { return tensor; }); + this._registerFunction('torch.size', function(tensor) { + if (tensor && tensor.size) { + return tensor.size; + } + return undefined; + }); this._registerFunction('uninitialized', function(type) { - return ({ __module__: 'torch', __name__: type }); + return ({ __module__: 'torch', __name__: type, __origin__: 'uninitialized' }); }); /* this._registerOperator('torch._convolution', 1); @@ -1682,8 +1702,8 @@ pytorch.Execution = class { _call(target, name, args, context) { const callTarget = this._target(target, context); - const callArguments = args.map((argument) => this.expression(argument, context)); - if (!callTarget || !callTarget[name]) { + let callArguments = args.map((argument) => this.expression(argument, context)); + if (!callTarget || (name !== null && !callTarget[name])) { const targetName = pytorch.Utility.target(target) + '.' + name; if (this.type(targetName)) { return this.invoke(targetName, callArguments); @@ -1692,7 +1712,7 @@ pytorch.Execution = class { return this._invokeCallback(targetName, args, context); } } - const func = callTarget[name]; + const func = name ? callTarget[name] : callTarget; if (func.__class__ === this._context.scope.builtins.type) { let obj = {}; obj.__proto__ = func; @@ -1702,6 +1722,11 @@ pytorch.Execution = class { return obj; } if (func.__class__ === this._context.scope.builtins.function) { + if (func.__call__) { + return func.__call__(callArguments); + } + } + if (func.__class__ === this._context.scope.builtins.method) { if (func.__call__) { return func.__call__([ callTarget ].concat(callArguments)); } @@ -1713,10 +1738,10 @@ pytorch.Execution = class { } apply(method, args, context) { - args = Array.prototype.slice.call(args); + let locals = Array.prototype.slice.call(args); context = context.push(); for (const parameter of method.parameters) { - context.set(parameter.name, args.shift()); + context.set(parameter.name, locals.shift()); } return this._block(method.body.statements, context) } @@ -1735,8 +1760,19 @@ pytorch.Execution = class { case 'def': { const module = context.get('__name__'); const self = this; + const parent = context.get('__class__'); + let type = null; + if (parent === this._context.scope.builtins.type) { + type = this._context.scope.builtins.method; + } + else if (parent === this._context.scope.builtins.module) { + type = this._context.scope.builtins.function; + } + else { + throw new pytorch.Error('Invalid function scope.'); + } const func = { - __class__: this._context.scope.builtins.function, + __class__: type, __globals__: context, __module__: module, __name__: statement.name, @@ -1874,17 +1910,21 @@ pytorch.Execution = class { case 'call': { if (expression.target.type === 'id' && expression.target.value === 'uninitialized' && expression.arguments.length === 1 && expression.arguments[0].type === 'id' && expression.arguments[0].value === 'Tensor') { - return { __module__: 'torch', __name__: 'Tensor' }; + return { __module__: 'torch', __name__: 'Tensor', __origin__: 'uninitialized' }; } if (expression.target.type === 'id' && expression.target.value === 'annotate' && expression.arguments.length === 2) { return this.expression(expression.arguments[1], context); } + if (expression.target.type === 'id' && expression.target.value === 'unchecked_cast' && expression.arguments.length === 2) { + return this.expression(expression.arguments[1], context); + } if (expression.target.type === '.') { return this._call(expression.target.target, expression.target.member.value, expression.arguments, context); } - const target = this.expression(expression.target, context); - const args = expression.arguments.map((argument) => this.expression(argument, context)); - return target.apply(self, args); + return this._call(expression.target, null, expression.arguments, context); + // const target = this.expression(expression.target, context); + // const args = expression.arguments.map((argument) => this.expression(argument, context)); + // return target.apply(self, args); } case 'id': { switch (expression.value) { @@ -1893,19 +1933,7 @@ pytorch.Execution = class { case 'True': return true; case 'False': return false; } - const value = context.get(expression.value); - if (value !== undefined) { - return value; - } - if (expression.value === 'Tensor') { - throw new Error("Unsupported '" + expression.value + "'."); - // return { __typeref__: expression.value }; - } - if (expression.value === 'int') { - throw new Error("Unsupported '" + expression.value + "'."); - // return { __typeref__: expression.value }; - } - break; + return context.get(expression.value); } case 'tuple': { return expression.value.map((expression) => this.expression(expression, context)); @@ -2749,27 +2777,28 @@ pytorch.Container.Zip = class { const schema = this._metadata.type(name); if (schema) { args = Array.prototype.slice.call(args); - - let node = {}; - node.type = name; - node.inputs = []; - node.outputs = []; - node.attributes = []; - + let node = { + type: name, + inputs: [], + attributes: [], + outputs: [] + }; for (let i = 0; i < schema.inputs.length; i++) { - let arg = args.shift(); + const arg = args.shift(); const p = this._execution.expression(arg, context); const parameters = Array.isArray(p) ? p : [ p ]; let inputs = []; - for (const parameter of parameters) { if (parameter) { + if (!pytorch.Utility.isTensor(parameter)) { + throw new pytorch.Error('Tensor expected.'); + } if (parameter.__variable__) { inputs.push({ id: parameter.__variable__ }); - node.inputs.push([ ]); } else { const id = this._variable().value; + parameter.__variable__ = id; parameter.__outputs__ = parameter.__outputs__ || []; parameter.__outputs__.push(id); inputs.push({ id: id }); @@ -2778,34 +2807,38 @@ pytorch.Container.Zip = class { } node.inputs.push(inputs); } - let outputs = [] - for (let i = 0; i < schema.outputs.length; i++) { - let parameter = { __module__: 'torch', __name__: 'Tensor' }; - parameter.__variable__ = this._variable().value; - outputs.push(parameter) - node.outputs.push(parameter.__variable__); + while (args.length > 0 && args[0].type !== '=') { + const arg = args.shift(); + const value = this._execution.expression(arg, context); + node.attributes.push(value); } - /* while (args.length > 0) { - let argument = args[0] - if (pytorch.Utility.isTensor(argument)) { - node.inputs.push([ argument ]); - args.shift(); - continue; + const arg = args.shift(); + if (arg.type === '=' && arg.target && arg.target.type === 'id') { + const value = this._execution.expression(arg.expression, context); + node.attributes.push({ type: '=', target: arg.target, expression: value }); } - if (Array.isArray(argument) && argument.every((tensor) => pytorch.Utility.isTensor(tensor))) { - node.inputs.push([ argument ]); - args.shift(); - continue; + else { + throw new pytorch.Attribute('Expected named argument.'); } - break; } - while (args.length > 0) { - let argument = args[0] - node.attributes.push(argument); - args.shift(); + let outputs = [] + for (let i = 0; i < schema.outputs.length; i++) { + let parameter = { __module__: 'torch', __name__: 'Tensor', __origin__: 'invoke-output-' + name }; + switch (name) { + case 'torch.cat': + case 'torch.conv2d': + case 'torch.flatten': + case 'torch.relu_': + case 'torch.dropout': { + parameter.size = [ undefined, undefined, undefined, undefined ]; + break; + } + } + parameter.__variable__ = this._variable().value; + outputs.push(parameter) + node.outputs.push(parameter.__variable__); } - */ this._nodes.push(node); if (outputs.length > 1) { return outputs; @@ -2890,8 +2923,8 @@ pytorch.Container.Zip = class { } trace() { - // this._trace = true; - if (this._trace) { + this._trace = false; + if (this._trace || this.format == 'TorchScript v1.4') { this._inputs = []; this._outputs = []; this._nodes = []; @@ -2914,7 +2947,25 @@ pytorch.Container.Zip = class { } } if (this.data.forward) { - this.data.forward.__call__([ this.data, { __module__: 'torch', __name__: 'Tensor' } ]); + let args = [ this.data ]; // self + if (this.data.forward.__code__ && this.data.forward.__code__.parameters) { + for (const parameter of this.data.forward.__code__.parameters) { + if (parameter.name !== 'self' && + parameter.parameterType.type === 'type' && + parameter.parameterType.name.type === 'id' && + parameter.parameterType.name.value === 'Tensor') { + this._inputs.push(parameter.name); + args.push({ __module__: 'torch', __name__: 'Tensor', __variable__: parameter.name, __origin__: 'trace-input' }); + } + } + } + const result = this.data.forward.__call__(args); + const outputs = !Array.isArray(result) ? [ result ] : result; + for (const output of outputs) { + if (pytorch.Utility.isTensor(output)) { + this._outputs.push(output.__variable__); + } + } return true; } else { diff --git a/test/models.json b/test/models.json index 2f1488fbea0..fa42a863e96 100644 --- a/test/models.json +++ b/test/models.json @@ -3838,7 +3838,6 @@ "target": "alexnet.pt", "script": "./tools/pytorch sync install zoo", "format": "TorchScript v1.4", - "error": "Unknown statement in 'alexnet.pt'.", "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, { @@ -3846,7 +3845,6 @@ "target": "alexnet_traced.pt", "script": "./tools/pytorch sync install zoo", "format": "TorchScript v1.4", - "error": "Unknown statement in 'alexnet_traced.pt'.", "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, { @@ -3936,7 +3934,6 @@ { "type": "pytorch", "target": "densenet161.pt", - "error": "Unknown statement in 'densenet161.pt'.", "render": "skip", "script": "./tools/pytorch sync install zoo", "format": "TorchScript v1.4", @@ -3946,7 +3943,6 @@ "type": "pytorch", "target": "densenet161_traced.pt", "render": "skip", - "error": "Unknown function argument in 'densenet161_traced.pt'.", "script": "./tools/pytorch sync install zoo", "format": "TorchScript v1.4", "link": "https://pytorch.org/docs/stable/torchvision/models.html" @@ -3968,7 +3964,7 @@ "type": "pytorch", "target": "inception_v3.pt", "script": "./tools/pytorch sync install zoo", - "error": "Unknown statement in 'inception_v3.pt'.", + "error": "Unknown expression type in 'inception_v3.pt'.", "format": "TorchScript v1.4", "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, @@ -3976,7 +3972,6 @@ "type": "pytorch", "target": "inception_v3_traced.pt", "script": "./tools/pytorch sync install zoo", - "error": "Unknown statement in 'inception_v3_traced.pt'.", "format": "TorchScript v1.4", "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, @@ -4068,7 +4063,6 @@ "type": "pytorch", "target": "mobilenet_v2.pt", "script": "./tools/pytorch sync install zoo", - "error": "Unknown statement in 'mobilenet_v2.pt'.", "format": "TorchScript v1.4", "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, @@ -4076,7 +4070,6 @@ "type": "pytorch", "target": "mobilenet_v2_traced.pt", "script": "./tools/pytorch sync install zoo", - "error": "Unknown function argument in 'mobilenet_v2_traced.pt'.", "format": "TorchScript v1.4", "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, @@ -4162,7 +4155,6 @@ "type": "pytorch", "target": "resnet18.pt", "script": "./tools/pytorch sync install zoo", - "error": "Unknown statement in 'resnet18.pt'.", "format": "TorchScript v1.4", "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, @@ -4170,7 +4162,6 @@ "type": "pytorch", "target": "resnet18_traced.pt", "script": "./tools/pytorch sync install zoo", - "error": "Unknown statement in 'resnet18_traced.pt'.", "format": "TorchScript v1.4", "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, @@ -4227,7 +4218,6 @@ "type": "pytorch", "target": "resnet101.pt", "script": "./tools/pytorch sync install zoo", - "error": "Unknown statement in 'resnet101.pt'.", "format": "TorchScript v1.4", "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, @@ -4235,7 +4225,6 @@ "type": "pytorch", "target": "resnet101_traced.pt", "script": "./tools/pytorch sync install zoo", - "error": "Unknown statement in 'resnet101_traced.pt'.", "format": "TorchScript v1.4", "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, @@ -4257,7 +4246,8 @@ "type": "pytorch", "target": "resnet101-5d3b4d8f.pth", "source": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", - "format": "PyTorch v0.1.1" + "format": "PyTorch v0.1.1", + "link": "https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py" }, { "type": "pytorch", @@ -4276,25 +4266,24 @@ { "type": "pytorch", "target": "shufflenet_v2_x1_0.pkl.pth", - "link": "https://pytorch.org/docs/stable/torchvision/models.html", "script": "./tools/pytorch sync install zoo", - "format": "PyTorch v0.1.10" + "format": "PyTorch v0.1.10", + "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, { "type": "pytorch", "target": "shufflenet_v2_x1_0.zip.pth", - "link": "https://pytorch.org/docs/stable/torchvision/models.html", "script": "./tools/pytorch sync install zoo", - "format": "PyTorch v1.4" + "format": "PyTorch v1.4", + "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, { "type": "pytorch", "target": "shufflenet_v2_x1_0.pt", - "link": "https://pytorch.org/docs/stable/torchvision/models.html", - "error": "Unknown statement in 'shufflenet_v2_x1_0.pt'.", "render": "skip", "script": "./tools/pytorch sync install zoo", - "format": "TorchScript v1.4" + "format": "TorchScript v1.4", + "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, { "type": "pytorch", @@ -4305,32 +4294,30 @@ { "type": "pytorch", "target": "squeezenet1_1.pkl.pth", - "link": "https://pytorch.org/docs/stable/torchvision/models.html", "script": "./tools/pytorch sync install zoo", - "format": "PyTorch v0.1.10" + "format": "PyTorch v0.1.10", + "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, { "type": "pytorch", "target": "squeezenet1_1.zip.pth", - "link": "https://pytorch.org/docs/stable/torchvision/models.html", "script": "./tools/pytorch sync install zoo", - "format": "PyTorch v1.4" + "format": "PyTorch v1.4", + "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, { "type": "pytorch", "target": "squeezenet1_1.pt", - "link": "https://pytorch.org/docs/stable/torchvision/models.html", - "error": "Unknown statement in 'squeezenet1_1.pt'.", "script": "./tools/pytorch sync install zoo", - "format": "TorchScript v1.4" + "format": "TorchScript v1.4", + "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, { "type": "pytorch", "target": "squeezenet1_1_traced.pt", - "link": "https://pytorch.org/docs/stable/torchvision/models.html", - "error": "Unknown statement in 'squeezenet1_1_traced.pt'.", "script": "./tools/pytorch sync install zoo", - "format": "TorchScript v1.4" + "format": "TorchScript v1.4", + "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, { "type": "pytorch", @@ -4431,35 +4418,34 @@ { "type": "pytorch", "target": "vgg11_bn.pkl.pth", - "link": "https://pytorch.org/docs/stable/torchvision/models.html", "script": "./tools/pytorch sync install zoo", - "format": "PyTorch v0.1.10" + "format": "PyTorch v0.1.10", + "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, { "type": "pytorch", "target": "vgg11_bn.zip.pth", - "link": "https://pytorch.org/docs/stable/torchvision/models.html", "script": "./tools/pytorch sync install zoo", - "format": "PyTorch v1.4" + "format": "PyTorch v1.4", + "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, { "type": "pytorch", "target": "vgg16.pkl.pth", - "link": "https://pytorch.org/docs/stable/torchvision/models.html", "script": "./tools/pytorch sync install zoo", - "format": "PyTorch v0.1.10" + "format": "PyTorch v0.1.10", + "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, { "type": "pytorch", "target": "vgg16.zip.pth", - "link": "https://pytorch.org/docs/stable/torchvision/models.html", "script": "./tools/pytorch sync install zoo", - "format": "PyTorch v1.4" + "format": "PyTorch v1.4", + "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, { "type": "pytorch", "target": "vgg16.pt", - "error": "Unknown statement in 'vgg16.pt'.", "script": "./tools/pytorch sync install zoo", "format": "TorchScript v1.4", "link": "https://pytorch.org/docs/stable/torchvision/models.html" @@ -4467,7 +4453,6 @@ { "type": "pytorch", "target": "vgg16_traced.pt", - "error": "Unknown statement in 'vgg16_traced.pt'.", "script": "./tools/pytorch sync install zoo", "format": "TorchScript v1.4", "link": "https://pytorch.org/docs/stable/torchvision/models.html"