diff --git a/tfjs-core/src/ops/linalg_ops.ts b/tfjs-core/src/ops/linalg_ops.ts index 38e4dcb6104..5aa41476806 100644 --- a/tfjs-core/src/ops/linalg_ops.ts +++ b/tfjs-core/src/ops/linalg_ops.ts @@ -28,7 +28,7 @@ import {split} from './concat_split'; import {norm} from './norm'; import {op} from './operation'; import {sum} from './reduction_ops'; -import {tensor, tensor2d} from './tensor_ops'; +import {tensor, tensor1d, tensor2d} from './tensor_ops'; /** * Copy a tensor setting everything outside a central band in each innermost @@ -71,7 +71,8 @@ function bandPart_(x: Tensor, numLower: number, numUpper: number): Tensor { if (totalElements === 0) { return tensor([], x.shape); } - const parted: number[] = x.flatten().arraySync(); + const flattened: Tensor1D = x.flatten(); + let band: Tensor1D = tensor1d([]); const rows = (x.rank < 2) ? 1 : x.shape[x.rank - 2]; const cols = x.shape[x.rank - 1]; @@ -80,12 +81,14 @@ function bandPart_(x: Tensor, numLower: number, numUpper: number): Tensor { for (let k = 0; k < cols; ++k) { if ((numLower > -1 && k < j - numLower) || (numUpper > -1 && k > j + numUpper)) { - parted[i + j * rows + k] = 0; + band = band.concat(tensor([0])); + } else { + band = band.concat(flattened.slice(i + j * rows + k, 1)); } } } } - return tensor(parted, x.shape); + return band.reshape(x.shape); }); }