diff --git a/src/ops/linalg_ops.ts b/src/ops/linalg_ops.ts index 7a15b94e16..5336fd6b16 100644 --- a/src/ops/linalg_ops.ts +++ b/src/ops/linalg_ops.ts @@ -22,13 +22,17 @@ import {ENV} from '../environment'; import {dispose} from '../globals'; import {Tensor, Tensor1D, Tensor2D} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; import {assert} from '../util'; import {eye, squeeze, stack, unstack} from './array_ops'; +import {sub} from './binary_ops'; import {split} from './concat_split'; +import {logicalAnd, where} from './logical_ops'; import {norm} from './norm'; import {op} from './operation'; import {sum} from './reduction_ops'; -import {tensor2d} from './tensor_ops'; +import {range, scalar, tensor2d, zeros} from './tensor_ops'; /** * Gram-Schmidt orthogonalization. @@ -260,5 +264,91 @@ function qr2d(x: Tensor2D, fullMatrices = false): [Tensor2D, Tensor2D] { }) as [Tensor2D, Tensor2D]; } +/** + * Copies a tensor of matrices, setting everything outside a central band + * in each matrix to zero. + * + * ```js + * >>> const a = tf.tensor2d([[11, 12, 13, 14], + * ... [21, 22, 23, 24], + * ... [31, 32, 33, 34], + * ... [41, 42, 43, 44]]); + * >>> tf.linalg.bandPart(a,0,2).print(); + * [[11, 12, 13, 0], + * [ 0, 22, 23, 24], + * [ 0, 0, 33, 34], + * [ 0, 0, 0, 44]] + * + * >>> tf.linalg.bandPart(a,1,-1).print(); + * [[11, 12, 13, 14], + * [21, 22, 23, 24], + * [ 0, 32, 33, 34], + * [ 0, 0, 43, 44]] + * ``` + * + * @param a Tensor of matrices from which the band part is extracted. + * @param numLower The number of subdiagonal lines to be copied. + * If set to `-1`, all entries below the diagonal are + * copied. + * @param numUpper The number of superdiagonal lines to be copied. + * If set to `-1`, all entries above the diagonal are + * copied. + */ +/** + * @doc {heading:'Operations', + * subheading:'Linear Algebra', + * namespace:'linalg'} + */ +function bandPart_( + a: T|TensorLike, numLower: number, numUpper: number +): T +{ + if( numLower%1 !== 0 ){ + throw new Error(`bandPart(): numLower=${numLower} not an integer.`); + } + if( numUpper%1 !== 0 ){ + throw new Error(`bandPart(): numUpper=${numUpper} not an integer.`); + } + + return ENV.engine.tidy( () => { + const $a = convertToTensor(a,'a','bandPart'); + a = undefined; + + if( $a.rank < 2 ) { + throw new Error(`bandPart(): a.rank = ${$a.rank} < 2.`); + } + + const shape = $a.shape, + [M,N] = $a.shape.slice(-2); + + if( !(numLower <= M) ) { + throw new Error(`bandPart() check failed: numLower <= #rows.` ); + } + if( !(numUpper <= N) ) { + throw new Error(`bandPart() check failed: numUpper <= #columns.`); + } + + if( numLower < 0 ) { numLower = M; } + if( numUpper < 0 ) { numUpper = N; } + + const i = range(0,M, 1, 'int32').reshape([-1,1]), + j = range(0,N, 1, 'int32'); + + const inBand = logicalAnd( + sub(i,j).lessEqual( scalar(numLower,'int32') ), + sub(j,i).lessEqual( scalar(numUpper,'int32') ) + ); + + const zero = zeros([M,N], $a.dtype); + + return stack( + unstack( $a.reshape([-1,M,N]) ).map( + mat => where(inBand, mat, zero) + ) + ).reshape(shape) as T; + }); +} + export const gramSchmidt = op({gramSchmidt_}); +export const bandPart = op({bandPart_}); export const qr = op({qr_}); diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts index bfbb5ef62b..f6e731d7c0 100644 --- a/src/ops/linalg_ops_test.ts +++ b/src/ops/linalg_ops_test.ts @@ -18,7 +18,7 @@ import * as tf from '../index'; import {describeWithFlags} from '../jasmine_util'; import {Tensor1D, Tensor2D} from '../tensor'; -import {ALL_ENVS, expectArraysClose, WEBGL_ENVS} from '../test_util'; +import {ALL_ENVS, CPU_ENVS, expectArraysClose, expectArraysEqual, WEBGL_ENVS} from '../test_util'; import {scalar, tensor1d, tensor2d, tensor3d, tensor4d} from './ops'; @@ -241,3 +241,121 @@ describeWithFlags('qr', ALL_ENVS, () => { expect(() => tf.linalg.qr(x2)).toThrowError(/rank >= 2.*got rank 1/); }); }); + +for( const ENV of [CPU_ENVS, WEBGL_ENVS] ) +{ + const expectArrayEq = Object.is(ENV, CPU_ENVS) + ? expectArraysEqual + : expectArraysClose; + + describeWithFlags('bandPart', ENV, () => { + const la = tf.linalg; + + it('works for 3x4 example', () => { + const a = tf.tensor2d([[1, 2, 3, 4], + [5, 6, 7, 8], + [9,10,11,12]]); + expectArrayEq( + la.bandPart(a,0,0), + tf.tensor2d([[1, 0, 0, 0], + [0, 6, 0, 0], + [0, 0,11, 0]]) + ); + expectArrayEq( + la.bandPart(a,0,1), + tf.tensor2d([[1, 2, 0, 0], + [0, 6, 7, 0], + [0, 0,11,12]]) + ); + expectArrayEq( + la.bandPart(a,0,2), + tf.tensor2d([[1, 2, 3, 0], + [0, 6, 7, 8], + [0, 0,11,12]]) + ); + expectArrayEq( + la.bandPart(a,0,2), + tf.tensor2d([[1, 2, 3, 0], + [0, 6, 7, 8], + [0, 0,11,12]]) + ); + for( const numUpper of [3,4,-1,-2] ) { + expectArrayEq( + la.bandPart(a,0,numUpper), + tf.tensor2d([[1, 2, 3, 4], + [0, 6, 7, 8], + [0, 0,11,12]]) + ); + } + + expectArrayEq( + la.bandPart(a,1,0), + tf.tensor2d([[1, 0, 0, 0], + [5, 6, 0, 0], + [0,10,11, 0]]) + ); + expectArrayEq( + la.bandPart(a,1,1), + tf.tensor2d([[1, 2, 0, 0], + [5, 6, 7, 0], + [0,10,11,12]]) + ); + expectArrayEq( + la.bandPart(a,1,2), + tf.tensor2d([[1, 2, 3, 0], + [5, 6, 7, 8], + [0,10,11,12]]) + ); + expectArrayEq( + la.bandPart(a,1,2), + tf.tensor2d([[1, 2, 3, 0], + [5, 6, 7, 8], + [0,10,11,12]]) + ); + for( const numUpper of [3,4,-1,-2] ) { + expectArrayEq( + la.bandPart(a,1,numUpper), + tf.tensor2d([[1, 2, 3, 4], + [5, 6, 7, 8], + [0,10,11,12]]) + ); + } + + for( const numLower of [2,3,-1,-2]) + { + expectArrayEq( + la.bandPart(a,numLower,0), + tf.tensor2d([[1, 0, 0, 0], + [5, 6, 0, 0], + [9,10,11, 0]]) + ); + expectArrayEq( + la.bandPart(a,numLower,1), + tf.tensor2d([[1, 2, 0, 0], + [5, 6, 7, 0], + [9,10,11,12]]) + ); + expectArrayEq( + la.bandPart(a,numLower,2), + tf.tensor2d([[1, 2, 3, 0], + [5, 6, 7, 8], + [9,10,11,12]]) + ); + expectArrayEq( + la.bandPart(a,numLower,2), + tf.tensor2d([[1, 2, 3, 0], + [5, 6, 7, 8], + [9,10,11,12]]) + ); + for( const numUpper of [3,4,-1,-2] ) { + expectArrayEq( + la.bandPart(a,numLower,numUpper), + tf.tensor2d([[1, 2, 3, 4], + [5, 6, 7, 8], + [9,10,11,12]]) + ); + } + } + }); + }); +}