|
| 1 | +import concurrent.futures |
1 | 2 | import threading |
| 3 | +import string |
2 | 4 |
|
3 | 5 | import numpy as np |
4 | 6 | import pytest |
@@ -165,3 +167,90 @@ def closure(b): |
165 | 167 | x = np.repeat(x0, 2, axis=0)[::2] |
166 | 168 |
|
167 | 169 | run_threaded(closure, max_workers=10, pass_barrier=True) |
| 170 | + |
| 171 | + |
| 172 | +def test_structured_advanced_indexing(): |
| 173 | + # Test that copyswap(n) used by integer array indexing is threadsafe |
| 174 | + # for structured datatypes, see gh-15387. This test can behave randomly. |
| 175 | + |
| 176 | + # Create a deeply nested dtype to make a failure more likely: |
| 177 | + dt = np.dtype([("", "f8")]) |
| 178 | + dt = np.dtype([("", dt)] * 2) |
| 179 | + dt = np.dtype([("", dt)] * 2) |
| 180 | + # The array should be large enough to likely run into threading issues |
| 181 | + arr = np.random.uniform(size=(6000, 8)).view(dt)[:, 0] |
| 182 | + |
| 183 | + rng = np.random.default_rng() |
| 184 | + |
| 185 | + def func(arr): |
| 186 | + indx = rng.integers(0, len(arr), size=6000, dtype=np.intp) |
| 187 | + arr[indx] |
| 188 | + |
| 189 | + tpe = concurrent.futures.ThreadPoolExecutor(max_workers=8) |
| 190 | + futures = [tpe.submit(func, arr) for _ in range(10)] |
| 191 | + for f in futures: |
| 192 | + f.result() |
| 193 | + |
| 194 | + assert arr.dtype is dt |
| 195 | + |
| 196 | + |
| 197 | +def test_structured_threadsafety2(): |
| 198 | + # Nonzero (and some other functions) should be threadsafe for |
| 199 | + # structured datatypes, see gh-15387. This test can behave randomly. |
| 200 | + from concurrent.futures import ThreadPoolExecutor |
| 201 | + |
| 202 | + # Create a deeply nested dtype to make a failure more likely: |
| 203 | + dt = np.dtype([("", "f8")]) |
| 204 | + dt = np.dtype([("", dt)]) |
| 205 | + dt = np.dtype([("", dt)] * 2) |
| 206 | + # The array should be large enough to likely run into threading issues |
| 207 | + arr = np.random.uniform(size=(5000, 4)).view(dt)[:, 0] |
| 208 | + |
| 209 | + def func(arr): |
| 210 | + arr.nonzero() |
| 211 | + |
| 212 | + tpe = ThreadPoolExecutor(max_workers=8) |
| 213 | + futures = [tpe.submit(func, arr) for _ in range(10)] |
| 214 | + for f in futures: |
| 215 | + f.result() |
| 216 | + |
| 217 | + assert arr.dtype is dt |
| 218 | + |
| 219 | + |
| 220 | +def test_stringdtype_multithreaded_access_and_mutation( |
| 221 | + dtype, random_string_list): |
| 222 | + # this test uses an RNG and may crash or cause deadlocks if there is a |
| 223 | + # threading bug |
| 224 | + rng = np.random.default_rng(0x4D3D3D3) |
| 225 | + |
| 226 | + chars = list(string.ascii_letters + string.digits) |
| 227 | + chars = np.array(chars, dtype="U1") |
| 228 | + ret = rng.choice(chars, size=100 * 10, replace=True) |
| 229 | + random_string_list = ret.view("U100") |
| 230 | + |
| 231 | + def func(arr): |
| 232 | + rnd = rng.random() |
| 233 | + # either write to random locations in the array, compute a ufunc, or |
| 234 | + # re-initialize the array |
| 235 | + if rnd < 0.25: |
| 236 | + num = np.random.randint(0, arr.size) |
| 237 | + arr[num] = arr[num] + "hello" |
| 238 | + elif rnd < 0.5: |
| 239 | + if rnd < 0.375: |
| 240 | + np.add(arr, arr) |
| 241 | + else: |
| 242 | + np.add(arr, arr, out=arr) |
| 243 | + elif rnd < 0.75: |
| 244 | + if rnd < 0.875: |
| 245 | + np.multiply(arr, np.int64(2)) |
| 246 | + else: |
| 247 | + np.multiply(arr, np.int64(2), out=arr) |
| 248 | + else: |
| 249 | + arr[:] = random_string_list |
| 250 | + |
| 251 | + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as tpe: |
| 252 | + arr = np.array(random_string_list, dtype=dtype) |
| 253 | + futures = [tpe.submit(func, arr) for _ in range(500)] |
| 254 | + |
| 255 | + for f in futures: |
| 256 | + f.result() |
0 commit comments