From d1a2cfeb032a6ff9ffe651bcd395fdb9fe3751df Mon Sep 17 00:00:00 2001 From: Ved Patwardhan <54766411+vedpatwardhan@users.noreply.github.com> Date: Wed, 13 Mar 2024 12:05:24 +0000 Subject: [PATCH 1/6] fix: updated _parse_query to use ivy's set_item to use get_item as it's optimized now --- ivy/functional/ivy/general.py | 272 +++------------------------------- 1 file changed, 20 insertions(+), 252 deletions(-) diff --git a/ivy/functional/ivy/general.py b/ivy/functional/ivy/general.py index 97b49427c93ee..5d917baeae57f 100644 --- a/ivy/functional/ivy/general.py +++ b/ivy/functional/ivy/general.py @@ -3,7 +3,6 @@ # global import gc import inspect -import itertools import math from functools import wraps from numbers import Number @@ -2802,19 +2801,11 @@ def get_item( query = ivy.nonzero(query, as_tuple=False) ret = ivy.gather_nd(x, query) else: - query, target_shape, vector_inds = _parse_query( - query, ivy.shape(x, as_array=True) - ) - if vector_inds is not None: - x = ivy.permute_dims( - x, - axes=[ - *vector_inds, - *[i for i in range(len(x.shape)) if i not in vector_inds], - ], - ) - ret = ivy.gather_nd(x, query) - ret = ivy.reshape(ret, target_shape) if target_shape != list(ret.shape) else ret + indices, target_shape = _parse_query(query, ivy.shape(x, as_array=True)) + if indices is None: + return ivy.empty(target_shape, dtype=x.dtype) + ret = ivy.gather_nd(x, indices) + ret = ivy.reshape(ret, target_shape) return ret @@ -2889,9 +2880,7 @@ def set_item( query = ivy.tile(query, (x.shape[0],)) indices = ivy.nonzero(query, as_tuple=False) else: - indices, target_shape, _ = _parse_query( - query, ivy.shape(x, as_array=True), scatter=True - ) + indices, target_shape, _ = _parse_query(query, ivy.shape(x, as_array=True)) if indices is None: return x val = val.astype(x.dtype) @@ -2909,247 +2898,26 @@ def set_item( } -def _parse_query(query, x_shape, scatter=False): - query = (query,) if not isinstance(query, tuple) else query +def _parse_query(query, x_shape): + query = query if isinstance(query, tuple) else (query,) - # sequence and integer queries are dealt with as array queries - query = [ivy.array(q) if isinstance(q, (tuple, list, int)) else q for q in query] + # array containing all of x's flat indices + x_ = ivy.arange(0, _numel(x_shape)).reshape(x_shape) - # check if non-slice queries are in consecutive positions - # if so, they have to be moved to the front - # https://numpy.org/neps/nep-0021-advanced-indexing.html#mixed-indexing - non_slice_q_idxs = [i for i, q in enumerate(query) if ivy.is_array(q)] - to_front = ( - len(non_slice_q_idxs) > 1 - and any(ivy.diff(non_slice_q_idxs) != 1) - and non_slice_q_idxs[-1] < len(x_shape) - ) + # use numpy's __getitem__ to get the queried indices + x_idxs = x_[query] + target_shape = x_idxs.shape - # extract newaxis queries - new_axes = [i for i, q in enumerate(query) if q is None] - query = [q for q in query if q is not None] - query = [Ellipsis] if query == [] else query - - # parse ellipsis - ellipsis_inds = None - if any(q is Ellipsis for q in query): - query, ellipsis_inds = _parse_ellipsis(query, len(x_shape)) - - # broadcast array queries - array_inds = [i for i, v in enumerate(query) if ivy.is_array(v)] - if array_inds: - array_queries = ivy.broadcast_arrays( - *[v for i, v in enumerate(query) if i in array_inds] - ) - array_queries = [ - ivy.nonzero(q, as_tuple=False)[0] if ivy.is_bool_dtype(q) else q - for q in array_queries - ] - array_queries = [ - ( - ivy.where(arr < 0, arr + x_shape[i], arr).astype(ivy.int64) - if arr.size - else arr.astype(ivy.int64) - ) - for arr, i in zip(array_queries, array_inds) - ] - for idx, arr in zip(array_inds, array_queries): - query[idx] = arr - - # convert slices to range arrays - query = [ - _parse_slice(q, x_shape[i]).astype(ivy.int64) if isinstance(q, slice) else q - for i, q in enumerate(query) - ] + if 0 in x_idxs.shape or 0 in x_shape: + return None, target_shape - # fill in missing queries - if len(query) < len(x_shape): - query += [ivy.arange(0, s, 1).astype(ivy.int64) for s in x_shape[len(query) :]] + # convert the flat indices to multi-D indices + x_idxs = ivy.unravel_index(x_idxs, x_shape) - # calculate target_shape, i.e. the shape the gathered/scattered values should have - if len(array_inds) and to_front: - target_shape = ( - [list(array_queries[0].shape)] - + [list(query[i].shape) for i in range(len(query)) if i not in array_inds] - + [[] for _ in range(len(array_inds) - 1)] - ) - elif len(array_inds): - target_shape = ( - [list(query[i].shape) for i in range(0, array_inds[0])] - + [list(ivy.shape(array_queries[0], as_array=True))] - + [[] for _ in range(len(array_inds) - 1)] - + [list(query[i].shape) for i in range(array_inds[-1] + 1, len(query))] - ) - else: - target_shape = [list(q.shape) for q in query] - if ellipsis_inds is not None: - target_shape = ( - target_shape[: ellipsis_inds[0]] - + [target_shape[ellipsis_inds[0] : ellipsis_inds[1]]] - + target_shape[ellipsis_inds[1] :] - ) - for i, ax in enumerate(new_axes): - if len(array_inds) and to_front: - ax -= sum(1 for x in array_inds if x < ax) - 1 - ax = ax + i - target_shape = [*target_shape[:ax], 1, *target_shape[ax:]] - target_shape = _deep_flatten(target_shape) - - # calculate the indices mesh (indices in gather_nd/scatter_nd format) - query = [ivy.expand_dims(q) if not len(q.shape) else q for q in query] - if len(array_inds): - array_queries = [ - ( - arr.reshape((-1,)) - if len(arr.shape) > 1 - else ivy.expand_dims(arr) if not len(arr.shape) else arr - ) - for arr in array_queries - ] - array_queries = ivy.stack(array_queries, axis=1) - if len(array_inds) == len(query): # advanced indexing - indices = array_queries.reshape((*target_shape, len(x_shape))) - elif len(array_inds) == 0: # basic indexing - indices = ivy.stack(ivy.meshgrid(*query, indexing="ij"), axis=-1).reshape( - (*target_shape, len(x_shape)) - ) - else: # mixed indexing - if to_front: - post_array_queries = ( - ivy.stack( - ivy.meshgrid( - *[v for i, v in enumerate(query) if i not in array_inds], - indexing="ij", - ), - axis=-1, - ).reshape((-1, len(query) - len(array_inds))) - if len(array_inds) < len(query) - else ivy.empty((1, 0)) - ) - indices = ivy.array( - [ - (*arr, *post) - for arr, post in itertools.product( - array_queries, post_array_queries - ) - ] - ).reshape((*target_shape, len(x_shape))) - else: - pre_array_queries = ( - ivy.stack( - ivy.meshgrid( - *[v for i, v in enumerate(query) if i < array_inds[0]], - indexing="ij", - ), - axis=-1, - ).reshape((-1, array_inds[0])) - if array_inds[0] > 0 - else ivy.empty((1, 0)) - ) - post_array_queries = ( - ivy.stack( - ivy.meshgrid( - *[v for i, v in enumerate(query) if i > array_inds[-1]], - indexing="ij", - ), - axis=-1, - ).reshape((-1, len(query) - 1 - array_inds[-1])) - if array_inds[-1] < len(query) - 1 - else ivy.empty((1, 0)) - ) - indices = ivy.array( - [ - (*pre, *arr, *post) - for pre, arr, post in itertools.product( - pre_array_queries, array_queries, post_array_queries - ) - ] - ).reshape((*target_shape, len(x_shape))) - - return ( - indices.astype(ivy.int64), - target_shape, - array_inds if len(array_inds) and to_front else None, - ) - - -def _parse_ellipsis(so, ndims): - pre = list() - for s in so: - if s is Ellipsis: - break - pre.append(s) - post = list() - for s in reversed(so): - if s is Ellipsis: - break - post.append(s) - ret = list( - pre - + [slice(None, None, None) for _ in range(ndims - len(pre) - len(post))] - + list(reversed(post)) - ) - return ret, (len(pre), ndims - len(post)) - - -def _parse_slice(idx, s): - step = 1 if idx.step is None else idx.step - if step > 0: - start = 0 if idx.start is None else idx.start - if start >= s: - stop = start - else: - if start <= -s: - start = 0 - elif start < 0: - start = start + s - stop = s if idx.stop is None else idx.stop - if stop > s: - stop = s - elif start <= -s: - stop = 0 - elif stop < 0: - stop = stop + s - else: - start = s - 1 if idx.start is None else idx.start - if start < -s: - stop = start - else: - if start >= s: - start = s - 1 - elif start < 0: - start = start + s - if idx.stop is None: - stop = -1 - else: - stop = idx.stop - if stop > s: - stop = s - elif stop < -s: - stop = -1 - elif stop == -s: - stop = 0 - elif stop < 0: - stop = stop + s - q_i = ivy.arange(start, stop, step) - q_i = [q for q in q_i if 0 <= q < s] - q_i = ( - ivy.array(q_i) - if len(q_i) or start == stop or idx.stop is not None - else ivy.arange(0, s, 1) - ) - return q_i - - -def _deep_flatten(iterable): - def _flatten_gen(iterable): - for item in iterable: - if isinstance(item, list): - yield from _flatten_gen(item) - else: - yield item + # stack the multi-D indices to bring them to gather_nd/scatter_nd format + x_idxs = ivy.stack(x_idxs, axis=-1).astype(ivy.int64) - return list(_flatten_gen(iterable)) + return x_idxs, target_shape def _numel(shape): From 2f06cc8b77ba19c8f9d7366ddddfd9392d029eec Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Wed, 13 Mar 2024 15:22:35 +0000 Subject: [PATCH 2/6] remove to scalar call from _numel --- ivy/functional/ivy/general.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/ivy/general.py b/ivy/functional/ivy/general.py index 5d917baeae57f..7ea4222748c30 100644 --- a/ivy/functional/ivy/general.py +++ b/ivy/functional/ivy/general.py @@ -2922,7 +2922,7 @@ def _parse_query(query, x_shape): def _numel(shape): shape = tuple(shape) - return ivy.prod(shape).to_scalar() if shape != () else 1 + return ivy.prod(shape) if shape != () else 1 def _broadcast_to(input, target_shape): From c492aa93b25186bc2df7dc9c29dc0f7ecc2751e3 Mon Sep 17 00:00:00 2001 From: Ved Patwardhan <54766411+vedpatwardhan@users.noreply.github.com> Date: Thu, 14 Mar 2024 16:42:30 +0530 Subject: [PATCH 3/6] Update general.py --- ivy/functional/ivy/general.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/ivy/general.py b/ivy/functional/ivy/general.py index 7ea4222748c30..519a834181e8d 100644 --- a/ivy/functional/ivy/general.py +++ b/ivy/functional/ivy/general.py @@ -2801,7 +2801,7 @@ def get_item( query = ivy.nonzero(query, as_tuple=False) ret = ivy.gather_nd(x, query) else: - indices, target_shape = _parse_query(query, ivy.shape(x, as_array=True)) + indices, target_shape = _parse_query(query, ivy.shape(x)) if indices is None: return ivy.empty(target_shape, dtype=x.dtype) ret = ivy.gather_nd(x, indices) From 3e37d220313cf9f210ea66b09881aa2a3d42b9fd Mon Sep 17 00:00:00 2001 From: Ved Patwardhan <54766411+vedpatwardhan@users.noreply.github.com> Date: Thu, 14 Mar 2024 19:53:13 +0530 Subject: [PATCH 4/6] Update general.py --- ivy/functional/ivy/general.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/ivy/general.py b/ivy/functional/ivy/general.py index 519a834181e8d..4a87582ccb95b 100644 --- a/ivy/functional/ivy/general.py +++ b/ivy/functional/ivy/general.py @@ -2880,7 +2880,7 @@ def set_item( query = ivy.tile(query, (x.shape[0],)) indices = ivy.nonzero(query, as_tuple=False) else: - indices, target_shape, _ = _parse_query(query, ivy.shape(x, as_array=True)) + indices, target_shape, _ = _parse_query(query, ivy.shape(x)) if indices is None: return x val = val.astype(x.dtype) From 1173cd7ce87050cefab056838603e9742397137d Mon Sep 17 00:00:00 2001 From: Ved Patwardhan <54766411+vedpatwardhan@users.noreply.github.com> Date: Thu, 14 Mar 2024 14:50:28 +0000 Subject: [PATCH 5/6] more fixes to set_item --- ivy/functional/ivy/general.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ivy/functional/ivy/general.py b/ivy/functional/ivy/general.py index 4a87582ccb95b..91d4e6c29dd7f 100644 --- a/ivy/functional/ivy/general.py +++ b/ivy/functional/ivy/general.py @@ -2801,7 +2801,7 @@ def get_item( query = ivy.nonzero(query, as_tuple=False) ret = ivy.gather_nd(x, query) else: - indices, target_shape = _parse_query(query, ivy.shape(x)) + indices, target_shape = _parse_query(query, ivy.shape(x, as_array=True)) if indices is None: return ivy.empty(target_shape, dtype=x.dtype) ret = ivy.gather_nd(x, indices) @@ -2880,7 +2880,7 @@ def set_item( query = ivy.tile(query, (x.shape[0],)) indices = ivy.nonzero(query, as_tuple=False) else: - indices, target_shape, _ = _parse_query(query, ivy.shape(x)) + indices, target_shape = _parse_query(query, ivy.shape(x, as_array=True)) if indices is None: return x val = val.astype(x.dtype) @@ -2908,7 +2908,7 @@ def _parse_query(query, x_shape): x_idxs = x_[query] target_shape = x_idxs.shape - if 0 in x_idxs.shape or 0 in x_shape: + if 0 in x_idxs.shape or int(ivy.prod(x_shape)) == 0: return None, target_shape # convert the flat indices to multi-D indices From dea0111b828cbdd457f6cb41701fa3971c7ef966 Mon Sep 17 00:00:00 2001 From: Ved Patwardhan <54766411+vedpatwardhan@users.noreply.github.com> Date: Fri, 15 Mar 2024 06:42:51 +0000 Subject: [PATCH 6/6] made a few more changes to set_item and ivy.shape, got ivy.shape to return a tensor by default if executing in graph mode to avoid breaking torch --- ivy/functional/backends/tensorflow/general.py | 2 +- ivy/functional/ivy/general.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ivy/functional/backends/tensorflow/general.py b/ivy/functional/backends/tensorflow/general.py index 0346b00b9641e..2a57998e4d753 100644 --- a/ivy/functional/backends/tensorflow/general.py +++ b/ivy/functional/backends/tensorflow/general.py @@ -406,7 +406,7 @@ def shape( *, as_array: bool = False, ) -> Union[tf.Tensor, ivy.Shape, ivy.Array]: - if as_array: + if as_array or not tf.executing_eagerly(): return ivy.array(tf.shape(x), dtype=ivy.default_int_dtype()) else: return ivy.Shape(x.shape) diff --git a/ivy/functional/ivy/general.py b/ivy/functional/ivy/general.py index 91d4e6c29dd7f..9d68cc059502c 100644 --- a/ivy/functional/ivy/general.py +++ b/ivy/functional/ivy/general.py @@ -2801,7 +2801,7 @@ def get_item( query = ivy.nonzero(query, as_tuple=False) ret = ivy.gather_nd(x, query) else: - indices, target_shape = _parse_query(query, ivy.shape(x, as_array=True)) + indices, target_shape = _parse_query(query, ivy.shape(x)) if indices is None: return ivy.empty(target_shape, dtype=x.dtype) ret = ivy.gather_nd(x, indices) @@ -2880,7 +2880,7 @@ def set_item( query = ivy.tile(query, (x.shape[0],)) indices = ivy.nonzero(query, as_tuple=False) else: - indices, target_shape = _parse_query(query, ivy.shape(x, as_array=True)) + indices, target_shape = _parse_query(query, ivy.shape(x)) if indices is None: return x val = val.astype(x.dtype) @@ -2922,7 +2922,7 @@ def _parse_query(query, x_shape): def _numel(shape): shape = tuple(shape) - return ivy.prod(shape) if shape != () else 1 + return int(ivy.prod(shape)) if shape != () else 1 def _broadcast_to(input, target_shape):