Skip to content

Commit

Permalink
PyTorch qint8 support (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Nov 9, 2019
1 parent eb392ec commit d13489c
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 2 deletions.
23 changes: 22 additions & 1 deletion src/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ pytorch.ModelFactory = class {
constructorTable['torch.DoubleStorage'] = function (size) {
this.size = size; this.dataTypeSize = 8; this.dataType = 'float64';
};
constructorTable['torch.QInt8Storage'] = function (size) {
this.size = size; this.dataTypeSize = 1; this.dataType = 'qint8';
};
constructorTable['torch.FloatTensor'] = function () {
this.__setstate__ = function(state) {
this.storage = state[0];
Expand Down Expand Up @@ -364,6 +367,18 @@ pytorch.ModelFactory = class {
obj.backward_hooks = backward_hooks;
return obj;
};
functionTable['torch._utils._rebuild_qtensor'] = function(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks) {
return {
__type__: storage.__type__.replace('Storage', 'Tensor'),
storage: storage,
storage_offset: storage_offset,
size: size,
stride: stride,
quantizer_params: quantizer_params,
requires_grad:requires_grad,
backward_hooks: backward_hooks
};
};
functionTable['numpy.core.multiarray.scalar'] = function(dtype, rawData) {
let data = rawData;
if (rawData.constructor !== Uint8Array) {
Expand Down Expand Up @@ -409,7 +424,7 @@ pytorch.ModelFactory = class {
if (constructor) {
constructor.apply(obj, args);
}
else if (name && unknownNameMap.has(name)) {
else if (name && !unknownNameMap.has(name)) {
unknownNameMap.add(name);
if (knownPackageMap.has(name.split('.').shift())) {
host.exception(new pytorch.Error("Unknown function '" + name + "' in '" + identifier + "'."), false);
Expand Down Expand Up @@ -696,6 +711,11 @@ pytorch.ModelFactory = class {
state.id = key;
state.name = split.pop();
state.value = obj[key];
if (state.value && state.value.__type__ === 'torch.nn.parameter.Parameter') {
if (pytorch.ModelFactory._isTensor(state.value.data)) {
state.value = state.value.data;
}
}
if (!pytorch.ModelFactory._isTensor(state.value)) {
return null;
}
Expand Down Expand Up @@ -1258,6 +1278,7 @@ pytorch.Tensor = class {
context.index++;
context.count++;
break;
case 'qint8':
case 'int8':
results.push(context.dataView.getInt8(context.index, this._littleEndian));
context.index++;
Expand Down
4 changes: 3 additions & 1 deletion src/view.js
Original file line number Diff line number Diff line change
Expand Up @@ -948,9 +948,10 @@ view.View = class {
this._host.save('NumPy Array', 'npy', defaultPath, (file) => {
try {
const dataTypeMap = new Map([
[ 'float16', 'f2' ], [ 'float32', 'f4' ], [ 'float64', 'f8' ],
[ 'int8', 'i1' ], [ 'int16', 'i2'], [ 'int32', 'i4' ], [ 'int64', 'i8' ],
[ 'uint8', 'u1' ], [ 'uint16', 'u2' ], [ 'uint32', 'u4' ], [ 'uint64', 'u8' ],
[ 'float16', 'f2' ], [ 'float32', 'f4' ], [ 'float64', 'f8' ]
[ 'qint8', 'i1' ]
]);
let array = new numpy.Array();
array.shape = tensor.type.shape.dimensions;
Expand Down Expand Up @@ -1132,6 +1133,7 @@ class ArchiveContext {
}

class ArchiveError extends Error {

constructor(message) {
super(message);
this.name = 'Error loading archive.';
Expand Down
7 changes: 7 additions & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -3854,6 +3854,13 @@
"format": "PyTorch",
"link": "https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py"
},
{
"type": "pytorch",
"target": "resnet18_fbgemm_16fa66dd.pth",
"source": "https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
"format": "PyTorch",
"link": "https://github.com/pytorch/vision/blob/master/torchvision/models/quantization/resnet.py"
},
{
"type": "pytorch",
"target": "resnet18_large_blocks.pth",
Expand Down

0 comments on commit d13489c

Please sign in to comment.