Skip to content

Commit

Permalink
TorchScript 1.4 prototype (#281)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jan 27, 2020
1 parent cce1991 commit 8990235
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 127 deletions.
31 changes: 6 additions & 25 deletions src/pytorch-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -1807,27 +1807,9 @@
"outputs": [
{
"name": "output"
}
]
}
},
{
"name": "torch.size",
"schema": {
"attributes": [
{
"name": "dim",
"type": "int64"
}
],
"inputs": [
{
"name": "input"
}
],
"outputs": [
},
{
"name": "output"
"name": "output2"
}
]
}
Expand Down Expand Up @@ -2001,16 +1983,15 @@
{
"name": "bidirectional",
"type": "boolean"
},
{
"name": "batch_first",
"type": "boolean"
}
],
"category": "Layer",
"inputs": [
{
"name": "input"
"name": "data"
},
{
"name": "batch_sizes"
},
{
"name": "hx"
Expand Down
173 changes: 112 additions & 61 deletions src/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ pytorch.Attribute = class {
}
}

if (Array.isArray(value) && value.length > 0 && value.every((obj) => obj.__module__ && obj.__module__.startsWith('torch.nn'))) {
if (Array.isArray(value) && value.length > 0 && value.every((obj) => obj && obj.__module__ && obj.__module__.startsWith('torch.nn'))) {
this._value = '?';
}
}
Expand Down Expand Up @@ -1023,6 +1023,7 @@ pytorch.Execution = class {
this._context.scope.builtins.type = { __module__: 'builtins', __name__: 'type' };
this._context.scope.builtins.module = { __module__: 'builtins', __name__: 'module', __class__: this._context.scope.builtins.type };
this._context.scope.builtins.function = { __module__: 'builtins', __name__: 'function', __class__:this._context.scope.builtins.type };
this._context.scope.builtins.method = { __module__: 'builtins', __name__: 'method', __class__:this._context.scope.builtins.type };
this._registerConstructor('argparse.Namespace', function (args) {
this.args = args;
});
Expand Down Expand Up @@ -1453,6 +1454,9 @@ pytorch.Execution = class {
this._registerFunction('unchecked_cast', function(type, value) {
return value;
});
this._registerFunction('ops.prim.data', function(tensor) {
return tensor;
});
this._registerFunction('ops.prim.unchecked_unwrap_optional', function(value) {
return value;
});
Expand Down Expand Up @@ -1502,6 +1506,10 @@ pytorch.Execution = class {
this._registerFunction('torch._unwrap_optional', function(value) {
return value; // TODO
});
this._registerFunction('torch.append', function(tensors, tensor) {
tensors.push(tensor);
return tensor;
});
this._registerFunction('torch.dim', function(tensor) {
if (tensor && tensor.size) {
return tensor.size.length;
Expand All @@ -1517,6 +1525,9 @@ pytorch.Execution = class {
}
throw new pytorch.Error('Unknown expression type.');
});
this._registerFunction('torch.floordiv', function(/* left, right */) {
return undefined;
});
this._registerFunction('torch.gt', function(left, right) {
if (typeof left === 'number' && typeof right === 'number') {
return left > right;
Expand All @@ -1538,6 +1549,9 @@ pytorch.Execution = class {
this._registerFunction('torch.len', function(/* value */) {
return undefined;
});
this._registerFunction('torch.list_with_default', function(size /*, defaults */) {
return size;
});
this._registerFunction('torch.lt', function(left, right) {
if (typeof left === 'number' && typeof right === 'number') {
return left < right;
Expand All @@ -1549,7 +1563,7 @@ pytorch.Execution = class {
return left * right;
}
if (pytorch.Utility.isTensor(left) && pytorch.Utility.isTensor(right)) {
return { __module__: 'torch', __name__: 'Tensor' };
return { __module__: 'torch', __name__: 'Tensor', __origin__: 'torch.mul' };
}
throw new pytorch.Error('Unknown expression type.');
});
Expand All @@ -1565,8 +1579,14 @@ pytorch.Execution = class {
this._registerFunction('torch.t', function(tensor) {
return tensor;
});
this._registerFunction('torch.size', function(tensor) {
if (tensor && tensor.size) {
return tensor.size;
}
return undefined;
});
this._registerFunction('uninitialized', function(type) {
return ({ __module__: 'torch', __name__: type });
return ({ __module__: 'torch', __name__: type, __origin__: 'uninitialized' });
});
/*
this._registerOperator('torch._convolution', 1);
Expand Down Expand Up @@ -1682,8 +1702,8 @@ pytorch.Execution = class {

_call(target, name, args, context) {
const callTarget = this._target(target, context);
const callArguments = args.map((argument) => this.expression(argument, context));
if (!callTarget || !callTarget[name]) {
let callArguments = args.map((argument) => this.expression(argument, context));
if (!callTarget || (name !== null && !callTarget[name])) {
const targetName = pytorch.Utility.target(target) + '.' + name;
if (this.type(targetName)) {
return this.invoke(targetName, callArguments);
Expand All @@ -1692,7 +1712,7 @@ pytorch.Execution = class {
return this._invokeCallback(targetName, args, context);
}
}
const func = callTarget[name];
const func = name ? callTarget[name] : callTarget;
if (func.__class__ === this._context.scope.builtins.type) {
let obj = {};
obj.__proto__ = func;
Expand All @@ -1702,6 +1722,11 @@ pytorch.Execution = class {
return obj;
}
if (func.__class__ === this._context.scope.builtins.function) {
if (func.__call__) {
return func.__call__(callArguments);
}
}
if (func.__class__ === this._context.scope.builtins.method) {
if (func.__call__) {
return func.__call__([ callTarget ].concat(callArguments));
}
Expand All @@ -1713,10 +1738,10 @@ pytorch.Execution = class {
}

apply(method, args, context) {
args = Array.prototype.slice.call(args);
let locals = Array.prototype.slice.call(args);
context = context.push();
for (const parameter of method.parameters) {
context.set(parameter.name, args.shift());
context.set(parameter.name, locals.shift());
}
return this._block(method.body.statements, context)
}
Expand All @@ -1735,8 +1760,19 @@ pytorch.Execution = class {
case 'def': {
const module = context.get('__name__');
const self = this;
const parent = context.get('__class__');
let type = null;
if (parent === this._context.scope.builtins.type) {
type = this._context.scope.builtins.method;
}
else if (parent === this._context.scope.builtins.module) {
type = this._context.scope.builtins.function;
}
else {
throw new pytorch.Error('Invalid function scope.');
}
const func = {
__class__: this._context.scope.builtins.function,
__class__: type,
__globals__: context,
__module__: module,
__name__: statement.name,
Expand Down Expand Up @@ -1874,17 +1910,21 @@ pytorch.Execution = class {
case 'call': {
if (expression.target.type === 'id' && expression.target.value === 'uninitialized' &&
expression.arguments.length === 1 && expression.arguments[0].type === 'id' && expression.arguments[0].value === 'Tensor') {
return { __module__: 'torch', __name__: 'Tensor' };
return { __module__: 'torch', __name__: 'Tensor', __origin__: 'uninitialized' };
}
if (expression.target.type === 'id' && expression.target.value === 'annotate' && expression.arguments.length === 2) {
return this.expression(expression.arguments[1], context);
}
if (expression.target.type === 'id' && expression.target.value === 'unchecked_cast' && expression.arguments.length === 2) {
return this.expression(expression.arguments[1], context);
}
if (expression.target.type === '.') {
return this._call(expression.target.target, expression.target.member.value, expression.arguments, context);
}
const target = this.expression(expression.target, context);
const args = expression.arguments.map((argument) => this.expression(argument, context));
return target.apply(self, args);
return this._call(expression.target, null, expression.arguments, context);
// const target = this.expression(expression.target, context);
// const args = expression.arguments.map((argument) => this.expression(argument, context));
// return target.apply(self, args);
}
case 'id': {
switch (expression.value) {
Expand All @@ -1893,19 +1933,7 @@ pytorch.Execution = class {
case 'True': return true;
case 'False': return false;
}
const value = context.get(expression.value);
if (value !== undefined) {
return value;
}
if (expression.value === 'Tensor') {
throw new Error("Unsupported '" + expression.value + "'.");
// return { __typeref__: expression.value };
}
if (expression.value === 'int') {
throw new Error("Unsupported '" + expression.value + "'.");
// return { __typeref__: expression.value };
}
break;
return context.get(expression.value);
}
case 'tuple': {
return expression.value.map((expression) => this.expression(expression, context));
Expand Down Expand Up @@ -2749,27 +2777,28 @@ pytorch.Container.Zip = class {
const schema = this._metadata.type(name);
if (schema) {
args = Array.prototype.slice.call(args);

let node = {};
node.type = name;
node.inputs = [];
node.outputs = [];
node.attributes = [];

let node = {
type: name,
inputs: [],
attributes: [],
outputs: []
};
for (let i = 0; i < schema.inputs.length; i++) {
let arg = args.shift();
const arg = args.shift();
const p = this._execution.expression(arg, context);
const parameters = Array.isArray(p) ? p : [ p ];
let inputs = [];

for (const parameter of parameters) {
if (parameter) {
if (!pytorch.Utility.isTensor(parameter)) {
throw new pytorch.Error('Tensor expected.');
}
if (parameter.__variable__) {
inputs.push({ id: parameter.__variable__ });
node.inputs.push([ ]);
}
else {
const id = this._variable().value;
parameter.__variable__ = id;
parameter.__outputs__ = parameter.__outputs__ || [];
parameter.__outputs__.push(id);
inputs.push({ id: id });
Expand All @@ -2778,34 +2807,38 @@ pytorch.Container.Zip = class {
}
node.inputs.push(inputs);
}
let outputs = []
for (let i = 0; i < schema.outputs.length; i++) {
let parameter = { __module__: 'torch', __name__: 'Tensor' };
parameter.__variable__ = this._variable().value;
outputs.push(parameter)
node.outputs.push(parameter.__variable__);
while (args.length > 0 && args[0].type !== '=') {
const arg = args.shift();
const value = this._execution.expression(arg, context);
node.attributes.push(value);
}
/*
while (args.length > 0) {
let argument = args[0]
if (pytorch.Utility.isTensor(argument)) {
node.inputs.push([ argument ]);
args.shift();
continue;
const arg = args.shift();
if (arg.type === '=' && arg.target && arg.target.type === 'id') {
const value = this._execution.expression(arg.expression, context);
node.attributes.push({ type: '=', target: arg.target, expression: value });
}
if (Array.isArray(argument) && argument.every((tensor) => pytorch.Utility.isTensor(tensor))) {
node.inputs.push([ argument ]);
args.shift();
continue;
else {
throw new pytorch.Attribute('Expected named argument.');
}
break;
}
while (args.length > 0) {
let argument = args[0]
node.attributes.push(argument);
args.shift();
let outputs = []
for (let i = 0; i < schema.outputs.length; i++) {
let parameter = { __module__: 'torch', __name__: 'Tensor', __origin__: 'invoke-output-' + name };
switch (name) {
case 'torch.cat':
case 'torch.conv2d':
case 'torch.flatten':
case 'torch.relu_':
case 'torch.dropout': {
parameter.size = [ undefined, undefined, undefined, undefined ];
break;
}
}
parameter.__variable__ = this._variable().value;
outputs.push(parameter)
node.outputs.push(parameter.__variable__);
}
*/
this._nodes.push(node);
if (outputs.length > 1) {
return outputs;
Expand Down Expand Up @@ -2890,8 +2923,8 @@ pytorch.Container.Zip = class {
}

trace() {
// this._trace = true;
if (this._trace) {
this._trace = false;
if (this._trace || this.format == 'TorchScript v1.4') {
this._inputs = [];
this._outputs = [];
this._nodes = [];
Expand All @@ -2914,7 +2947,25 @@ pytorch.Container.Zip = class {
}
}
if (this.data.forward) {
this.data.forward.__call__([ this.data, { __module__: 'torch', __name__: 'Tensor' } ]);
let args = [ this.data ]; // self
if (this.data.forward.__code__ && this.data.forward.__code__.parameters) {
for (const parameter of this.data.forward.__code__.parameters) {
if (parameter.name !== 'self' &&
parameter.parameterType.type === 'type' &&
parameter.parameterType.name.type === 'id' &&
parameter.parameterType.name.value === 'Tensor') {
this._inputs.push(parameter.name);
args.push({ __module__: 'torch', __name__: 'Tensor', __variable__: parameter.name, __origin__: 'trace-input' });
}
}
}
const result = this.data.forward.__call__(args);
const outputs = !Array.isArray(result) ? [ result ] : result;
for (const output of outputs) {
if (pytorch.Utility.isTensor(output)) {
this._outputs.push(output.__variable__);
}
}
return true;
}
else {
Expand Down
Loading

0 comments on commit 8990235

Please sign in to comment.