@@ -37,27 +37,36 @@ infiniStatus_t calculateArgWhere(
3737 const void *x) {
3838
3939 const Tdata *x_data = reinterpret_cast <const Tdata *>(x);
40- // int64_t *y_data = reinterpret_cast<int64_t *>(y);
41- std::vector<size_t > positions;
42- // #pragma omp parallel for
40+
41+ std::vector<int64_t > positions;
42+ const size_t ndim = info.shapes .size ();
43+
4344 for (size_t i = 0 ; i < info.num_elements ; i++) {
44- size_t pos = 0 , tem = i;
45- std::vector<size_t > position (info.strides .size ());
46- for (size_t j = info.strides .size () - 1 ; j >= 0 ; j--) {
47- position[j] = tem % info.shapes [j];
48- tem /= info.shapes [j];
49- pos += position[j] * info.strides [j];
45+ size_t pos = 0 ;
46+ size_t tmp = i;
47+
48+ std::vector<int64_t > coord (ndim);
49+
50+ // unravel index
51+ for (size_t j = ndim; j-- > 0 ;) {
52+ coord[j] = tmp % info.shapes [j];
53+ tmp /= info.shapes [j];
54+ pos += coord[j] * info.strides [j];
5055 }
51- if (fabs (x_data[pos] - 0 .0f ) > 1e-5 ) {
52- for (auto p : position) {
53- positions.push_back (p);
56+
57+ // PyTorch semantics: != 0
58+ if (x_data[pos] != Tdata (0 )) {
59+ for (size_t j = 0 ; j < ndim; j++) {
60+ positions.push_back (coord[j]);
5461 }
5562 }
5663 }
5764
65+ *count = positions.size () / ndim;
66+
5867 *y = new int64_t [positions.size ()];
5968 memcpy (*y, positions.data (), positions.size () * sizeof (int64_t ));
60- *count = positions. size () / info. strides . size ();
69+
6170 return INFINI_STATUS_SUCCESS;
6271}
6372
0 commit comments