Skip to content

Commit

Permalink
Update to TorchScript 1.3 prototype (#281)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 21, 2019
1 parent 080bca8 commit 37dabf1
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 69 deletions.
225 changes: 156 additions & 69 deletions src/torchscript.js
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,48 @@ torchscript.Model = class {
torchscript.Graph = class {

constructor(metadata, host, container) {
if (container.model && container.model.mainModule) {
this._name = container.model.mainModule.name;
}
this._inputs = [];
this._outputs = [];
this._nodes = [];

let tensors = [];
let constants = [];

if (container.model && container.model.tensors) {
tensors = container.model.tensors.map((tensor) => {
return new torchscript.Tensor('json', tensor);
});
if (container.model) {
if (container.model.mainModule) {
this._name = container.model.mainModule.name;
/*
let queue = [ container.model.mainModule ];
while (queue.length > 0) {
let module = queue.shift();
module.__type__ = module.__type__ || 'torch.Module';
if (module.submodules) {
for (let submodule of module.submodules) {
module[submodule.name] = submodule;
submodule.__parent__ = module;
queue.push(submodule);
}
delete module.submodules;
}
if (module.parameters) {
for (let parameter of module.parameters) {
module[parameter.name] = parameter;
parameter.__type__ = parameter.__type__ || 'torch.Tensor';
}
delete module.parameters;
}
if (module.arguments) {
debugger;
}
}
debugger;
*/
}
if (container.model.tensors) {
tensors = container.model.tensors.map((tensor) => {
return new torchscript.Tensor('json', tensor);
});
}
}

if (container.constants) {
Expand Down Expand Up @@ -1035,43 +1063,75 @@ torchscript.Container = class {
this._version = JSON.parse(this._utf8Decoder.decode(versionEntry.data));

this._functionTable = new Map();

this._functionTable.set('annotate', function(type, value) {
return value;
});
this._functionTable.set('collections.OrderedDict', function(args) {
let obj = [];
obj.__setitem__ = function(key, value) {
obj.push({ key: key, value: value });
};
if (args) {
for (let arg of args) {
obj.__setitem__(arg[0], arg[1]);
}
}
return obj;
});
this._functionTable.set('int', function(/* tensor */) {
return 0; // TODO
});
this._functionTable.set('getattr', function(obj, name, defaultValue) {
if (Object.prototype.hasOwnProperty.call(obj, name)) {
return obj[name];
}
return defaultValue;
});
this._functionTable.set('annotate', function(type, value) {
// debugger;
return value;
});
this._functionTable.set('uninitialized', function(type) {
return ({ __type__: type.__typeref__ });
});
this._functionTable.set('unchecked_cast', function(type, value) {
return value;
});
this._functionTable.set('ops.prim.unchecked_unwrap_optional', function(value) {
return value;
});
this._functionTable.set('ops.prim.NumToTensor', function(value) {
return { __type__: 'torch.Tensor', value: value }; // TODO
});
this._functionTable.set('ops.quantized.conv_prepack', function(/* weight, bias, stride, padding, dilation, groups */) {
return { __type__: '__conv_prepack__' }; // TODO
});
this._functionTable.set('ops.quantized.linear_prepack', function(/* weight, bias */) {
return { __type__: '__linear_prepack__' }; // TODO
});
this._functionTable.set('collections.OrderedDict', function(args) {
let obj = [];
obj.__setitem__ = function(key, value) {
obj.push({ key: key, value: value });
};
if (args) {
for (let arg of args) {
obj.__setitem__(arg[0], arg[1]);
}

this._functionTable.set('ops.prim.RaiseException', function(message) {
throw new torchscript.Error(message);
});
this._functionTable.set('torch.__is__', function(left, right) {
if (left === null && right === null) {
return true;
}
return obj;
if ((left !== null && right === null) || (left === null && right !== null)) {
return false;
}
throw new torchscript.Error('Unknown expression type.');
});
this._functionTable.set('torch.__isnot__', function(left, right) {
if (left === null && right === null) {
return false;
}
if ((left !== null && right === null) || (left === null && right !== null)) {
return true;
}
throw new torchscript.Error('Unknown expression type.');
});
this._functionTable.set('torch.__not__', function(value) {
if (typeof value === 'boolean') {
return !value;
}
throw new torchscript.Error('Unknown expression type.');
});
this._functionTable.set('torch._unwrap_optional', function(value) {
return value; // TODO
});
this._functionTable.set('torch._utils._rebuild_tensor_v2', function(storage, storage_offset, size, stride, requires_grad, backward_hooks) {
return {
Expand All @@ -1096,18 +1156,34 @@ torchscript.Container = class {
backward_hooks: backward_hooks
};
});
this._functionTable.set('torch.jit._pickle.build_intlist', function(data) {
return data;
});
this._functionTable.set('torch.jit._pickle.build_tensorlist', function(data) {
return data;

this._functionTable.set('torch.dim', function(tensor) {
if (tensor && tensor.size) {
return tensor.size.length;
}
return 0; // TODO
});
this._functionTable.set('torch.eq', function(left, right) {
if (typeof left === 'string' && typeof right === 'string') {
return left === right;
}
if (typeof left === 'number' && typeof right === 'number') {
return left === right;
}
throw new torchscript.Error('Unknown expression type.');
});
this._functionTable.set('torch.gt', function(left, right) {
if (typeof left === 'number' && typeof right === 'number') {
return left > right;
}
throw new torchscript.Error('Unknown expression type.');
});
this._functionTable.set('torch.jit._pickle.build_intlist', function(data) {
return data;
});
this._functionTable.set('torch.jit._pickle.build_tensorlist', function(data) {
return data;
});
this._functionTable.set('torch.lt', function(left, right) {
if (typeof left === 'number' && typeof right === 'number') {
return left < right;
Expand All @@ -1120,25 +1196,21 @@ torchscript.Container = class {
}
throw new torchscript.Error('Unknown expression type.');
});
this._functionTable.set('torch.gt', function(left, right) {
this._functionTable.set('torch.ne', function(left, right) {
if (typeof left === 'number' && typeof right === 'number') {
return left > right;
return left !== right;
}
throw new torchscript.Error('Unknown expression type.');
});
this._functionTable.set('torch.__is__', function(left, right) {
if (left === null && right === null) {
return true;
}
if ((left !== null && right === null) || (left === null && right !== null)) {
return false;
}
// debugger;
throw new torchscript.Error('Unknown expression type.');
});
this._functionTable.set('torch.q_scale', function(/* tensor */) {
return -1; // TODO
});
this._functionTable.set('torch.t', function(tensor) {
return tensor;
});
this._functionTable.set('uninitialized', function(type) {
return ({ __type__: type.__typeref__ });
});
this._constructorTable = new Map();
this._constructorTable.set('torch.ByteStorage', function (size) {
this.size = size; this.dataTypeSize = 1; this.dataType = 'uint8';
Expand Down Expand Up @@ -1325,28 +1397,37 @@ torchscript.Container = class {
this._tensors = null;
}

_trace(statement /*, obj, locals */) {
if (this._tensors) {
if (statement.type === '=') {
const target = statement.target;
const expression = statement.expression;
if (target.type === 'id' && expression.type === 'call') {
let name = torchscript.Utility.target(expression.target);
let namespace = 'torch.';
if (name.startsWith(namespace)) {
switch (name) {
case 'torch.conv2d':
case 'torch.relu_':
case 'torch.relu':
case 'torch.max_pool2d':
case 'torch.view':
return true;
}
}
}
_trace(name /*, args */) {
let namespace = 'torch.';
if (name.startsWith(namespace)) {
switch (name) {
case 'torch.conv2d':
return { __type__: 'Tensor', size: [ 0, 0, 0, 0 ] }; // TODO
case 'torch._convolution':
case 'torch.addmm':
case 'torch.relu_':
case 'torch.relu':
case 'torch.max_pool2d':
case 'torch.view':
case 'torch.matmul':
case 'torch.flatten':
case 'torch.add_':
case 'torch.slice':
case 'torch.log_softmax':
case 'torch.dropout':
case 'torch.adaptive_avg_pool2d':
case 'torch.batch_norm':
case 'torch.cat':
return { __type__: 'Tensor' }; // TODO
case 'torch.max_pool2d_with_indices':
return [ { __type__: 'Tensor' }, { __type__: 'Tensor' } ]; // TODO
case 'torch.list_with_default':
return [0]; // TODO
case 'torch.size':
return 0; // TODO
}
}
return false;
throw new torchscript.Error("Unknown symbol '" + name + "'.");
}

_invoke(name, args) {
Expand All @@ -1365,7 +1446,7 @@ torchscript.Container = class {
this._construct(type, obj, args);
return obj;
}
throw new torchscript.Error("Unknown symbol '" + name + "'.");
return this._trace(name, args);
}

_construct(type, obj, args) {
Expand Down Expand Up @@ -1393,12 +1474,12 @@ torchscript.Container = class {
let statements = Array.prototype.slice.call(block.statements);
while (statements.length > 0) {
const statement = statements.shift();
if (this._trace(statement, obj, locals)) {
continue;
}
switch (statement.type) {
case 'pass': {
break;
}
case 'return': {
return this._expression(statement.expression);
return this._expression(statement.expression, obj, locals);
}
case 'def': {
const method = statement;
Expand Down Expand Up @@ -1427,6 +1508,10 @@ torchscript.Container = class {
}
throw new torchscript.Error("Unknown condition '" + condition + "'.");
}
case 'call': {
this._expression(statement, obj, locals);
break;
}
default: {
throw new torchscript.Error("Unknown statement '" + statement + "'.");
}
Expand Down Expand Up @@ -1487,8 +1572,10 @@ torchscript.Container = class {
const index = this._expression(expression.arguments.value[0], obj, locals);
return locals[expression.target.value][index];
}
if (expression.target.value === 'List' && expression.arguments.value.every((item) => item.type === 'id')) {
return { __typeref__: expression.target.value + '[' + expression.arguments.value.map((item) => item.value).join(',') + ']' };
if (expression.target.value === 'List' || expression.target.value === 'Optional') {
if (expression.arguments.value.every((item) => item.type === 'id')) {
return { __typeref__: expression.target.value + '[' + expression.arguments.value.map((item) => item.value).join(',') + ']' };
}
}
}
break;
Expand Down
7 changes: 7 additions & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -5032,6 +5032,13 @@
"format": "TorchScript v1",
"link": "https://pytorch.org/docs/stable/torchvision/models.html"
},
{
"type": "torchscript",
"target": "blitz_cifar10_tutorial.pt",
"source": "https://github.com/lutzroeder/netron/files/3748500/blitz_cifar10_tutorial.zip[blitz_cifar10_tutorial.pt]",
"format": "TorchScript v1",
"link": "https://github.com/lutzroeder/netron/issues/281"
},
{
"type": "torchscript",
"target": "cruise_cutin_vehicle_model.pt",
Expand Down

0 comments on commit 37dabf1

Please sign in to comment.