Skip to content

Commit ab88f0f

Browse files
committed
Merge branch 'allow_dynamic_shape_in_wp_array_creation_from_ptr' into 'main'
Allow non-constant shapes when creating an array in a kernel using wp.array(ptr=... See merge request omniverse/warp!1415
2 parents 6075242 + 717c3d2 commit ab88f0f

File tree

3 files changed

+27
-2
lines changed

3 files changed

+27
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
libraries ([GH-792](https://github.com/NVIDIA/warp/issues/792)).
1717
- Update `wp.MarchingCubes` to a pure-Warp implementation, allowing cross-platform support and differentiability.
1818
([GH-788](https://github.com/NVIDIA/warp/issues/788)).
19+
- Constructing wp.array objects from a pointer inside Warp kernels (e.g., wp.array(ptr=..., shape=...)) no longer requires the shape to be a compile-time constant, allowing for greater flexibility.
1920

2021
### Fixed
2122

warp/builtins.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5544,7 +5544,7 @@ def array_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any
55445544
return array(dtype=Scalar)
55455545

55465546
dtype = arg_values["dtype"]
5547-
shape = extract_tuple(arg_values["shape"], as_constant=True)
5547+
shape = extract_tuple(arg_values["shape"], as_constant=False)
55485548
return array(dtype=dtype, ndim=len(shape))
55495549

55505550

@@ -5554,7 +5554,7 @@ def array_dispatch_func(input_types: Mapping[str, type], return_type: Any, args:
55545554
# to the underlying C++ function's runtime and template params.
55555555

55565556
dtype = return_type.dtype
5557-
shape = extract_tuple(args["shape"], as_constant=True)
5557+
shape = extract_tuple(args["shape"], as_constant=False)
55585558

55595559
func_args = (args["ptr"], *shape)
55605560
template_args = (dtype,)

warp/tests/test_array.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2917,6 +2917,27 @@ def test_kernel_array_from_ptr(test, device):
29172917
assert_np_equal(arr.numpy(), np.array(((1.0, 2.0, 3.0), (0.0, 0.0, 0.0))))
29182918

29192919

2920+
@wp.kernel
2921+
def kernel_array_from_ptr_variable_shape(
2922+
ptr: wp.uint64,
2923+
shape_x: int,
2924+
shape_y: int,
2925+
):
2926+
arr = wp.array(ptr=ptr, shape=(shape_x, shape_y), dtype=wp.float32)
2927+
arr[0, 0] = 1.0
2928+
arr[0, 1] = 2.0
2929+
if shape_y > 2:
2930+
arr[0, 2] = 3.0
2931+
2932+
2933+
def test_kernel_array_from_ptr_variable_shape(test, device):
2934+
arr = wp.zeros(shape=(2, 3), dtype=wp.float32, device=device)
2935+
wp.launch(kernel_array_from_ptr_variable_shape, dim=(1,), inputs=(arr.ptr, 2, 2), device=device)
2936+
assert_np_equal(arr.numpy(), np.array(((1.0, 2.0, 0.0), (0.0, 0.0, 0.0))))
2937+
wp.launch(kernel_array_from_ptr_variable_shape, dim=(1,), inputs=(arr.ptr, 2, 3), device=device)
2938+
assert_np_equal(arr.numpy(), np.array(((1.0, 2.0, 3.0), (0.0, 0.0, 0.0))))
2939+
2940+
29202941
def test_array_from_int32_domain(test, device):
29212942
wp.zeros(np.array([1504, 1080, 520], dtype=np.int32), dtype=wp.float32, device=device)
29222943

@@ -3185,6 +3206,9 @@ def test_array_new_del(self):
31853206
add_function_test(TestArray, "test_array_inplace_non_diff_ops", test_array_inplace_non_diff_ops, devices=devices)
31863207
add_function_test(TestArray, "test_direct_from_numpy", test_direct_from_numpy, devices=["cpu"])
31873208
add_function_test(TestArray, "test_kernel_array_from_ptr", test_kernel_array_from_ptr, devices=devices)
3209+
add_function_test(
3210+
TestArray, "test_kernel_array_from_ptr_variable_shape", test_kernel_array_from_ptr_variable_shape, devices=devices
3211+
)
31883212

31893213
add_function_test(TestArray, "test_array_from_int32_domain", test_array_from_int32_domain, devices=devices)
31903214
add_function_test(TestArray, "test_array_from_int64_domain", test_array_from_int64_domain, devices=devices)

0 commit comments

Comments
 (0)