Skip to content

v1.9.0

Latest
Compare
Choose a tag to compare
@github-actions github-actions released this 05 Sep 03:54
· 68 commits to main since this release
v1.9.0

Warp 1.9 ships with a rewritten marching cubes implementation, compatibility with the CUDA 13 toolkit, and new functions for ahead-of-time module compilation. The programming model has also been enhanced with more flexible indexing for composite types, direct IntEnum support, and the ability to initialize local arrays in kernels.

New Features

Differentiable marching cubes

A fully differentiable wp.MarchingCubes implementation, contributed by @mikacuy and @nmwsharp, has been added. This version is written entirely in Warp, replacing the previous native CUDA C++ implementation and enabling it to run on both CPU and GPU devices. The implementation also addresses a long-standing off-by-one bug (#324). For more details, see the updated documentation.

Functions for module compilation and loading

We have added wp.compile_aot_module() and wp.load_aot_module() for more flexible ahead-of-time (AOT) compilation.

These functions include a strip_hash=True argument, which removes the unique hashes from compiled module and function
names. This change makes it possible to distribute pre-compiled modules without shipping the original Python source code.

See the documentation on ahead-of-time compilation workflows for more details. In future releases, we plan to continue to expand Warp's support for ahead-of-time workflows.

CUDA 13 Support

CUDA Toolkit 13.0 was released in early August.

PyPI Distribution: Warp wheels on PyPI and NVIDIA PyPI will continue to be built with CUDA 12.8 to provide a transition period for users upgrading their CUDA drivers.

CUDA 13.0 Compatibility: Users requiring Warp compiled against CUDA 13.x have two options:

  • Build Warp from source
  • Install pre-built wheels from GitHub releases

Driver Compatibility: CUDA 12.8 Warp wheels can run on systems with CUDA 13.x drivers thanks to CUDA's backward compatibility.

Performance Improvements

Graph-capturable linear solvers

The iterative linear solvers in warp.optim.linear (CG, BiCGSTAB, GMRES) are now fully compatible with CUDA graph capture. This adds support for device-side convergence checking via wp.capture_while(), enabling full CUDA graph capture when check_every=0. Users can now choose between traditional host-side convergence checks or fully graph-capturable device-side termination.

Automatic tiling for sparse linear algebra

warp.sparse now supports arbitrary-sized blocks and can leverage tile-based computations for certain matrix types. The system automatically chooses between tiled and non-tiled execution using heuristics based on matrix characteristics (block sizes, sparsity patterns, and workload dimensions). Note that the heuristic for choosing between tiled and non-tiled variants is still being refined, and that it can be manually overridden by providing the tile_size parameter to bsr_mm or bsr_mv.

Automatic tiling for finite element quadrature

warp.fem.integrate now leverages tile-based computations for quadrature point accumulation, with automatic tile size selection based on workload characteristics. The system automatically chooses between tiled and non-tiled execution to optimize performance based on the integration problem size and complexity.

Programming Model Updates

Slice and negative indexing improvements for composite types

We have enhanced the support for slice operations and negative indexing across all composite types (vectors, matrices, quaternions, and transforms).

m = wp.matrix_from_rows(
    wp.vec3(1.0, 2.0, 3.0),
    wp.vec3(4.0, 5.0, 6.0),
    wp.vec3(7.0, 8.0, 9.0),
)
subm = m[:-1, 1:]
print(subm)
# [[2.0, 3.0],
#  [5.0, 6.0]]

Support for IntEnum and IntFlag inside kernels

It is now possible to directly reference IntEnum and IntFlag values inside Warp functions and kernels. Previously, workarounds involving wp.static() were required.

from enum import IntEnum

class JointType(IntEnum):
    PRISMATIC = 0
    REVOLUTE = 1
    BALL = 2

@wp.kernel
def count_revolute_joints(
    joint_types: wp.array(dtype=JointType),
    counter: wp.array(dtype=int)
):
    tid = wp.tid()
    joint = joint_types[tid]

    # No longer requires wp.static(JointType.REVOLUTE.value)
    if joint == JointType.REVOLUTE:
        wp.atomic_add(counter, 0, 1)

Improved support for wp.array() views inside kernels

This enhancement allows kernels to create array views by accessing the ptr attribute of an array.

@wp.kernel
def kernel_array_from_ptr(arr_orig: wp.array2d(dtype=wp.float32)):
    arr = wp.array(ptr=arr_orig.ptr, shape=(2, 3), dtype=wp.float32)
    arr[0, 0] = 1.0
    arr[0, 1] = 2.0
    arr[0, 2] = 3.0

Additionally, these in-kernel views now support dynamic shapes and struct types.

Support for initializing fixed-size arrays inside kernels

It is now possible to allocate local arrays of a fixed size in kernels using wp.zeros(). The resulting arrays are allocated in registers, providing fast access and avoiding global memory overhead.

Previously, developers needed to create vectors to achieve a similar capability, e.g. v = wp.vector(length=8, dtype=float), but this came with various limitations.

@wp.kernel
def kernel_with_local_array():
    local_arr = wp.zeros(8, dtype=wp.float32)  # Allocated in registers
    # ... use local_arr

Indexed tile operations

Warp now provides three new indexed tile operations that enable more flexible memory access patterns beyond simple contiguous tile operations. These functions allow you to load, store, and perform atomic operations on tiles using custom index mappings along specified axes.

x = wp.array(
    [
        [0.77395605, 0.43887844, 0.85859792, 0.69736803],
        [0.09417735, 0.97562235, 0.7611397, 0.78606431],
        [0.12811363, 0.45038594, 0.37079802, 0.92676499],
    ],
    dtype=float,
)

indices = wp.array([0, 2], dtype=int)


@wp.kernel
def indexed_data_lookup(data: wp.array2d(dtype=float), indices: wp.array(dtype=int)):
    # [0 2] = tile(shape=(2), storage=shared)
    indices_tile = wp.tile_load(indices, shape=(2,))

    # [[0.773956 0.438878 0.858598 0.697368]
    #  [0.128114 0.450386 0.370798 0.926765]] = tile(shape=(2,4), storage=register)
    data_rows_tile = wp.tile_load_indexed(data, indices_tile, axis=0, shape=(2, 4))
    print(data_rows_tile)

    # [[0.773956 0.858598]
    #  [0.0941774 0.76114]
    #  [0.128114 0.370798]] = tile(shape=(3,2), storage=register)
    data_columns_tile = wp.tile_load_indexed(data, indices_tile, axis=1, shape=(3, 2))


wp.launch_tiled(indexed_data_lookup, dim=1, inputs=[x, indices], block_dim=2)

Fixed nested matrix component support

Warp now properly supports writing to individual matrix elements stored within struct fields. Previously, operations like struct.matrix[1, 2] = value would result in a compile-time error.

@wp.struct
class MatStruct:
    m: wp.mat44

@wp.kernel
def kernel_nested_mat(out: wp.array(dtype=MatStruct)):
    s = MatStruct()
    s.m[1, 2] = 3.0  # This now works correctly (no longer raises a WarpCodegenError)
    s.m[2][2] = 5.0  # This has also been fixed (used to silently fail)
    out[0] = s

Announcements

Known limitations

Early testing on NVIDIA Jetson Thor indicates that launching CPU kernels may sometimes result in segmentation faults. GPU kernel launches are unaffected. We believe this can be resolved by building Warp from source against LLVM/Clang version 18 or newer.

Upcoming removals

The following features have been deprecated in prior releases and will be removed in v1.10 (early November):

  • warp.sim - Use the Newton engine.
  • Constructing a wp.matrix() from column vectors - Use wp.matrix_from_rows() or wp.matrix_from_cols() instead.
  • wp.select() - Use wp.where() instead (note: different argument order).
  • wp.matrix(pos, quat, scale) - Use wp.transform_compose() instead.

Platform support

  • We plan to drop support for Intel macOS (x86-64) in a future release (tentatively planned for v1.10).

Acknowledgments

We thank the following contributors for their valuable contributions to this release:

  • @liblaf for fixing an issue with using warp.jax_experimental.ffi.jax_callable() with a function annotated with the -> None return type (#893).
  • @matthewdcong for providing an updated version of NanoVDB compatible with CUDA 13 (#888).
  • @YuyangLee for contributing an early prototype that helped shape the strip_hash=True option for the new ahead-of-time compilation functions (#661).

Full Changelog

For a curated list of all changes in this release, please see the v1.9.0 section in CHANGELOG.md.