Skip to content

Commit

Permalink
fix: use grid-stride looping for kernels with variable-length loops (#…
Browse files Browse the repository at this point in the history
…3130)

* fix: use grid-stride looping

* feat: add awkward_ListArray_getitem_next_range_carrylength kernel

* feat: add awkward_ListArray_getitem_next_range kernel

* test: add integration tests

* ignore 'Jitify is performing a one-time only warm-up' messages

---------

Co-authored-by: Jim Pivarski <[email protected]>
  • Loading branch information
ManasviGoyal and jpivarski authored May 29, 2024
1 parent 53dc0e2 commit 0b9f6f4
Show file tree
Hide file tree
Showing 28 changed files with 1,137 additions and 190 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,26 @@ ERROR awkward_ListArray_getitem_next_range(
int64_t step) {
int64_t k = 0;
tooffsets[0] = 0;
if (step > 0) {
for (int64_t i = 0; i < lenstarts; i++) {
int64_t length = fromstops[i] - fromstarts[i];
int64_t regular_start = start;
int64_t regular_stop = stop;
awkward_regularize_rangeslice(&regular_start, &regular_stop, step > 0,
start != kSliceNone, stop != kSliceNone,
length);
for (int64_t i = 0; i < lenstarts; i++) {
int64_t length = fromstops[i] - fromstarts[i];
int64_t regular_start = start;
int64_t regular_stop = stop;
awkward_regularize_rangeslice(&regular_start, &regular_stop, step > 0,
start != kSliceNone, stop != kSliceNone,
length);
if (step > 0) {
for (int64_t j = regular_start; j < regular_stop; j += step) {
tocarry[k] = fromstarts[i] + j;
k++;
}
tooffsets[i + 1] = (C)k;
}
}
else {
for (int64_t i = 0; i < lenstarts; i++) {
int64_t length = fromstops[i] - fromstarts[i];
int64_t regular_start = start;
int64_t regular_stop = stop;
awkward_regularize_rangeslice(&regular_start, &regular_stop, step > 0,
start != kSliceNone, stop != kSliceNone,
length);
else {
for (int64_t j = regular_start; j > regular_stop; j += step) {
tocarry[k] = fromstarts[i] + j;
k++;
}
tooffsets[i + 1] = (C)k;
}
tooffsets[i + 1] = (C)k;
}
return success();
}
Expand Down
2 changes: 2 additions & 0 deletions dev/generate-kernel-signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
"awkward_ListArray_getitem_next_array_advanced",
"awkward_ListArray_getitem_next_array",
"awkward_ListArray_getitem_next_at",
"awkward_ListArray_getitem_next_range",
"awkward_ListArray_getitem_next_range_carrylength",
"awkward_ListArray_getitem_next_range_counts",
"awkward_ListArray_rpad_and_clip_length_axis1",
"awkward_ListArray_rpad_axis1",
Expand Down
2 changes: 2 additions & 0 deletions dev/generate-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,8 @@ def gencpuunittests(specdict):
"awkward_ListArray_getitem_next_array_advanced",
"awkward_ListArray_getitem_next_array",
"awkward_ListArray_getitem_next_at",
"awkward_ListArray_getitem_next_range",
"awkward_ListArray_getitem_next_range_carrylength",
"awkward_ListArray_getitem_next_range_counts",
"awkward_ListArray_rpad_and_clip_length_axis1",
"awkward_ListArray_rpad_axis1",
Expand Down
46 changes: 16 additions & 30 deletions kernel-specification.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1762,44 +1762,30 @@ kernels:
):
k = 0
tooffsets[0] = 0
if step > 0:
for i in range(lenstarts):
length = fromstops[i] - fromstarts[i]
regular_start = start
regular_stop = stop
regular_start, regular_stop = awkward_regularize_rangeslice(
regular_start,
regular_stop,
step > 0,
start != kSliceNone,
stop != kSliceNone,
length,
)
j = regular_start
for i in range(lenstarts):
length = fromstops[i] - fromstarts[i]
regular_start = start
regular_stop = stop
regular_start, regular_stop = awkward_regularize_rangeslice(
regular_start,
regular_stop,
step > 0,
start != kSliceNone,
stop != kSliceNone,
length,
)
j = regular_start
if step > 0:
while j < regular_stop:
tocarry[k] = fromstarts[i] + j
k = k + 1
j += step
tooffsets[i + 1] = k
else:
for i in range(lenstarts):
length = fromstops[i] - fromstarts[i]
regular_start = start
regular_stop = stop
regular_start, regular_stop = awkward_regularize_rangeslice(
regular_start,
regular_stop,
step > 0,
start != kSliceNone,
stop != kSliceNone,
length,
)
j = regular_start
else:
while j > regular_stop:
tocarry[k] = fromstarts[i] + j
k = k + 1
j += step
tooffsets[i + 1] = k
tooffsets[i + 1] = k
automatic-tests: false

- name: awkward_ListArray_getitem_next_range_carrylength
Expand Down
195 changes: 195 additions & 0 deletions kernel-test-data.json
Original file line number Diff line number Diff line change
Expand Up @@ -13806,6 +13806,21 @@
"carrylength": [7]
}
},
{
"error": false,
"message": "",
"inputs": {
"fromstarts": [0, 2, 2, 3, 5],
"fromstops": [2, 2, 3, 5, 7],
"lenstarts": 5,
"start": 7,
"stop": 0,
"step": -1
},
"outputs": {
"carrylength": [3]
}
},
{
"error": false,
"message": "",
Expand All @@ -13820,6 +13835,186 @@
"outputs": {
"carrylength": [0]
}
},
{
"error": false,
"message": "",
"inputs": {
"fromstarts": [0, 2, 2, 3, 5],
"fromstops": [2, 2, 3, 5, 7],
"lenstarts": 5,
"start": 0,
"stop": 6,
"step": 2
},
"outputs": {
"carrylength": [4]
}
},
{
"error": false,
"message": "",
"inputs": {
"fromstarts": [0, 2, 2, 3, 5],
"fromstops": [2, 2, 3, 5, 7],
"lenstarts": 5,
"start": 7,
"stop": 0,
"step": -2
},
"outputs": {
"carrylength": [3]
}
}
]
},
{
"name": "awkward_ListArray_getitem_next_range",
"status": true,
"tests": [
{
"error": false,
"message": "",
"inputs": {
"fromstarts": [],
"fromstops": [],
"lenstarts": 0,
"start": 0,
"stop": 0,
"step": 0
},
"outputs": {
"tooffsets": [0],
"tocarry": []
}
},
{
"error": false,
"message": "",
"inputs": {
"fromstarts": [0, 2, 2, 3, 5],
"fromstops": [2, 2, 3, 5, 7],
"lenstarts": 5,
"start": 0,
"stop": 3,
"step": 1
},
"outputs": {
"tooffsets": [0, 2, 2, 3, 5, 7],
"tocarry": [0, 1, 2, 3, 4, 5, 6]
}
},
{
"error": false,
"message": "",
"inputs": {
"fromstarts": [0, 2, 2, 3, 5],
"fromstops": [2, 2, 3, 5, 7],
"lenstarts": 5,
"start": 7,
"stop": 0,
"step": -1
},
"outputs": {
"tooffsets": [0, 1, 1, 1, 2, 3],
"tocarry": [1, 4, 6]
}
},
{
"error": false,
"message": "",
"inputs": {
"fromstarts": [0, 2, 2, 3, 5],
"fromstops": [2, 2, 3, 5, 7],
"lenstarts": 5,
"start": 0,
"stop": 2,
"step": 0
},
"outputs": {
"tooffsets": [0, 0, 0, 0, 0, 0],
"tocarry": []
}
},
{
"error": false,
"message": "",
"inputs": {
"fromstarts": [0, 3, 3, 5, 7, 9],
"fromstops": [3, 3, 5, 7, 9, 11],
"lenstarts": 6,
"start": 0,
"stop": 6,
"step": 2
},
"outputs": {
"tooffsets": [0, 2, 2, 3, 4, 5, 6],
"tocarry": [0, 2, 3, 5, 7, 9]
}
},
{
"error": false,
"message": "",
"inputs": {
"fromstarts": [0, 3, 3, 5, 7, 9],
"fromstops": [3, 3, 5, 7, 9, 11],
"lenstarts": 6,
"start": 2,
"stop": 6,
"step": 2
},
"outputs": {
"tooffsets": [0, 1, 1, 1, 1, 1, 1],
"tocarry": [2]
}
},
{
"error": false,
"message": "",
"inputs": {
"fromstarts": [0, 3, 3, 5, 7, 9],
"fromstops": [3, 3, 5, 7, 9, 11],
"lenstarts": 6,
"start": 6,
"stop": 2,
"step": -2
},
"outputs": {
"tooffsets": [0, 0, 0, 0, 0, 0, 0],
"tocarry": []
}
},
{
"error": false,
"message": "",
"inputs": {
"fromstarts": [0, 3, 3, 5, 7, 9],
"fromstops": [3, 3, 5, 7, 9, 11],
"lenstarts": 6,
"start": 0,
"stop": 6,
"step": -2
},
"outputs": {
"tooffsets": [0, 0, 0, 0, 0, 0, 0],
"tocarry": []
}
},
{
"error": false,
"message": "",
"inputs": {
"fromstarts": [0, 2, 2, 3, 5],
"fromstops": [2, 2, 3, 5, 7],
"lenstarts": 5,
"start": 7,
"stop": 0,
"step": -2
},
"outputs": {
"tooffsets": [0, 1, 1, 1, 2, 3],
"tocarry": [1, 4, 6]
}
}
]
},
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ filterwarnings = [
"ignore:The NumPy module was reloaded:UserWarning",
"ignore:.*np\\.MachAr.*:DeprecationWarning",
"ignore:module 'sre_.*' is deprecated:DeprecationWarning",
"ignore:Jitify is performing a one-time only warm-up",
]
log_cli_level = "info"

Expand Down
2 changes: 2 additions & 0 deletions src/awkward/_connect/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def fetch_template_specializations(kernel_dict):
"awkward_ListArray_getitem_jagged_carrylen",
"awkward_ListArray_getitem_jagged_descend",
"awkward_ListArray_getitem_jagged_numvalid",
"awkward_ListArray_getitem_next_range",
"awkward_ListArray_getitem_next_range_carrylength",
"awkward_ListArray_min_range",
"awkward_ListArray_rpad_and_clip_length_axis1",
"awkward_ListArray_rpad_axis1",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ awkward_IndexedArray_getitem_nextcarry_a(
C j = fromindex[thread_id];
if (j < 0 || j >= lencontent) {
RAISE_ERROR(INDEXEDARRAY_GETITEM_NEXTCARRY_ERRORS::IND_OUT_OF_RANGE)
} else if (j >= 0) {
} else {
scan_in_array[thread_id] = 1;
}
}
Expand All @@ -55,7 +55,7 @@ awkward_IndexedArray_getitem_nextcarry_b(
C j = fromindex[thread_id];
if (j < 0 || j >= lencontent) {
RAISE_ERROR(INDEXEDARRAY_GETITEM_NEXTCARRY_ERRORS::IND_OUT_OF_RANGE)
} else if (j >= 0) {
} else {
tocarry[scan_in_array[thread_id] - 1] = j;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ awkward_IndexedArray_ranges_carry_next_64_a(

if (thread_id < length) {
stride = fromstops[thread_id] - fromstarts[thread_id];
for (int64_t j = 0; j < stride; j++) {
for (int64_t j = threadIdx.y; j < stride; j += blockDim.y) {
if (!(index[fromstarts[thread_id] + j] < 0)) {
scan_in_array[fromstarts[thread_id] + j] = 1;
}
Expand All @@ -54,7 +54,7 @@ awkward_IndexedArray_ranges_carry_next_64_b(

if (thread_id < length) {
stride = fromstops[thread_id] - fromstarts[thread_id];
for (int64_t j = 0; j < stride; j++) {
for (int64_t j = threadIdx.y; j < stride; j += blockDim.y) {
if (!(index[fromstarts[thread_id] + j] < 0)) {
tocarry[scan_in_array[fromstarts[thread_id] + j] - 1] = index[fromstarts[thread_id] + j];
}
Expand Down
Loading

0 comments on commit 0b9f6f4

Please sign in to comment.