Skip to content

Commit

Permalink
Add PyTorch null check
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Nov 3, 2019
1 parent 13ef12f commit 7376a62
Showing 1 changed file with 30 additions and 28 deletions.
58 changes: 30 additions & 28 deletions src/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -891,36 +891,38 @@ pytorch.Graph = class {
}

for (let module of parent._modules) {
let node;
switch (module.value.__type__) {
case 'torch.nn.modules.container.Sequential':
groups.push(module.key);
inputs = this._loadModule(module.value, module_source_map, groups, inputs);
groups.pop(module.key);
break;
case 'torchvision.models.densenet._Transition':
case 'torchvision.models.resnet.Bottleneck':
case 'torchvision.models.densenet._DenseBlock':
case 'torchvision.models.densenet._DenseLayer':
case 'torchvision.models.inception.BasicConv2d':
case 'torchvision.models.inception.InceptionAux':
case 'torchvision.models.inception.InceptionA':
case 'torchvision.models.inception.InceptionB':
case 'torchvision.models.inception.InceptionC':
case 'torchvision.models.inception.InceptionD':
case 'torchvision.models.inception.InceptionE':
groups.push(module.key);
node = this._createNode(groups, module.key, module.value, inputs, this._littleEndian);
inputs = [ node.name ];
groups.pop(module.key);
break;
default:
node = this._createNode(groups, module.key, module.value, inputs);
inputs = [ node.name ];
break;
if (module && module.value) {
switch (module.value.__type__) {
case 'torch.nn.modules.container.Sequential':
groups.push(module.key);
inputs = this._loadModule(module.value, module_source_map, groups, inputs);
groups.pop(module.key);
break;
case 'torchvision.models.densenet._Transition':
case 'torchvision.models.resnet.Bottleneck':
case 'torchvision.models.densenet._DenseBlock':
case 'torchvision.models.densenet._DenseLayer':
case 'torchvision.models.inception.BasicConv2d':
case 'torchvision.models.inception.InceptionAux':
case 'torchvision.models.inception.InceptionA':
case 'torchvision.models.inception.InceptionB':
case 'torchvision.models.inception.InceptionC':
case 'torchvision.models.inception.InceptionD':
case 'torchvision.models.inception.InceptionE': {
groups.push(module.key);
const node = this._createNode(groups, module.key, module.value, inputs, this._littleEndian);
inputs = [ node.name ];
groups.pop(module.key);
break;
}
default: {
const node = this._createNode(groups, module.key, module.value, inputs);
inputs = [ node.name ];
break;
}
}
}
}

return inputs;
}

Expand Down

0 comments on commit 7376a62

Please sign in to comment.