|
13 | 13 |
|
14 | 14 |
|
15 | 15 | GRID_SEARCH_ERROR = -3 |
| 16 | +LEFT_OUT_OF_BOUNDS = -2 |
| 17 | +RIGHT_OUT_OF_BOUNDS = -1 |
| 18 | + |
| 19 | + |
| 20 | +def _search_1d_array( |
| 21 | + arr: np.array, |
| 22 | + x: float, |
| 23 | +) -> tuple[int, int]: |
| 24 | + """ |
| 25 | + Searches for particle locations in a 1D array and returns barycentric coordinate along dimension. |
| 26 | +
|
| 27 | + Assumptions: |
| 28 | + - array is strictly monotonically increasing. |
| 29 | +
|
| 30 | + Parameters |
| 31 | + ---------- |
| 32 | + arr : np.array |
| 33 | + 1D array to search in. |
| 34 | + x : float |
| 35 | + Position in the 1D array to search for. |
| 36 | +
|
| 37 | + Returns |
| 38 | + ------- |
| 39 | + array of int |
| 40 | + Index of the element just before the position x in the array. Note that this index is -2 if the index is left out of bounds and -1 if the index is right out of bounds. |
| 41 | + array of float |
| 42 | + Barycentric coordinate. |
| 43 | + """ |
| 44 | + # TODO v4: We probably rework this to deal with 0D arrays before this point (as we already know field dimensionality) |
| 45 | + if len(arr) < 2: |
| 46 | + return np.zeros(shape=x.shape, dtype=np.int32), np.zeros_like(x) |
| 47 | + index = np.searchsorted(arr, x, side="right") - 1 |
| 48 | + # Use broadcasting to avoid repeated array access |
| 49 | + arr_index = arr[index] |
| 50 | + arr_next = arr[np.clip(index + 1, 1, len(arr) - 1)] # Ensure we don't go out of bounds |
| 51 | + bcoord = (x - arr_index) / (arr_next - arr_index) |
| 52 | + |
| 53 | + # TODO check how we can avoid searchsorted when grid spacing is uniform |
| 54 | + # dx = arr[1] - arr[0] |
| 55 | + # index = ((x - arr[0]) / dx).astype(int) |
| 56 | + # index = np.clip(index, 0, len(arr) - 2) |
| 57 | + # bcoord = (x - arr[index]) / dx |
| 58 | + |
| 59 | + index = np.where(x < arr[0], LEFT_OUT_OF_BOUNDS, index) |
| 60 | + index = np.where(x >= arr[-1], RIGHT_OUT_OF_BOUNDS, index) |
| 61 | + |
| 62 | + return np.atleast_1d(index), np.atleast_1d(bcoord) |
16 | 63 |
|
17 | 64 |
|
18 | 65 | def _search_time_index(field: Field, time: datetime): |
|
0 commit comments