diff --git a/src/ops/linalg_ops.ts b/src/ops/linalg_ops.ts index 7a15b94e16..c401a2dd14 100644 --- a/src/ops/linalg_ops.ts +++ b/src/ops/linalg_ops.ts @@ -20,15 +20,21 @@ */ import {ENV} from '../environment'; -import {dispose} from '../globals'; +import {range, scalar} from './tensor_ops'; import {Tensor, Tensor1D, Tensor2D} from '../tensor'; +import {TensorLike, TypedArray} from '../types'; +import {add, mul, sub} from './binary_ops'; +import {logicalAnd} from './logical_ops'; +import {complex, real, imag} from './complex_ops'; import {assert} from '../util'; -import {eye, squeeze, stack, unstack} from './array_ops'; +import {convertToTensor} from '../tensor_util_env'; +import {squeeze, stack} from './array_ops'; import {split} from './concat_split'; +import {matMul} from './matmul'; import {norm} from './norm'; import {op} from './operation'; import {sum} from './reduction_ops'; -import {tensor2d} from './tensor_ops'; +import {upcastType} from '../types'; /** * Gram-Schmidt orthogonalization. @@ -106,12 +112,770 @@ function gramSchmidt_(xs: Tensor1D[]|Tensor2D): Tensor1D[]|Tensor2D { } } +/** + * Conjugates a tensor of matrices and then transposes the last two dimensions. + * The adjoint is also commonly known as the Hermitian Transpose. + * + * ```js + * const a = tf.tensor3d([[[1, 2], + * [3, 4]], + * [[5, 6], + * [7, 8]]]); + * const aT = tf.linalg.adjoint(a); + * aT.print(); + * // Output: + * // [[[1, 3], + * // [2, 4]], + * // [[5, 7], + * // [6, 8]]] + * ``` + * + * @param a Tensor of shape [...,M,N]. The tensor of matrices that is to be + * tranposed. + * + * @returns Tensor of shape [...,N,M]. The transpose of `a`. + */ +/** + * @doc {heading:'Operations', + * subheading:'Linear Algebra', + * namespace:'linalg'} + */ +function adjoint_( a: T|TensorLike ): T +{ + let $a = convertToTensor(a,'a','bandPart'); + + if( $a.rank < 2 ) { + throw new Error(`adjoint(): a.rank = ${$a.rank} < 2.`); + } + + const axes = Array.from( $a.shape, (_,i) => i ); + axes[axes.length-2] = axes.length-1; + axes[axes.length-1] = axes.length-2; + + if( $a.dtype.startsWith('complex') ) { + $a = complex( real($a), imag($a).neg() ); // <- TODO: implement tf.conj + } + + return $a.transpose(axes); +} + +/** + * Copies a tensor of matrices, setting everything outside a central band + * in each matrix to zero. Does not yet support Infinity or NaN entries. + * + * ```js + * const a = tf.tensor2d([[11, 12, 13, 14], + * [21, 22, 23, 24], + * [31, 32, 33, 34], + * [41, 42, 43, 44]]); + * tf.linalg.bandPart(a,0,2); + * // Output: + * // [[11, 12, 13, 0], + * // [ 0, 22, 23, 24], + * // [ 0, 0, 33, 34], + * // [ 0, 0, 0, 44]] + * + * tf.linalg.bandPart(a,1,-1); + * // Output: + * // [[11, 12, 13, 14], + * // [21, 22, 23, 24], + * // [ 0, 32, 33, 34], + * // [ 0, 0, 43, 44]] + * ``` + * + * @param a Tensor of matrices from which the band part is extracted. + * @param numLower The number of subdiagonal lines to be copied. + * If set to `-1`, all entries below the diagonal are + * copied. + * @param numUpper The number of superdiagonal lines to be copied. + * If set to `-1`, all entries above the diagonal are + * copied. + */ +/** + * @doc {heading:'Operations', + * subheading:'Linear Algebra', + * namespace:'linalg'} + */ +function bandPart_( + a: T|TensorLike, numLower: number, numUpper: number +): T +{ + if( numLower%1 !== 0 ){ + throw new Error(`bandPart(): numLower=${numLower} not an integer.`); + } + if( numUpper%1 !== 0 ){ + throw new Error(`bandPart(): numUpper=${numUpper} not an integer.`); + } + + return ENV.engine.tidy( () => { + const $a = convertToTensor(a,'a','bandPart'); + + if( $a.rank < 2 ) { + throw new Error(`bandPart(): a.rank = ${$a.rank} < 2.`); + } + + if( ! isFinite($a.abs().max().dataSync()[0]) ) { + throw new Error(`bandPart(): NaN and Infinity not yet supported.`); + } + + const [M,N] = $a.shape.slice(-2); + + if( !(numLower <= M) ) { + throw new Error(`bandPart() check failed: numLower <= #rows.` ); + } + if( !(numUpper <= N) ) { + throw new Error(`bandPart() check failed: numUpper <= #columns.`); + } + + if( numLower < 0 ) { numLower = M; } + if( numUpper < 0 ) { numUpper = N; } + + const i = range(0,M, 1, 'int32').reshape([-1,1]), + j = range(0,N, 1, 'int32'); + + const inBand = logicalAnd( + sub(i,j).lessEqual( scalar(numLower,'int32') ), + sub(j,i).lessEqual( scalar(numUpper,'int32') ) + ).cast($a.dtype); + + return mul($a,inBand); + }); +} + +function triangularSolveKernel( + l: Tensor, y: Tensor, lower: boolean, adjoint: boolean +): Tensor +{ + if( ! l.dtype.startsWith('float') ) { + throw new Error(`triangularSolve(): l.dtype=${l.dtype} not supported.`); + } + if( ! y.dtype.startsWith('float') ) { + throw new Error(`triangularSolve(): y.dtype=${y.dtype} not supported.`); + } + if( l.rank < 2 ) { + throw new Error('triangularSolve(): l must be at least 2D.'); + } + if( y.rank < 2 ) { + throw new Error('triangularSolve(): y must be at least 2D.'); + } + if( l.rank !== y.rank ) { + throw new Error('triangularSolve(): l and y must have same rank.'); + } + for( let i=l.rank-2; i-- > 0; ) { + if( l.shape[i] !== y.shape[i] ) { + throw new Error('triangularSolve(): leading dimensions do not match.'); + } + } + + const [N,M] = l.shape.slice(-2), + [I,J] = y.shape.slice(-2); + if( N !== M ) { + throw new Error('triangularSolve(): Last two axes of L not square.'); + } + if( I !== M ) { + throw new Error('triangularSolve(): L and y do not match.'); + } + + const + rank = l.rank, + xShape = Array.from(l.shape); + xShape[rank-2] = I; + xShape[rank-1] = J; + + // GENERATE RESULT DATA + const + dtype = 'float32', +// dtype = ( l.dtype === 'float64' || +// y.dtype === 'float64' ) ? 'float64' : 'float32', + // tslint:disable + DTypeArray = Float32Array, + // tslint:enable +// DTypeArray = dtype === 'float32' ? Float32Array +// : Float64Array, + L = l.dataSync(), + X = DTypeArray.from( y.dataSync() ) as TypedArray; + l = undefined; + y = undefined; + + for( let lOff = 0, + xOff = 0; xOff < X.length; xOff += N*J, + lOff += N*N ) + { + if( ! adjoint ) + { + if(lower) + { // FORWARD SUBSTITUTION + for( let i=0; i < I; i++ ) { + for( let k=0; k < i; k++ ) { + for( let j=0; j < J; j++ ) { + X[xOff + J*i+j] -= L[lOff + N*i+k] * X[xOff + J*k+j]; + }} + + for( let j=0; j < J; j++ ) { + X[xOff + J*i+j] /= L[lOff + N*i+i]; + } + } + } + else + { // BACKWARD SUBSTITUTION + for( let i=I; i-- > 0; ) { + for( let j=J; j-- > 0; ) { + X[xOff + J*i+j] /= L[lOff + N*i+i]; + } + + for( let k=i; k-- > 0; ) { + for( let j=J; j-- > 0; ) { + X[xOff + J*k+j] -= L[lOff + N*k+i] * X[xOff + J*i+j]; + }} + } + } + } + else + { + if(lower) + { // BACKWARD SUBSTITUTION (TRANSPOSED) + for( let i=I; i-- > 0; ) { + for( let j=J; j-- > 0; ) { + X[xOff + J*i+j] /= L[lOff + N*i+i]; + } + + for( let k=i; k-- > 0; ) { + for( let j=J; j-- > 0; ) { + X[xOff + J*k+j] -= L[lOff + N*i+k] * X[xOff + J*i+j]; + }} + } + } + else + { // FORWARD SUBSTITUTION (TRANSPOSED) + for( let i=0; i < I; i++ ) { + for( let k=0; k < i; k++ ) { + for( let j=0; j < J; j++ ) { + X[xOff + J*i+j] -= L[lOff + N*k+i] * X[xOff + J*k+j]; + }} + + for( let j=0; j < J; j++ ) { + X[xOff + J*i+j] /= L[lOff + N*i+i]; + } + } + } + } + } + + return Tensor.make(xShape,{values: X},dtype); +} + +/** + * Solves a triangular linear equation system (LES). + * + * @param l The triangular matrix of the LES. + * @param y The right-hand-side of the LES. + * @param lower If set to `true`, `l` is interpreted as lower triangular + * matrix. The strict upper triangular entries are ignore. + * If set to `false`, `l` is interpreted as upper triangular + * matrix and the strict lower triangular entries are ignored. + * @param adjoint If set to `true`, the hermitian transpose of `l` is used in + * the LES. + * + * @returns The solution of one of the following LES: + *
+ *
lower=false, adjoint=false
tril(l) ∙x == y + *
lower=true, adjoint=false
triu(l) ∙x == y + *
lower=false, adjoint=true
tril(l)ᴴ∙x == y + *
lower=true, adjoint=true
triu(l)ᴴ∙x == y + *
+ */ +/** + * @doc {heading:'Operations', + * subheading:'Linear Algebra', + * namespace:'linalg'} + */ +function triangularSolve_( + l: Tensor|TensorLike, y: Tensor|TensorLike, lower=true, adjoint=false +): Tensor +{ + // FIXME: if `l` is singular the right hand side could be + // checked for 0 and then some/any solution could be used + +// let [$l,$y] = broadcastMatrices( +// convertToTensor(l,'l','triangularSolve'), +// convertToTensor(y,'y','triangularSolve') +// ); + let $l = convertToTensor(l,'l','triangularSolve'), + $y = convertToTensor(y,'y','triangularSolve'); + l=undefined; + y=undefined; + if( $l.rank < 2 ){ + throw new Error(`triangularSolve(): l.rank must be at least 2.`); + } + if( $y.rank < 2 ){ + throw new Error(`triangularSolve(): y.rank must be at least 2.`); + } + + const dtype = upcastType($l.dtype, $y.dtype); + if( $l.dtype !== dtype ) { $l = $l.cast(dtype); } + if( $y.dtype !== dtype ) { $y = $y.cast(dtype); } + + // WHERE THE BACKPROP COMES FROM: + // x = L⁻¹∙y + // => dx = d(L⁻¹)∙y + L⁻¹∙dy = L⁻¹∙dy - L⁻¹∙dL∙L⁻¹∙y = L⁻¹∙dy - L⁻¹∙dL∙x + // => df = tr( (∂f/∂x)∙dxᵀ ) + // = tr( (∂f/∂x)∙dyᵀ∙L⁻ᵀ ) - tr( (∂f/∂x)∙yᵀ∙L⁻ᵀ∙dLᵀ∙L⁻ᵀ ) + // = tr( (∂f/∂x)ᵀ∙L⁻¹∙dy ) - tr( (∂f/∂x)∙yᵀ∙L⁻ᵀ∙(L⁻¹∙dL)ᵀ ) + // = tr( L⁻ᵀ∙(∂f/∂x) ∙dyᵀ) - tr( L⁻¹∙y∙(∂f/∂x)ᵀ∙ L⁻¹∙dL ) + // = tr( L⁻ᵀ∙(∂f/∂x) ∙dyᵀ) - tr( x∙(∂f/∂x)ᵀ∙ L⁻¹∙dL ) + // = tr( L⁻ᵀ∙(∂f/∂x) ∙dyᵀ) - tr( L⁻ᵀ ∙(∂f/∂x) ∙ xᵀ ∙dLᵀ ) + // => ∂f/∂y = L⁻ᵀ∙(∂f/∂x) + // ∂f/∂L = -L⁻ᵀ∙(∂f/∂x)∙xᵀ = ∂f/∂L = -(∂f/∂y)∙xᵀ + + // tslint:disable + // SEE: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_grad.py#L218 + // tslint:enable + return ENV.engine.runKernel( + (backend,saveFn) => { + const x = triangularSolveKernel($l,$y,lower,adjoint); + saveFn(x); + return x; + }, + {$l,$y}, + (dx,[x]) => { + const dy = triangularSolve($l, dx, lower, !adjoint); + return { + $l: () => { + let dl = adjoint ? matMul( x, dy, false, true) + : matMul(dy, x, false, true); + dl = dl.neg(); + dl = lower ? bandPart(dl,-1, 0) + : bandPart(dl, 0,-1); + return dl; + }, + $y: () => dy + }; + } + ); +} + +/** Computes the economic QR Decomposition. + */ +function qrEcoDecompKernel( a: Tensor ): [Tensor,Tensor] +{ + if( a.rank < 2 ) { + throw new Error(`qrEco(): input must have rank >= 2, got rank ${a.rank}.`); + } + if( a.dtype !== 'float32' ) { + throw new Error(`qrEco(): only float32 currently supported as dtype.`); + } + if( a.shape[a.rank-2] < a.shape[a.rank-1] ) { + throw new Error(`qrEco(): a must have at least as many rows as columns`); + } + + const dtype = 'float32', + // tslint:disable + DTypeArray = Float32Array, + // tslint:enable + qShape = Array.from( a.shape ), + rShape = Array.from( qShape ), + [M,N] = qShape.slice(-2); + rShape[rShape.length-2] = N; + Object.freeze(qShape); + Object.freeze(rShape); + + const Q = DTypeArray.from( a.dataSync() ); a = undefined; + const R = new DTypeArray( Q.length/M*N ), + cs = new DTypeArray( 2*M*N - N*(N+1) );// <- MEMOIZE ROTATIONS + + for( + let rOff=0, + qOff=0; qOff < Q.length; qOff += M*N, + rOff += N*N + ) + { + let csi = 0; + + for( let i=1; i < M; i++ ) { const J = Math.min(i,N); + for( let j=0; j < J; j++ ) + { // DETERMINE GIVENS ROTATION cos AND sin + const rIJ = Q[qOff + N*i+j]; if( 0.0 === rIJ ) {cs[csi++]=1.0; + cs[csi++]=0.0; continue;} + const rJJ = Q[qOff + N*j+j], + norm = Math.hypot(rJJ,rIJ), + c = rJJ / norm, + s = rIJ / norm; + cs[csi++] = c; + cs[csi++] = s; + Q[qOff + N*j+j] = norm; + Q[qOff + N*i+j] = 0; + // ROTATE ROWS IN R (WHICH IS CURRENTLY STORED IN Q) + for( let k=j; ++k < N; ) + { const rJK = Q[qOff + N*j+k], + rIK = Q[qOff + N*i+k]; + Q[qOff + N*j+k] = s*rIK + c*rJK; + Q[qOff + N*i+k] = c*rIK - s*rJK; + } + }} + + assert( csi === cs.length, `WTF: ${csi} !== ${cs.length}` ); + + // COPY R FROM Q -> R + for( let i=0; i < N; i++ ) { + for( let j=i; j < N; j++ ) { + R[rOff + N*i+j] = Q[qOff + N*i+j]; + Q[qOff + N*i+j] = i !== j ? 0.0 : 1.0; + }} + + // COMPUTE Q + for( let i=M; --i > 0; ) { const J = Math.min(i,N); + for( let j=J; j-- > 0; ) + { const s = cs[--csi], + c = cs[--csi]; + // ROTATE ROWS IN Q + for( let k=N; k-- > 0; ) + { const qJK = Q[qOff + N*j+k], + qIK = Q[qOff + N*i+k]; + Q[qOff + N*j+k] = c*qJK - s*qIK; + Q[qOff + N*i+k] = s*qJK + c*qIK; + } + }} + + assert( csi === 0, `WTF: ${csi} !== 0` ); + } + + const q = Tensor.make(qShape, { values: Q }, dtype); + const r = Tensor.make(rShape, { values: R }, dtype); + + return [q,r]; +} + +/** Computes the full QR Decomposition an memoizes the + * Givens rotation angles in the process. + */ +function qrFullDecompKernel( a: Tensor ): [Tensor,Tensor,Tensor] +{ + if( a.rank < 2 ) { + throw new Error(`qrEco(): input must have rank >= 2, got rank ${a.rank}.`); + } + if( a.dtype !== 'float32' ) { + throw new Error(`qrEco(): only float32 currently supported as dtype.`); + } + + const dtype = 'float32', + // tslint:disable + DTypeArray = Float32Array, + // tslint:enable + rShape = Array.from( a.shape ), + qShape = Array.from( a.shape ), + [M,N] = a.shape.slice(-2), + R = DTypeArray.from( a.dataSync() ); + a = undefined; + const L = Math.min(M,N), + Q = new DTypeArray( R.length/N*M ), + CS = new DTypeArray( R.length/N/M * 2 * ( + (L*(L-1) >>> 1) + Math.max(0,M-N)*N + )); + qShape[qShape.length-1] = M; + Object.freeze(qShape); + Object.freeze(rShape); + + let l = 0; + for( let qOff=0, + rOff=0; qOff < Q.length; qOff += M*M, + rOff += M*N ) + { + // INIT Q TO IDENTITY + for( let i=0; i < M; i++ ) { Q[qOff + M*i+i] = 1; } + + // BEGIN QR DECOMPOSITION + for( let i=1; i < M; i++ ) { const J = Math.min(i,N); + for( let j=0; j < J; j++ ) + { + // DETERMINE GIVENS ROTATION cos AND sin + const rIJ = R[rOff + N*i+j]; if( 0.0 === rIJ ) { CS[l++]=1.0; + CS[l++]=0.0; continue; } + const rJJ = R[rOff + N*j+j], + norm = Math.hypot(rJJ,rIJ), + c = rJJ / norm, + s = rIJ / norm; + CS[l++] = c; + CS[l++] = s; + R[rOff + N*j+j] = norm; + R[rOff + N*i+j] = 0; + // ROTATE ROWS IN R + for( let k=j; ++k < N; ) + { const rJK = R[rOff + N*j+k], + rIK = R[rOff + N*i+k]; + R[rOff + N*j+k] = s*rIK + c*rJK; + R[rOff + N*i+k] = c*rIK - s*rJK; + } + // ROTATE ROWS IN Qᵀ + for( let k=0; k <= i; k++ ) + { const qJK = Q[qOff + M*j+k], + qIK = Q[qOff + M*i+k]; + Q[qOff + M*j+k] = s*qIK + c*qJK; + Q[qOff + M*i+k] = c*qIK - s*qJK; + } + }} // END QR DECOMPOSITION + + // TRANSPOSE Q (was transposed for cache locality) + for( let i=0; i < M; i++ ) { + for( let j=0; j < i; j++ ) { + const qIJ = Q[qOff + M*i+j]; + Q[qOff + M*i+j] = Q[qOff + M*j+i]; + Q[qOff + M*j+i] = qIJ; + }} + } + assert( l === CS.length, `WTF: ${l} != ${CS.length}` ); + + const q = Tensor.make(qShape, {values: Q}, dtype); + const r = Tensor.make(rShape, {values: R}, dtype); + const cs = Tensor.make([CS.length], {values: CS}, dtype); + + return [q,r,cs]; +} + +/** Computes the backpropagation full QR Decomposition using + * memoized Givens rotation angles in the process. + */ +function qrFullBackpropKernel( + q: Tensor, dq: Tensor, r: Tensor, dr: Tensor, cs: Tensor +): Tensor +{ + if( q.rank !== dq.rank ) { + throw new Error( + `qrFullBackprop(): q.rank == ${q.rank} != ${dq.rank} == dq.rank` + ); + } + if( q.rank !== dr.rank ) { + throw new Error( + `qrFullBackprop(): q.rank == ${q.rank} != ${dr.rank} == dr.rank` + ); + } + if( q.rank !== r.rank ) { + throw new Error( + `qrFullBackprop(): q.rank == ${q.rank} != ${ r.rank} == r.rank` + ); + } + + if( cs.rank !== 1 ) { + throw new Error(`qrFullBackprop(): cs.rank == ${cs.rank} != 1`); + } + + const rank = q.rank; + + if( rank < 2 ) { + throw new Error( + `qrFullBackprop(): input must have rank >= 2, got rank ${rank}.` + ); + } + + for( let i=rank-2; i-- > 0; ) + { + if( q.shape[i] !== dq.shape[i] ) { + throw new Error( + 'qrFullBackprop(): ' + + `q.shape[${i}] == ${q.shape[i]} != ${dq.shape[i]} == dq.shape[${i}]` + ); + } + if( q.shape[i] !== dr.shape[i] ) { + throw new Error( + 'qrFullBackprop(): ' + + `q.shape[${i}] == ${q.shape[i]} != ${dr.shape[i]} == dr.shape[${i}]` + ); + } + if( q.shape[i] !== r.shape[i] ) { + throw new Error( + 'qrFullBackprop(): ' + + `q.shape[${i}] == ${q.shape[i]} != ${ r.shape[i]} == r.shape[${i}]` + ); + } + } + + if( q.shape[rank-2] !== q.shape[rank-1] ) { + throw new Error( + 'qrFullBackprop(): ' + + `q.shape[-2] == ${q.shape[rank-2]} != ${ q.shape[rank-1]} == q.shape[-1]` + ); + } + if( q.shape[rank-2] !== dq.shape[rank-1] ) { + throw new Error( + 'qrFullBackprop(): ' + + `q.shape[-2] == ${q.shape[rank-2]} != ${dq.shape[rank-1]} == dq.shape[-1]` + ); + } + if( q.shape[rank-2] !== dq.shape[rank-2] ) { + throw new Error( + 'qrFullBackprop(): ' + + `q.shape[-2] == ${q.shape[rank-2]} != ${dq.shape[rank-2]} == dq.shape[-2]` + ); + } + if( r.shape[rank-2] !== q.shape[rank-1] ) { + throw new Error( + 'qrFullBackprop(): ' + + `r.shape[-2] == ${r.shape[rank-2]} != ${ q.shape[rank-1]} == q.shape[-1]` + ); + } + if( r.shape[rank-1] !== dr.shape[rank-1] ) { + throw new Error( + 'qrFullBackprop(): ' + + `r.shape[-1] == ${r.shape[rank-1]} != ${dr.shape[rank-1]} == dr.shape[-1]` + ); + } + if( r.shape[rank-2] !== dr.shape[rank-2] ) { + throw new Error( + 'qrFullBackprop(): ' + + `r.shape[-2] == ${r.shape[rank-2]} != ${dr.shape[rank-2]} == dr.shape[-2]` + ); + } + + if( q.dtype !== dq.dtype ) { + throw new Error( + `qrFullBackprop(): q.dtype == ${q.dtype} != ${dq.dtype} == dq.dtype` + ); + } + if( q.dtype !== dr.dtype ) { + throw new Error( + `qrFullBackprop(): q.dtype == ${q.dtype} != ${dr.dtype} == dr.dtype` + ); + } + if( q.dtype !== r.dtype ) { + throw new Error( + `qrFullBackprop(): q.dtype == ${q.dtype} != ${r.dtype} == r.dtype` + ); + } + if( q.dtype !== cs.dtype ) { + throw new Error( + `qrFullBackprop(): q.dtype == ${q.dtype} != ${cs.dtype} == cs.dtype` + ); + } + + if( q.dtype !== 'float32' ) { + throw new Error( + `qrFullBackprop(): only float32 currently supported as dtype.` + ); + } + + const dtype ='float32', + // tslint:disable + DTypeArray = Float32Array, + // tslint:enable + dAShape = Array.from( r.shape ), + [M,N] = dAShape.slice(-2); + const Q = DTypeArray.from( q.dataSync() ); q = undefined; + const dQ = DTypeArray.from( dq.dataSync() ); dq = undefined; + const R = DTypeArray.from( r.dataSync() ); r = undefined; + const dR = DTypeArray.from( dr.dataSync() ); dr = undefined; + const CS = cs.dataSync(); + Object.freeze(dAShape); + + let l = CS.length; + for( let rOff=R.length, + qOff=Q.length; qOff > 0; ) + { + qOff -= M*M; + rOff -= M*N; + + // TRANSPOSE Q (for cache locality) + for( let i=0; i < M; i++ ) { + for( let j=0; j < i; j++ ) { + const qIJ = Q[qOff + M*i+j]; + Q[qOff + M*i+j] = Q[qOff + M*j+i]; + Q[qOff + M*j+i] = qIJ; + }} + + // TRANSPOSE dQ (for cache locality) + for( let i=0; i < M; i++ ) { + for( let j=0; j < i; j++ ) { + const dQij = dQ[qOff + M*i+j]; + dQ[qOff + M*i+j] = dQ[qOff + M*j+i]; + dQ[qOff + M*j+i] = dQij; + }} + + // BEGIN QR DECOMPOSITION + for( let i=M; --i > 0; ) { const J = Math.min(i,N); + for( let j=J; j-- > 0; ) + { + // DETERMINE GIVENS ROTATION cos AND sin + const s = CS[--l]; if( 0 === s ) { continue; } + const c = CS[--l], + norm = R[rOff + N*j+j]; + + // ROTATE ROWS IN R + for( let k=j; k < N; k++ ) + { const rJK = R[rOff + N*j+k], + rIK = R[rOff + N*i+k]; + R[rOff + N*j+k] = c*rJK - s*rIK; + R[rOff + N*i+k] = s*rJK + c*rIK; + } + + // ROTATE ROWS IN Qᵀ + for( let k=0; k <= i; k++ ) + { const qJK = Q[qOff + M*j+k], + qIK = Q[qOff + M*i+k]; + Q[qOff + M*j+k] = c*qJK - s*qIK; + Q[qOff + M*i+k] = s*qJK + c*qIK; + } + + const rIJ = R[rOff + N*i+j] / norm, + rJJ = R[rOff + N*j+j] / norm, + dCdJ = +rIJ*rIJ / norm, + dCdI = -rIJ*rJJ / norm, + dSdJ = -rJJ*rIJ / norm, + dSdI = +rJJ*rJJ / norm; + let dj = 0.0, + di = 0.0; + + // ROTATE ROWS IN dR + for( let k=j; k < N; k++ ) + { const dRjk = dR[rOff + N*j+k], + dRik = dR[rOff + N*i+k]; + dR[rOff + N*j+k] = c*dRjk - s*dRik; + dR[rOff + N*i+k] = s*dRjk + c*dRik; + + const rJK = R[rOff + N*j+k], + rIK = R[rOff + N*i+k]; + + dj += dRjk*(rIK*dSdJ + rJK*dCdJ) + dRik*(rIK*dCdJ - rJK*dSdJ); + di += dRjk*(rIK*dSdI + rJK*dCdI) + dRik*(rIK*dCdI - rJK*dSdI); + } + + // ROTATE ROWS IN dQᵀ + for( let k=0; k <= i; k++ ) + { const dQjk = dQ[qOff + M*j+k], + dQik = dQ[qOff + M*i+k]; + dQ[qOff + M*j+k] = c*dQjk - s*dQik; + dQ[qOff + M*i+k] = s*dQjk + c*dQik; + + const qJK = Q[qOff + M*j+k], + qIK = Q[qOff + M*i+k]; + + dj += dQjk*(qIK*dSdJ + qJK*dCdJ) + dQik*(qIK*dCdJ - qJK*dSdJ); + di += dQjk*(qIK*dSdI + qJK*dCdI) + dQik*(qIK*dCdI - qJK*dSdI); + } + + dR[rOff + N*j+j] += dj; + dR[rOff + N*i+j] += di; + }} // END QR DECOMPOSITION + } + assert( 0 === l, `WTF: ${l} != 0` ); + + return Tensor.make(dAShape,{values: dR},dtype); +} + /** - * Compute QR decomposition of m-by-n matrix using Householder transformation. + * Compute QR decomposition of m-by-n matrix using Givens rotations. * - * Implementation based on - * [http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf] - * (http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf) + * See: http://www.math.usm.edu/lambers/mat610/sum10/lecture9.pdf + * + * ```js + * const a = tf.tensor2d([[1, 2], [3, 4]]); + * let [q, r] = tf.linalg.qr(a); + * console.log('Q'); + * q.print(); + * console.log('R'); + * r.print(); + * console.log('Orthogonalized'); + * q.dot(q.transpose()).print() // should be nearly the identity matrix. + * console.log('Reconstructed'); + * q.dot(r).print(); // should be nearly [[1, 2], [3, 4]]; + * ``` * * ```js * const a = tf.tensor2d([[1, 2], [3, 4]]); @@ -150,115 +914,82 @@ function gramSchmidt_(xs: Tensor1D[]|Tensor2D): Tensor1D[]|Tensor2D { * subheading:'Linear Algebra', * namespace:'linalg'} */ -function qr_(x: Tensor, fullMatrices = false): [Tensor, Tensor] { - if (x.rank < 2) { +function qr_( a: Tensor, fullMatrices = false ): [Tensor, Tensor] { + if( a.rank < 2 ) { throw new Error( - `qr() requires input tensor to have a rank >= 2, but got rank ${ - x.rank}`); - } else if (x.rank === 2) { - return qr2d(x as Tensor2D, fullMatrices); - } else { - // Rank > 2. - // TODO(cais): Below we split the input into individual 2D tensors, - // perform QR decomposition on them and then stack the results back - // together. We should explore whether this can be parallelized. - const outerDimsProd = x.shape.slice(0, x.shape.length - 2) - .reduce((value, prev) => value * prev); - const x2ds = unstack( - x.reshape([ - outerDimsProd, x.shape[x.shape.length - 2], - x.shape[x.shape.length - 1] - ]), - 0); - const q2ds: Tensor2D[] = []; - const r2ds: Tensor2D[] = []; - x2ds.forEach(x2d => { - const [q2d, r2d] = qr2d(x2d as Tensor2D, fullMatrices); - q2ds.push(q2d); - r2ds.push(r2d); - }); - const q = stack(q2ds, 0).reshape(x.shape); - const r = stack(r2ds, 0).reshape(x.shape); - return [q, r]; + `qr() requires input tensor to have a rank >= 2, but got rank ${a.rank}` + ); + } + if( a.dtype.startsWith('complex') ) { + throw new Error(`qr() not yet supported for complex tensors.`); } -} -function qr2d(x: Tensor2D, fullMatrices = false): [Tensor2D, Tensor2D] { - return ENV.engine.tidy(() => { - if (x.shape.length !== 2) { - throw new Error( - `qr2d() requires a 2D Tensor, but got a ${x.shape.length}D Tensor.`); - } + const [m,n] = a.shape.slice(-2); - const m = x.shape[0]; - const n = x.shape[1]; - - let q = eye(m) as Tensor2D; // Orthogonal transform so far. - let r = x.clone(); // Transformed matrix so far. - - const one2D = tensor2d([[1]], [1, 1]); - let w: Tensor2D = one2D.clone(); - - const iters = m >= n ? n : m; - for (let j = 0; j < iters; ++j) { - // This tidy within the for-loop ensures we clean up temporary - // tensors as soon as they are no longer needed. - const rTemp = r; - const wTemp = w; - const qTemp = q; - [w, r, q] = ENV.engine.tidy((): [Tensor2D, Tensor2D, Tensor2D] => { - // Find H = I - tau * w * w', to put zeros below R(j, j). - const rjEnd1 = r.slice([j, j], [m - j, 1]); - const normX = rjEnd1.norm(); - const rjj = r.slice([j, j], [1, 1]); - const s = rjj.sign().neg() as Tensor2D; - const u1 = rjj.sub(s.mul(normX)) as Tensor2D; - const wPre = rjEnd1.div(u1); - if (wPre.shape[0] === 1) { - w = one2D.clone(); - } else { - w = one2D.concat( - wPre.slice([1, 0], [wPre.shape[0] - 1, wPre.shape[1]]) as - Tensor2D, - 0); - } - const tau = s.matMul(u1).div(normX).neg() as Tensor2D; - - // -- R := HR, Q := QH. - const rjEndAll = r.slice([j, 0], [m - j, n]); - const tauTimesW = tau.mul(w) as Tensor2D; - if (j === 0) { - r = rjEndAll.sub(tauTimesW.matMul(w.transpose().matMul(rjEndAll))); - } else { - r = r.slice([0, 0], [j, n]) - .concat( - rjEndAll.sub(tauTimesW.matMul( - w.transpose().matMul(rjEndAll))) as Tensor2D, - 0) as Tensor2D; - } - const qAllJEnd = q.slice([0, j], [m, q.shape[1] - j]); - if (j === 0) { - q = qAllJEnd.sub(qAllJEnd.matMul(w).matMul(tauTimesW.transpose())); - } else { - q = q.slice([0, 0], [m, j]) - .concat( - qAllJEnd.sub(qAllJEnd.matMul(w).matMul( - tauTimesW.transpose())) as Tensor2D, - 1) as Tensor2D; + if( m === n || m > n && !fullMatrices ) + { + // FIXME: What if R is (nearly) singular? + return ENV.engine.runKernel( + (backend,saveFunc) => { + const [q,r] = qrEcoDecompKernel(a); + saveFunc(q); + saveFunc(r); + return [q,r]; + }, + {a}, + ([dq,dr], [q,r]) => ({ + a: () => { + // TODO: is tidy required here? + // tslint:disable + // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_grad.py#L160 + // tslint:enable + const qdq = matMul(q,dq, true, false), + rdr = matMul(r,dr, false, true), + qdq_ = qdq.sub( adjoint(qdq) ), + rdr_ = rdr.sub( adjoint(rdr) ), + tril = bandPart( add(qdq_,rdr_), -1, 0 ); + + const triSolv = (x: Tensor,r: Tensor) => adjoint( + triangularSolve(r, adjoint(x), /*lower=*/false, /*adjoint_r*/false) + ); + + const gradA = matMul( q, dr.add( triSolv(tril,r) ) ), + gradB = triSolv( dq.sub( matMul(q,qdq) ), r ); + + return add(gradA,gradB); } - return [w, r, q]; - }); - dispose([rTemp, wTemp, qTemp]); - } + }) + ) as [Tensor, Tensor]; + } - if (!fullMatrices && m > n) { - q = q.slice([0, 0], [m, n]); - r = r.slice([0, 0], [n, n]); - } + let [q,r] = ENV.engine.runKernel( + (backend,saveFunc) => { + const [q,r,cs] = qrFullDecompKernel(a); + saveFunc(q); + saveFunc(r); + saveFunc(cs); + return [q,r]; + }, + {a}, + ([dq,dr], [q,r,cs]) => ({ + a: () => ENV.engine.runKernel( + (backend,saveFunc) => qrFullBackpropKernel(q,dq, r,dr, cs), + { $dq: dq, $dr: dr } + ) + }) + ); + + if( ! fullMatrices && m > n ) { + const end = a.shape.slice(); + q = q.slice([0, 0], end); end[end.length-2] = n; + r = r.slice([0, 0], end); + } - return [q, r]; - }) as [Tensor2D, Tensor2D]; + return [q,r]; } +export const adjoint = op({adjoint_}); +export const bandPart = op({bandPart_}); export const gramSchmidt = op({gramSchmidt_}); export const qr = op({qr_}); +export const triangularSolve = op({triangularSolve_}); diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts index bfbb5ef62b..6fdb39e310 100644 --- a/src/ops/linalg_ops_test.ts +++ b/src/ops/linalg_ops_test.ts @@ -16,12 +16,114 @@ */ import * as tf from '../index'; +import {ENV} from '../environment'; import {describeWithFlags} from '../jasmine_util'; -import {Tensor1D, Tensor2D} from '../tensor'; -import {ALL_ENVS, expectArraysClose, WEBGL_ENVS} from '../test_util'; +import {Scalar, Tensor, Tensor1D, Tensor2D} from '../tensor'; +import {ALL_ENVS, CPU_ENVS, expectArraysClose, expectArraysEqual, WEBGL_ENVS} from '../test_util'; import {scalar, tensor1d, tensor2d, tensor3d, tensor4d} from './ops'; +/** Returns a random integer in the range of [from,until). + */ +const randInt = (from: number, until: number) => { + return Math.floor(Math.random()*(until-from)) + from; +}; + +/** + * Computes the gradients using finite differences. Current + * implmentation uses an O(h⁴) central difference. + * + * SEE: https://en.wikipedia.org/wiki/Finite_difference + * + * FIXME this is terribly imprecise... wish there was + * double precision support *hint hint*. + */ +const numDiff = (f: (x: Tensor) => Scalar) => (a: Tensor) => { + if( a.dtype !== 'float32' ) { + throw new Error(`numDiff(): dtype=${a.dtype} not supported.`); + } + + const aData = Float32Array.from( a.dataSync() ); + + const eps = Math.sqrt( ENV.get('EPSILON') ); + + return ENV.engine.tidy(() => { + + const dA = new Float32Array( aData.length ); + + for( let i=0; i < aData.length; i++ ) + { // use central difference + const x = aData[i], + h = Math.max( Math.abs(x)*eps, eps ); + + const g = ( x: number ) => ENV.engine.tidy( () => { + aData[i] = x; + + const b = Tensor.make(a.shape, {values: aData}); + const scalar = f(b); + + if( scalar.rank !== 0 ) { + throw new Error('f() returned a non-scalar value.'); + } + + return scalar.dataSync()[0]; + }); + + // https://www.geometrictools.com/Documentation/FiniteDifferences.pdf + dA[i] = (-g(x+2*h) + 8*g(x+h) - 8*g(x-h) + g(x-2*h) ) / (12*h); + aData[i] = x; // <- undo modifications + } + + return Tensor.make(a.shape,{values: dA}); + }); +}; + +/** + * An tensor equivalency assertion that uses a comparison operator + * that is very similar to NumPy's `is_close()` function. + */ +function expectTensorsRelativelyClose( + actual: Tensor, expected: Tensor, rtol?: number, atol?: number +): void +{ + if( expected.shape.some( (s,i) => s !== actual.shape[i] ) ) { + throw new Error( + `Shapes [${actual.shape}] and [${expected.shape}] do not match.` + ); + } + + if( null == atol ) { atol = ENV.get('TEST_EPSILON'); } + if( null == rtol ) { rtol = ENV.get('TEST_EPSILON'); } + + const act = actual.dataSync(), + exp = expected.dataSync(); + + const isClose = (x: number, y: number) => { + x = Math.abs(x); + y = Math.abs(y); + return Math.abs(x-y) <= atol + rtol/2*(x+y); + }; + + for( let i=act.length; i-- > 0; ) { + if( ! isClose(act[i],exp[i]) ) + { + console.log( 'actual:'); actual.print(); + console.log('expected:'); expected.print(); + const idx = [], + shape = actual.shape; + for( let j=i, d=shape.length; d-- > 0; ) + { + const size = shape[d]; + idx.unshift(j % size); + j = Math.trunc(j / size); + } + throw new Error( + `actual[${idx}] = ${act[i]} != ${exp[i]} = expected[${idx}]` + ); + } + } +} + describeWithFlags('gramSchmidt-tiny', ALL_ENVS, () => { it('2x2, Array of Tensor1D', () => { const xs: Tensor1D[] = [ @@ -94,137 +196,375 @@ describeWithFlags('gramSchmidt-non-tiny', WEBGL_ENVS, () => { }); }); -describeWithFlags('qr', ALL_ENVS, () => { - it('1x1', () => { - const x = tensor2d([[10]], [1, 1]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose(q, tensor2d([[-1]], [1, 1])); - expectArraysClose(r, tensor2d([[-10]], [1, 1])); +describeWithFlags('adjoint', ALL_ENVS, () => { + it('2x3', () => { + const a = tf.tensor2d([[1,2,3], + [4,5,6]], [2,3]), + aT = tf.tensor2d([[1,4], + [2,5], + [3,6]],[3,2]); + // FIXME: shouldn't tf.transpose be lossless? + // Yet this fails on Travis with `expectArraysEqual`... + expectArraysClose( tf.linalg.adjoint(a), aT ); }); - - it('2x2', () => { - const x = tensor2d([[1, 3], [-2, -4]], [2, 2]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose( - q, tensor2d([[-0.4472, -0.8944], [0.8944, -0.4472]], [2, 2])); - expectArraysClose(r, tensor2d([[-2.2361, -4.9193], [0, -0.8944]], [2, 2])); + it('3x2x1', () => { + const a = tf.tensor3d([[[1],[2]], + [[3],[4]], + [[5],[6]]], [3,2,1]), + aT = tf.tensor3d([[[1,2]], + [[3,4]], + [[5,6]]], [3,1,2]); + expectArraysClose( tf.linalg.adjoint(a), aT ); }); +}); - it('2x2x2', () => { - const x = tensor3d([[[-1, -3], [2, 4]], [[1, 3], [-2, -4]]], [2, 2, 2]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose( - q, - tensor3d( - [ - [[-0.4472, -0.8944], [0.8944, -0.4472]], - [[-0.4472, -0.8944], [0.8944, -0.4472]] - ], - [2, 2, 2])); - expectArraysClose( - r, - tensor3d( - [ - [[2.2361, 4.9193], [0, 0.8944]], - [[-2.2361, -4.9193], [0, -0.8944]] - ], - [2, 2, 2])); - }); +describeWithFlags('bandPart', ALL_ENVS, () => { + const la = tf.linalg; - it('2x1x2x2', () => { - const x = - tensor4d([[[[-1, -3], [2, 4]]], [[[1, 3], [-2, -4]]]], [2, 1, 2, 2]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose( - q, - tensor4d( - [ - [[[-0.4472, -0.8944], [0.8944, -0.4472]]], - [[[-0.4472, -0.8944], [0.8944, -0.4472]]], - ], - [2, 1, 2, 2])); - expectArraysClose( - r, - tensor4d( - [ - [[[2.2361, 4.9193], [0, 0.8944]]], - [[[-2.2361, -4.9193], [0, -0.8944]]] - ], - [2, 1, 2, 2])); - }); + // FIXME: shouldn't 1*x be lossless? + // It's even in the IEEE spec somewhere... + // Yet this fails on Travis with `expectArraysEqual`... + const expectArraysEqual = expectArraysClose; - it('3x3', () => { - const x = tensor2d([[1, 3, 2], [-2, 0, 7], [8, -9, 4]], [3, 3]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose( - q, - tensor2d( - [ - [-0.1204, 0.8729, 0.4729], [0.2408, -0.4364, 0.8669], - [-0.9631, -0.2182, 0.1576] - ], - [3, 3])); - expectArraysClose( - r, - tensor2d( - [[-8.3066, 8.3066, -2.4077], [0, 4.5826, -2.1822], [0, 0, 7.6447]], - [3, 3])); - }); + it('3x4', () => { + const a = tf.tensor2d([ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9,10,11,12] + ]); + expectArraysEqual( + la.bandPart(a,0,0), + tf.tensor2d([[1, 0, 0, 0], + [0, 6, 0, 0], + [0, 0,11, 0]]) + ); + expectArraysEqual( + la.bandPart(a,0,1), + tf.tensor2d([[1, 2, 0, 0], + [0, 6, 7, 0], + [0, 0,11,12]]) + ); + expectArraysEqual( + la.bandPart(a,0,2), + tf.tensor2d([[1, 2, 3, 0], + [0, 6, 7, 8], + [0, 0,11,12]]) + ); + expectArraysEqual( + la.bandPart(a,0,2), + tf.tensor2d([[1, 2, 3, 0], + [0, 6, 7, 8], + [0, 0,11,12]]) + ); + for( const numUpper of [3,4,-1,-2] ) { + expectArraysEqual( + la.bandPart(a,0,numUpper), + tf.tensor2d([[1, 2, 3, 4], + [0, 6, 7, 8], + [0, 0,11,12]]) + ); + } - it('3x2, fullMatrices = default false', () => { - const x = tensor2d([[1, 2], [3, -3], [-2, 1]], [3, 2]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose( - q, - tensor2d( - [[-0.2673, 0.9221], [-0.8018, -0.3738], [0.5345, -0.0997]], - [3, 2])); - expectArraysClose(r, tensor2d([[-3.7417, 2.4054], [0, 2.8661]], [2, 2])); - }); + expectArraysEqual( + la.bandPart(a,1,0), + tf.tensor2d([[1, 0, 0, 0], + [5, 6, 0, 0], + [0,10,11, 0]]) + ); + expectArraysEqual( + la.bandPart(a,1,1), + tf.tensor2d([[1, 2, 0, 0], + [5, 6, 7, 0], + [0,10,11,12]]) + ); + expectArraysEqual( + la.bandPart(a,1,2), + tf.tensor2d([[1, 2, 3, 0], + [5, 6, 7, 8], + [0,10,11,12]]) + ); + expectArraysEqual( + la.bandPart(a,1,2), + tf.tensor2d([[1, 2, 3, 0], + [5, 6, 7, 8], + [0,10,11,12]]) + ); + for( const numUpper of [3,4,-1,-2] ) { + expectArraysEqual( + la.bandPart(a,1,numUpper), + tf.tensor2d([[1, 2, 3, 4], + [5, 6, 7, 8], + [0,10,11,12]]) + ); + } - it('3x2, fullMatrices = true', () => { - const x = tensor2d([[1, 2], [3, -3], [-2, 1]], [3, 2]); - const [q, r] = tf.linalg.qr(x, true); - expectArraysClose( - q, - tensor2d( - [ - [-0.2673, 0.9221, 0.2798], [-0.8018, -0.3738, 0.4663], - [0.5345, -0.0997, 0.8393] - ], - [3, 3])); - expectArraysClose( - r, tensor2d([[-3.7417, 2.4054], [0, 2.8661], [0, 0]], [3, 2])); + for( const numLower of [2,3,-1,-2]) + { + expectArraysEqual( + la.bandPart(a,numLower,0), + tf.tensor2d([[1, 0, 0, 0], + [5, 6, 0, 0], + [9,10,11, 0]]) + ); + expectArraysEqual( + la.bandPart(a,numLower,1), + tf.tensor2d([[1, 2, 0, 0], + [5, 6, 7, 0], + [9,10,11,12]]) + ); + expectArraysEqual( + la.bandPart(a,numLower,2), + tf.tensor2d([[1, 2, 3, 0], + [5, 6, 7, 8], + [9,10,11,12]]) + ); + expectArraysEqual( + la.bandPart(a,numLower,2), + tf.tensor2d([[1, 2, 3, 0], + [5, 6, 7, 8], + [9,10,11,12]]) + ); + for( const numUpper of [3,4,-1,-2] ) { + expectArraysEqual( + la.bandPart(a,numLower,numUpper), + tf.tensor2d([[1, 2, 3, 4], + [5, 6, 7, 8], + [9,10,11,12]]) + ); + } + } +// following test is only required for custom backend implementations +// +// for( const numUpper of [0,1,2,3,4,-1,-2] ) { +// for( const numLower of [0,1,2,3, -1,-2] ) { +// const w = tf.randomUniform(a.shape), +// f = (x: Tensor) => { +// return la.bandPart(x,numLower,numUpper).mul(w).mean() as Scalar; +// }, +// g = numDiff(f), +// h = tf.grad(f); +// expectArraysClose( g(a), h(a) ); +// }} }); +}); - it('2x3, fullMatrices = default false', () => { - const x = tensor2d([[1, 2, 3], [-3, -2, 1]], [2, 3]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose( - q, - tensor2d([[-0.3162278, -0.9486833], [0.9486833, -0.31622773]], [2, 2])); - expectArraysClose( - r, - tensor2d( - [[-3.162, -2.5298, -2.3842e-07], [0, -1.2649, -3.162]], [2, 3]), - ); - }); +describeWithFlags('triangularSolve', CPU_ENVS, () => { + const la = tf.linalg; - it('2x3, fullMatrices = true', () => { - const x = tensor2d([[1, 2, 3], [-3, -2, 1]], [2, 3]); - const [q, r] = tf.linalg.qr(x, true); - expectArraysClose( - q, - tensor2d([[-0.3162278, -0.9486833], [0.9486833, -0.31622773]], [2, 2])); - expectArraysClose( - r, - tensor2d( - [[-3.162, -2.5298, -2.3842e-07], [0, -1.2649, -3.162]], [2, 3]), + const testWith = (L: Tensor, y: Tensor) => { + const test = (adjoint: boolean) => + { + let tril = la.bandPart(L,-1, 0), + triu = la.bandPart(L, 0,-1); + if( adjoint ) { + tril = la.adjoint(tril); + triu = la.adjoint(triu); + } + for( const lower of [true,undefined] ) + { + const x = la.triangularSolve(L,y, lower, adjoint); + const [a,b] = [y,tril.matMul(x)]; + expectArraysClose(a,b); + } + const x = la.triangularSolve(L,y, /*lower=*/false, adjoint); + const [a,b] = [y,triu.matMul(x)]; +// const [a,b] = broadcastMatrices( y, triu.matMul(x) ); + expectArraysClose(a,b); + + for( const lower of [false,true,undefined] ) + { + const w = tf.randomUniform(y.shape,-1,+1), + f = (L: Tensor, y: Tensor) => { + return la.triangularSolve(L,y,lower).mul(w).mean() as Scalar; + }, + [g1,g2] = tf.grads(f)([L,y]), + h1 = numDiff( (L: Tensor) => f(L,y) )(L), + h2 = numDiff( (y: Tensor) => f(L,y) )(y); + expectArraysClose(g1,h1); + expectArraysClose(g2,h2); + } + }; + test(undefined); + test(false); + test(true); + }; + + it('3x3', () => testWith( + tf.tensor2d([[1,2,3], + [4,5,6], + [7,8,9]]), + tf.tensor2d([[10,11], + [12,13], + [14,15]]) + )); + + for( let run=0; run < 128; run++ ) + { + const lShape = Array.from({ length: randInt(2,5) }, () => randInt(1,7) ), + yShape = lShape.slice(); + lShape[lShape.length-1] = lShape[lShape.length-2]; + + // RUN TEST + it(`random#${run}_${lShape.join('x')}_${yShape.join('x')}`, () => { + const ONE = tf.scalar(1), + TWO = tf.scalar(2); + const y = tf.randomUniform(yShape,-1,+1); + let L: Tensor = tf.randomUniform(lShape,-1,+1); + // SET THE DIAGONAL TO BE FAR FROM ZERO + const i = tf.range(0,lShape[lShape.length-2]).reshape([-1,1]), + j = tf.range(0,lShape[lShape.length-1]), + diag = tf.equal(i,j).cast('float32'), + magn = tf.randomNormal (lShape, /*mean=*/1,/*stdDev=*/0.1), + sign = tf.randomUniform(lShape, 0,2, 'int32') + .cast('float32').mul(TWO).sub(ONE); + L = tf.add( + diag.sub(ONE).mul(L), // <- off-diagonal + diag.mul(sign).mul(magn) // <- diagonal + ); + L = tf.clone(L); + testWith(L,y); + }); + } +}); + +describeWithFlags('qr', CPU_ENVS, () => { + const testWith = (a: Tensor) => { + const [m,n] = a.shape.slice(-2), + l = Math.min(m,n), + // Indices of matrix transpose. + T = Array.from({ length: a.rank }, (_,i) => i ); + T[T.length-2] = T.length-1; + T[T.length-1] = T.length-2; + + for( const fullMatrices of [undefined,false,true] ) + { + const tril = (() => { + const [p,q] = fullMatrices ? [m,n] : [l,n], + i = tf.range(0,p).reshape([p,1]), + j = tf.range(0,q).reshape([1,q]); + return i.greater(j).cast('float32'); + })(); + const EYE = (() => { + const d = fullMatrices ? m : l; + return tf.stack( + Array.from( + { length: a.shape.slice(0,-2).reduce( (x,y) => x*y, 1 ) }, + () => tf.eye(d) + ) + ).reshape([...a.shape.slice(0,-2),d,d]); + })(); + const [q,r] = tf.linalg.qr(a,fullMatrices); + + // TEST SHAPE OF Q + expectArraysEqual( q.shape.slice(0,-1), a.shape.slice(0,-1) ); + expectArraysEqual( q.shape.slice( -1), fullMatrices ? [m ] : [l ] ); + + // TEST SHAPE OF R + expectArraysEqual( r.shape.slice(0,-2), a.shape.slice(0,-2) ); + expectArraysEqual( r.shape.slice( -2), fullMatrices ? [m,n] : [l,n] ); + + // TEST DECOMPOSITION (Q @ R == A) + try { + expectArraysClose( q.matMul(r), a ); + } catch(err) { + console.log('A'); a.print(); + console.log('Q'); q.print(); + console.log('R'); r.print(); + throw err; + } + + const qT = q.transpose(T); + + // TEST ORTHOGONALITY OF Q + if( fullMatrices || n >= m ) { + expectArraysClose( tf.matMul(q,qT), EYE ); + } + expectArraysClose( tf.matMul(qT,q), EYE ); + + // TEST TRIANGULARITY OF R + expectArraysEqual( tril.mul(r), tf.zeros(r.shape) ); + + // TEST GRADIENTS + const wQ = tf.randomUniform(q.shape,-1,+1), + wR = tf.randomUniform(r.shape,-1,+1), + f = (a: Tensor) => { + const [q,r] = tf.linalg.qr(a,fullMatrices); + return tf.add( + q.mul(wQ).mean(), + r.mul(wR).mean() + ) as Scalar; + }; + const g = numDiff(f); + const h = tf.grad(f); + try { + expectTensorsRelativelyClose(g(a), h(a), /*rtol=*/1e-2, /*atol=*/1e-2); + } + catch(err) { + console.log('fullMatrices:', fullMatrices); + console.log('A:'); a .print(); +// const [q,r] = tf.linalg.qr(a,fullMatrices); +// console.log('Q:'); q .print(); +// console.log('R:'); r .print(); +// console.log('G:'); g(a).print(); +// console.log('H:'); h(a).print(); + throw err; + } + } + }; + + it('1x1', () => testWith( tensor2d([[10]], [1, 1]) ) ); + + it('2x2', () => testWith( tensor2d([[ 1, 3], + [-2,-4]], [2, 2]) ) ); + + it('2x2x2', () => testWith( tensor3d([[[-1,-3], + [ 2, 4]], + [[ 1, 3], + [-2,-4]]], [2, 2, 2]) ) ); + + it('2x1x2x2', () => testWith( tensor4d([[[[-1,-3], + [ 2, 4]]], + [[[ 1, 3], + [-2,-4]]]], [2, 1, 2, 2]) ) ); + + it('3x3', () => testWith( tensor2d([[ 1, 3, 2], + [-2, 0, 7], + [ 8,-9, 4]], [3, 3]) ) ); + + it('3x2', () => testWith( tensor2d([[ 1, 2], + [ 3,-3], + [-2, 1]], [3, 2]) ) ); + + it('2x3', () => testWith( tensor2d([[ 1, 2, 3], + [-3,-2, 1]], [2, 3]) ) ); + + for( let run=0; run < 128; run++ ) + { + const shape = Array.from({ length: randInt(2,5) }, () => randInt(1,7) ); + it( + `random#${run}_${shape.join('x')}`, + () => testWith( tf.randomUniform(shape,-1,+1) ) ); + } + + it('Is reasonably fast', () => { + // TODO is there a better way to test this with a timeout? + const N = 128, + A = tf.randomUniform([N,N],-1,+1), + wQ = tf.randomUniform([N,N],-1,+1), + wR = tf.randomUniform([N,N],-1,+1), + f = (a: Tensor) => { + const [q,r] = tf.linalg.qr(a); + return q.mul(wQ).mean().add( r.mul(wR).mean() ); + }; + const g = tf.grad(f); + // following hopefully prevents g(A) from being JITes/Optimized away... + expectArraysClose( g(A), g(A) ); }); it('Does not leak memory', () => { - const x = tensor2d([[1, 3], [-2, -4]], [2, 2]); + const x = tensor2d([[ 1, 3], + [-2,-4]], [2, 2]); // The first call to qr creates and keeps internal singleton tensors. // Subsequent calls should always create exactly two tensors. tf.linalg.qr(x);