diff --git a/tfjs-core/src/ops/linalg_ops.ts b/tfjs-core/src/ops/linalg_ops.ts index 37a963846b..5aa4147680 100644 --- a/tfjs-core/src/ops/linalg_ops.ts +++ b/tfjs-core/src/ops/linalg_ops.ts @@ -28,7 +28,69 @@ import {split} from './concat_split'; import {norm} from './norm'; import {op} from './operation'; import {sum} from './reduction_ops'; -import {tensor2d} from './tensor_ops'; +import {tensor, tensor1d, tensor2d} from './tensor_ops'; + +/** + * Copy a tensor setting everything outside a central band in each innermost + * matrix to zero. + * + * The band part is computed as follows: Assume input has `k` dimensions + * `[I, J, K, ..., M, N]`, then the output is a tensor with the same shape where + * `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`. + * The indicator function + * `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower))` + * `&& (num_upper < 0 || (n-m) <= num_upper)` + * + * ```js + * const x = tf.tensor2d([[ 0, 1, 2, 3], + * [-1, 0, 1, 2], + * [-2, -1, 0, 1], + * [-3, -2, -1, 0]]); + * let y = tf.linalg.bandPart(x, 1, -1); + * y.print(); + * let z = tf.linalg.bandPart(x, 2, 1); + * z.print(); + * ``` + * + * @param x Rank `k` tensor + * @param numLower Number of subdiagonals to keep. + * If negative, keep entire lower triangle. + * @param numUpper Number of subdiagonals to keep. + * If negative, keep entire upper triangle. + * @returns Rank `k` tensor of the same shape as input. + * The extracted banded tensor. + */ +/** + * @doc {heading:'Operations', + * subheading:'Linear Algebra', + * namespace:'linalg'} + */ +function bandPart_(x: Tensor, numLower: number, numUpper: number): Tensor { + return ENGINE.tidy(() => { + const totalElements = x.shape.reduce((a, b) => a * b); + if (totalElements === 0) { + return tensor([], x.shape); + } + const flattened: Tensor1D = x.flatten(); + let band: Tensor1D = tensor1d([]); + const rows = (x.rank < 2) ? 1 : x.shape[x.rank - 2]; + const cols = x.shape[x.rank - 1]; + + for (let i = 0; i < totalElements; i += (rows * cols)) { + for (let j = 0; j < rows; ++j) { + for (let k = 0; k < cols; ++k) { + if ((numLower > -1 && k < j - numLower) || + (numUpper > -1 && k > j + numUpper)) { + band = band.concat(tensor([0])); + } else { + band = band.concat(flattened.slice(i + j * rows + k, 1)); + } + } + } + } + return band.reshape(x.shape); + }); +} /** * Gram-Schmidt orthogonalization. @@ -263,5 +325,6 @@ function qr2d(x: Tensor2D, fullMatrices = false): [Tensor2D, Tensor2D] { }) as [Tensor2D, Tensor2D]; } +export const bandPart = op({bandPart_}); export const gramSchmidt = op({gramSchmidt_}); export const qr = op({qr_}); diff --git a/tfjs-core/src/ops/linalg_ops_test.ts b/tfjs-core/src/ops/linalg_ops_test.ts index 3988425317..c396bfe5fd 100644 --- a/tfjs-core/src/ops/linalg_ops_test.ts +++ b/tfjs-core/src/ops/linalg_ops_test.ts @@ -17,11 +17,87 @@ import * as tf from '../index'; import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; -import {Tensor1D, Tensor2D} from '../tensor'; +import {Tensor1D, Tensor2D, Tensor3D} from '../tensor'; import {expectArraysClose} from '../test_util'; import {scalar, tensor1d, tensor2d, tensor3d, tensor4d} from './ops'; +describeWithFlags('bandPart', ALL_ENVS, () => { + it('bandPart to keep tensor unchanged', async () => { + const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]); + expectArraysClose( + await tf.linalg.bandPart(x, -1, -1).array(), + [[1, 1, 1], [1, 1, 1], [1, 1, 1]]); + }); + + it('bandPart for upper triangular matrix', async () => { + const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]); + expectArraysClose( + await tf.linalg.bandPart(x, 0, -1).array(), + [[1, 1, 1], [0, 1, 1], [0, 0, 1]]); + }); + + it('bandPart for lower triangular matrix', async () => { + const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]); + expectArraysClose( + await tf.linalg.bandPart(x, -1, 0).array(), + [[1, 0, 0], [1, 1, 0], [1, 1, 1]]); + }); + + it('bandPart for diagonal elements', async () => { + const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]); + expectArraysClose( + await tf.linalg.bandPart(x, 0, 0).array(), + [[1, 0, 0], [0, 1, 0], [0, 0, 1]]); + }); + + it('bandPart for lower triangular elements', async () => { + const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]); + expectArraysClose( + await tf.linalg.bandPart(x, 1, 0).array(), + [[1, 0, 0], [1, 1, 0], [0, 1, 1]]); + }); + + it('bandPart for upper triangular elements', async () => { + const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]); + expectArraysClose( + await tf.linalg.bandPart(x, 0, 1).array(), + [[1, 1, 0], [0, 1, 1], [0, 0, 1]]); + }); + + it('bandPart for 4X4 matrix - tensorflow python examples', async () => { + const x: Tensor2D = tensor2d( + [[0, 1, 2, 3], [-1, 0, 1, 2], [-2, -1, 0, 1], [-3, -2, -1, 0]]); + expectArraysClose( + await tf.linalg.bandPart(x, 1, -1).array(), + [[0, 1, 2, 3], [-1, 0, 1, 2], [0, -1, 0, 1], [0, 0, -1, 0]]); + expectArraysClose( + await tf.linalg.bandPart(x, 2, 1).array(), + [[0, 1, 0, 0], [-1, 0, 1, 0], [-2, -1, 0, 1], [0, -2, -1, 0]]); + }); + + it('bandPart for 3 dimensional matrix', async () => { + const x: Tensor3D = tensor3d([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]); + expectArraysClose( + await tf.linalg.bandPart(x, 0, 0).array(), + [[[1, 0], [0, 1]], [[1, 0], [0, 1]]]); + }); + + it('bandPart for 2X3X3', async () => { + const x: Tensor3D = tensor3d( + [[[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]]); + expectArraysClose( + await tf.linalg.bandPart(x, 1, 2).array(), + [[[1, 1, 1], [1, 1, 1], [0, 1, 1]], [[1, 1, 1], [1, 1, 1], [0, 1, 1]]]); + }); + + it('bandPart for 1D tensor', async () => { + const x: Tensor1D = tensor1d([1, 1, 1, 1, 1]); + expectArraysClose( + await tf.linalg.bandPart(x, 1, 2).array(), [1, 1, 1, 0, 0]); + }); +}); + describeWithFlags('gramSchmidt-tiny', ALL_ENVS, () => { it('2x2, Array of Tensor1D', async () => { const xs: Tensor1D[] = [