Skip to content

Commit

Permalink
Replaced arraySync with slice + concat
Browse files Browse the repository at this point in the history
  • Loading branch information
Kriyszig committed Oct 7, 2019
1 parent 5e1b7b8 commit a3552ce
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions tfjs-core/src/ops/linalg_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import {split} from './concat_split';
import {norm} from './norm';
import {op} from './operation';
import {sum} from './reduction_ops';
import {tensor, tensor2d} from './tensor_ops';
import {tensor, tensor1d, tensor2d} from './tensor_ops';

/**
* Copy a tensor setting everything outside a central band in each innermost
Expand Down Expand Up @@ -71,7 +71,8 @@ function bandPart_(x: Tensor, numLower: number, numUpper: number): Tensor {
if (totalElements === 0) {
return tensor([], x.shape);
}
const parted: number[] = x.flatten().arraySync();
const flattened: Tensor1D = x.flatten();
let band: Tensor1D = tensor1d([]);
const rows = (x.rank < 2) ? 1 : x.shape[x.rank - 2];
const cols = x.shape[x.rank - 1];

Expand All @@ -80,12 +81,14 @@ function bandPart_(x: Tensor, numLower: number, numUpper: number): Tensor {
for (let k = 0; k < cols; ++k) {
if ((numLower > -1 && k < j - numLower) ||
(numUpper > -1 && k > j + numUpper)) {
parted[i + j * rows + k] = 0;
band = band.concat(tensor([0]));
} else {
band = band.concat(flattened.slice(i + j * rows + k, 1));
}
}
}
}
return tensor(parted, x.shape);
return band.reshape(x.shape);
});
}

Expand Down

0 comments on commit a3552ce

Please sign in to comment.