Skip to content

Commit

Permalink
fix for webgl lrn (microsoft#15236)
Browse files Browse the repository at this point in the history
fix issue that resulted in wrong results for lrn on webgpu
  • Loading branch information
guschmue authored Mar 30, 2023
1 parent 9f942e1 commit 4645726
Showing 1 changed file with 4 additions and 15 deletions.
19 changes: 4 additions & 15 deletions js/web/lib/onnxjs/backends/webgl/ops/lrn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,14 @@ const lrnProgramMetadata = {
inputTypes: [TextureType.unpacked]
};

function getOutputExpression(attributes: LrnAttributes): string {
let expression = `float(${attributes.bias}) + float(${attributes.alpha}) * square_sum`;
if (attributes.beta === 0.5) {
expression = `inversesqrt(${expression})`;
} else if (attributes.beta === 1.0) {
expression = `1.0/(${expression})`;
} else {
expression = `exp(log(${expression})) * float(-${attributes.beta})`;
}
return `x * ${expression}`;
}

function createLrnProgramInfo(inputs: Tensor[], attributes: LrnAttributes): ProgramInfo {
const C = inputs[0].dims[1];

const rank = inputs[0].dims.length;
const from = -Math.floor((attributes.size - 1) / 2);
const to = Math.ceil((attributes.size - 1) / 2);
const alpha = `float(${attributes.alpha}) / float(${attributes.size})`;
const bias = `float(${attributes.bias})`;
const beta = `float(${attributes.beta})`;

const shaderSource = `
float process(int indices[${rank}]) {
Expand All @@ -75,8 +65,7 @@ function createLrnProgramInfo(inputs: Tensor[], attributes: LrnAttributes): Prog
square_sum += j * j;
}
}
return ${getOutputExpression(attributes)};
return x / pow(${bias} + ${alpha} * square_sum, ${beta});
}`;
return {
...lrnProgramMetadata,
Expand Down

0 comments on commit 4645726

Please sign in to comment.