Skip to content

Releases: NVIDIA/warp

v1.9.0

05 Sep 03:54
v1.9.0
Compare
Choose a tag to compare

Warp 1.9 ships with a rewritten marching cubes implementation, compatibility with the CUDA 13 toolkit, and new functions for ahead-of-time module compilation. The programming model has also been enhanced with more flexible indexing for composite types, direct IntEnum support, and the ability to initialize local arrays in kernels.

New Features

Differentiable marching cubes

A fully differentiable wp.MarchingCubes implementation, contributed by @mikacuy and @nmwsharp, has been added. This version is written entirely in Warp, replacing the previous native CUDA C++ implementation and enabling it to run on both CPU and GPU devices. The implementation also addresses a long-standing off-by-one bug (#324). For more details, see the updated documentation.

Functions for module compilation and loading

We have added wp.compile_aot_module() and wp.load_aot_module() for more flexible ahead-of-time (AOT) compilation.

These functions include a strip_hash=True argument, which removes the unique hashes from compiled module and function
names. This change makes it possible to distribute pre-compiled modules without shipping the original Python source code.

See the documentation on ahead-of-time compilation workflows for more details. In future releases, we plan to continue to expand Warp's support for ahead-of-time workflows.

CUDA 13 Support

CUDA Toolkit 13.0 was released in early August.

PyPI Distribution: Warp wheels on PyPI and NVIDIA PyPI will continue to be built with CUDA 12.8 to provide a transition period for users upgrading their CUDA drivers.

CUDA 13.0 Compatibility: Users requiring Warp compiled against CUDA 13.x have two options:

  • Build Warp from source
  • Install pre-built wheels from GitHub releases

Driver Compatibility: CUDA 12.8 Warp wheels can run on systems with CUDA 13.x drivers thanks to CUDA's backward compatibility.

Performance Improvements

Graph-capturable linear solvers

The iterative linear solvers in warp.optim.linear (CG, BiCGSTAB, GMRES) are now fully compatible with CUDA graph capture. This adds support for device-side convergence checking via wp.capture_while(), enabling full CUDA graph capture when check_every=0. Users can now choose between traditional host-side convergence checks or fully graph-capturable device-side termination.

Automatic tiling for sparse linear algebra

warp.sparse now supports arbitrary-sized blocks and can leverage tile-based computations for certain matrix types. The system automatically chooses between tiled and non-tiled execution using heuristics based on matrix characteristics (block sizes, sparsity patterns, and workload dimensions). Note that the heuristic for choosing between tiled and non-tiled variants is still being refined, and that it can be manually overridden by providing the tile_size parameter to bsr_mm or bsr_mv.

Automatic tiling for finite element quadrature

warp.fem.integrate now leverages tile-based computations for quadrature point accumulation, with automatic tile size selection based on workload characteristics. The system automatically chooses between tiled and non-tiled execution to optimize performance based on the integration problem size and complexity.

Programming Model Updates

Slice and negative indexing improvements for composite types

We have enhanced the support for slice operations and negative indexing across all composite types (vectors, matrices, quaternions, and transforms).

m = wp.matrix_from_rows(
    wp.vec3(1.0, 2.0, 3.0),
    wp.vec3(4.0, 5.0, 6.0),
    wp.vec3(7.0, 8.0, 9.0),
)
subm = m[:-1, 1:]
print(subm)
# [[2.0, 3.0],
#  [5.0, 6.0]]

Support for IntEnum and IntFlag inside kernels

It is now possible to directly reference IntEnum and IntFlag values inside Warp functions and kernels. Previously, workarounds involving wp.static() were required.

from enum import IntEnum

class JointType(IntEnum):
    PRISMATIC = 0
    REVOLUTE = 1
    BALL = 2

@wp.kernel
def count_revolute_joints(
    joint_types: wp.array(dtype=JointType),
    counter: wp.array(dtype=int)
):
    tid = wp.tid()
    joint = joint_types[tid]

    # No longer requires wp.static(JointType.REVOLUTE.value)
    if joint == JointType.REVOLUTE:
        wp.atomic_add(counter, 0, 1)

Improved support for wp.array() views inside kernels

This enhancement allows kernels to create array views by accessing the ptr attribute of an array.

@wp.kernel
def kernel_array_from_ptr(arr_orig: wp.array2d(dtype=wp.float32)):
    arr = wp.array(ptr=arr_orig.ptr, shape=(2, 3), dtype=wp.float32)
    arr[0, 0] = 1.0
    arr[0, 1] = 2.0
    arr[0, 2] = 3.0

Additionally, these in-kernel views now support dynamic shapes and struct types.

Support for initializing fixed-size arrays inside kernels

It is now possible to allocate local arrays of a fixed size in kernels using wp.zeros(). The resulting arrays are allocated in registers, providing fast access and avoiding global memory overhead.

Previously, developers needed to create vectors to achieve a similar capability, e.g. v = wp.vector(length=8, dtype=float), but this came with various limitations.

@wp.kernel
def kernel_with_local_array():
    local_arr = wp.zeros(8, dtype=wp.float32)  # Allocated in registers
    # ... use local_arr

Indexed tile operations

Warp now provides three new indexed tile operations that enable more flexible memory access patterns beyond simple contiguous tile operations. These functions allow you to load, store, and perform atomic operations on tiles using custom index mappings along specified axes.

x = wp.array(
    [
        [0.77395605, 0.43887844, 0.85859792, 0.69736803],
        [0.09417735, 0.97562235, 0.7611397, 0.78606431],
        [0.12811363, 0.45038594, 0.37079802, 0.92676499],
    ],
    dtype=float,
)

indices = wp.array([0, 2], dtype=int)


@wp.kernel
def indexed_data_lookup(data: wp.array2d(dtype=float), indices: wp.array(dtype=int)):
    # [0 2] = tile(shape=(2), storage=shared)
    indices_tile = wp.tile_load(indices, shape=(2,))

    # [[0.773956 0.438878 0.858598 0.697368]
    #  [0.128114 0.450386 0.370798 0.926765]] = tile(shape=(2,4), storage=register)
    data_rows_tile = wp.tile_load_indexed(data, indices_tile, axis=0, shape=(2, 4))
    print(data_rows_tile)

    # [[0.773956 0.858598]
    #  [0.0941774 0.76114]
    #  [0.128114 0.370798]] = tile(shape=(3,2), storage=register)
    data_columns_tile = wp.tile_load_indexed(data, indices_tile, axis=1, shape=(3, 2))


wp.launch_tiled(indexed_data_lookup, dim=1, inputs=[x, indices], block_dim=2)

Fixed nested matrix component support

Warp now properly supports writing to individual matrix elements stored within struct fields. Previously, operations like struct.matrix[1, 2] = value would result in a compile-time error.

@wp.struct
class MatStruct:
    m: wp.mat44

@wp.kernel
def kernel_nested_mat(out: wp.array(dtype=MatStruct)):
    s = MatStruct()
    s.m[1, 2] = 3.0  # This now works correctly (no longer raises a WarpCodegenError)
    s.m[2][2] = 5.0  # This has also been fixed (used to silently fail)
    out[0] = s

Announcements

Known limitations

Early testing on NVIDIA Jetson Thor indicates that launching CPU kernels may sometimes result in segmentation faults. GPU kernel launches are unaffected. We believe this can be resolved by building Warp from source against LLVM/Clang version 18 or newer.

Upcoming removals

The following features have been deprecated in prior releases and will be removed in v1.10 (early November):

  • warp.sim - Use the Newton engine.
  • Constructing a wp.matrix() from column vectors - Use wp.matrix_from_rows() or wp.matrix_from_cols() instead.
  • wp.select() - Use wp.where() instead (note: different argument order).
  • wp.matrix(pos, quat, scale) - Use wp.transform_compose() instead.

Platform support

  • We plan to drop support for Intel macOS (x86-64) in a future release (tentatively planned for v1.10).

Acknowledgments

We thank the following contributors for their valuable contributions to this release:

  • @liblaf for fixing an issue with using warp.jax_experimental.ffi.jax_callable() with a function annotated with the -> None return type (#893).
  • @matthewdcong for providing an updated version of NanoVDB compatible with CUDA 13 (#888).
  • @YuyangLee for contributing an early prototype that helped shape the strip_hash=True option for the new ahead-of-time compilation functions (#661).

Full Changelog

For a curated list of all changes in this release, please see the v1.9.0 section in CHANGELOG.md.

v1.9.0rc1

20 Aug 15:59
v1.9.0rc1
d641e89
Compare
Choose a tag to compare
v1.9.0rc1 Pre-release
Pre-release

Release candidate for Isaac Lab testing.

v1.8.1

01 Aug 17:41
v1.8.1
ad1092b
Compare
Choose a tag to compare

This patch release primarily contains bug fixes as expected.

However, to support the adoption of Warp by the MuJoCo MJX physics engine, it also includes new features and deprecations limited to the jax_experimental module. We are flagging this deviation from our standard versioning practices to ensure clarity. Normal versioning practices will resume with the next release.

Full Changelog

Deprecated

  • This is the final release that will provide builds for or support the CUDA 11.x Toolkit and driver. Starting with v1.9.0, Warp will require CUDA 12.x or newer.
  • Deprecate the graph_compatible boolean flag in jax_callable() in favor of the new graph_mode argument with GraphMode enum (#848).

Added

  • Add documentation for creating and manipulating Warp structured arrays using NumPy (#852)
  • Add documentation for wp.indexedarray() (#468).
  • Support input-output aliasing in JAX FFI (#815).
  • Support capturing jax_callable() using Warp via the new graph_mode parameter (GraphMode.WARP), enabling capture of graphs with conditional nodes that cannot be used as subgraphs in a JAX capture (#848).

Fixed

  • Fix tape.zero() to correctly reset gradient arrays in nested structs (#807).
  • Fix incorrect adjoints for div(scalar, vec), div(scalar, mat), and div(scalar, quat), and other miscellaneous issues with adjoints (#831).
  • Fix a module-hashing issue for functions or kernels using static expressions that cannot be resolved at the time of declaration (#830).
  • Fix a bug in which changes to wp.config.mode were not being picked up after module initialization (#856).
  • Fix a bug where CUDA modules could get prematurely unloaded when conditional graph nodes are used.
  • Fix compile time regression for kernels using matmul, Cholesky, and FFT solvers by upgrading to libmathdx 0.2.2 (#809).
  • Fix potential uninitialized memory issues in wp.tile_sort() (#836).
  • Fix wp.tile_min() and wp.tile_argmin() to return correct values for large tiles with low occupancy (#725).
  • Fix codegen errors associated with adjoint of wp.tile_sum() when using shared tiles (#822).
  • Fix driver entry point error for cuDeviceGetUuid caused by using an incorrect version (#851).
  • Fix an issue that caused Warp to request PTX generation from NVRTC for architectures unsupported by the compiler (#858).
  • Fix a regression where wp.sparse.bsr_from_triplets() ignored the prune_numerical_zeros=False setting (#832).
  • Fix missing cloth-body contact in wp.sim.VBDIntegrator with handle_self_contact=False (#862).
  • Fix a bug causing potential infinite loops in the color balancing calculation (#816).
  • Fix box-box collision by computing the contact normal at the closest point of approach instead of at the center of the source box (#839).
  • Fix the OpenGL renderer not correctly displaying colors for box shapes (#810).
  • Fix a bug in OpenGLRenderer where meshes with different scale attributes were incorrectly instanced, causing them all to be rendered with the same scale OpenGLRenderer (#828).

v1.8.0

01 Jul 18:32
v1.8.0
Compare
Choose a tag to compare

Changelog

[1.8.0] - 2025-07-01

Added

  • Add wp.map() to map a function over arrays and add math operators for Warp arrays (docs, #694).
  • Add support for dynamic control flow in CUDA graphs, see wp.capture_if() and wp.capture_while() (docs, #597).
  • Add wp.capture_debug_dot_print() to write a DOT file describing the structure of a captured CUDA graph (#746).
  • Add the Device.sm_count property to get the number of streaming multiprocessors on a CUDA device (#584).
  • Add wp.block_dim() to query the number of threads in the current block inside a kernel (#695).
  • Add wp.atomic_cas() and wp.atomic_exch() built-ins for atomic compare-and-swap and exchange operations (#767).
  • Add support for profiling GPU runtime module compilation using the global wp.config.compile_time_trace setting or the module-level "compile_time_trace" option. When used, JSON files in the Trace Event format will be written in the kernel cache, which can be opened in a viewer like chrome://tracing/ (docs, #609).
  • Add support for returning multiple values from native functions like wp.svd3() and wp.quat_to_axis_angle() (#503).
  • Add support for passing tiles to user wp.func functions (#682).
  • Add wp.tile_squeeze() to remove axes of length one (#662).
  • Add wp.tile_reshape() to reshape a tile (#663).
  • Add wp.tile_astype() to return a new tile with the same data but different data type. (#683).
  • Add support for in-place tile add and subtract operations (#518).
  • Add support for in-place tile-component addition and subtraction (#659).
  • Add support for 2D solves using wp.tile_cholesky_solve() (#773).
  • Add wp.tile_scan_inclusive() and wp.tile_scan_exclusive() for performing inclusive and exclusive scans over tiles (#731).
  • Support attribute indexing for quaternions on the right-hand side of expressions (#625).
  • Add wp.transform_compose() and wp.transform_decompose() for converting between transforms and 4x4 matrices with 3D scale information (#576).
  • Add various wp.transform syntax operations for loading and storing (#710).
  • Add the as_spheres parameter to UsdRenderer.render_points() in order to choose whether to render the points as USD spheres using a point instancer or as simple USD points (#634).
  • Add support for animating visibility of objects in the USD renderer (#598).
  • Add wp.sim.VBDIntegrator.rebuild_bvh() to rebuild the BVH used for detecting self-contacts.
  • Add damping terms wp.sim.VBDIntegrator collisions, with strength is controlled by Model.soft_contact_kd.
  • Improve consistency of the wp.fem.lookup() operator across geometries and add filtering parameters (#618).
  • Add two examples demonstrating shape optimization using warp.fem: fem/example_elastic_shape_optimization.py and fem/example_darcy_ls_optimization.py (#698).
  • Add a py.typed marker file (per PEP 561) to the package to formally support static type checking by downstream users (#780).

Removed

  • Remove wp.mlp() (deprecated in v1.6.0). Use tile primitives instead.
  • Remove wp.autograd.plot_kernel_jacobians() (deprecated in v1.4.0). Use wp.autograd.jacobian_plot() instead.
  • Remove the length and owner keyword arguments from wp.array() constructor (deprecated in v1.6.0). Use the shape and deleter keywords instead.
  • Remove the kernel keyword argument from wp.autograd.jacobian() and wp.autograd.jacobian_fd() (deprecated in v1.6.0). Use the function keyword argument instead.
  • Remove the outputs keyword argument from wp.autograd.jacobian_plot() (deprecated in v1.6.0).

Changed

  • Deprecate the warp.sim module (planned for removal in v1.10). It will be superseded by the upcoming Newton library, a separate package with a new API. Migrating will require code changes; a future guide will be provided (current draft). See the GitHub announcement for details (#735).
  • Deprecate the wp.matrix(pos, quat, scale) built-in function. Use wp.transform_compose() instead (#576).
  • Improve support for tuples in kernels (#506).
  • Return a constant value from len() where possible.
  • Rename the internal function wp.types.type_length() to wp.types.type_size().
  • Rename wp.tile_cholesky_solve() input parameters to align with its docstring (#726).
  • Change wp.tile_upper_solve() and wp.tile_lower_solve() to use libmathdx 0.2.1 TRSM solver (#773).
  • Skip adjoint compilation for wp.tile_matmul() if enable_backward is disabled (#644).
  • Allow tile reductions to work with non-scalar tile types (#771).
  • Permit data-type preservation with preserve_type=True when tiling a value across the block with wp.Tile() (#772).
  • Make wp.sparse.bsr_[set_]from_triplets differentiable with respect to the input triplet values (#760).
  • Expose new warp.fem operators: node_count, node_index, element_coordinates, element_closest_point.
  • Change wp.sim.VBDIntegrator rigid-body-contact handling to use only the shape's friction coefficient, rather than averaging the shape's and the cloth's coefficients.
  • Limit usage of the wp.assign_copy() hidden built-in to the kernel scope.
  • Describe the distinction between inputs and outputs arguments in the Kernel documentation.
  • Reduce the overhead of wp.launch() by avoiding costly native API calls (#774).
  • Improve error reporting when calling @wp.func-decorated functions from the Python scope (#521).

Fixed

  • Fix missing documentation for geometric structs (#674).
  • Fix the type annotations in various tile functions (#714).
  • Fix incorrect stride initialization in tiles returned from functions taking transposed tiles as input (#722).
  • Fix adjoint generation for user functions that return a tile (#749).
  • Fix tile-based solvers failing to accept and return transposed tiles (#768).
  • Fix the Formal parameter space overflowed error during wp.sim.VBDIntegrator kernel compilation for the backward pass in CUDA 11 Warp builds. This was resolved by decoupling collision and elasticity evaluations into separate kernels, increasing parallelism and speeding up the solver (#442).
  • Fix an issue with graph coloring on an empty graph (#509).
  • Fix an integer overflow bug in the native graph coloring module (#718).
  • Fix UsdRenderer.render_points() not supporting multiple colors (#634).
  • Fix an inconsistency in the wp.fem module regarding the orientation of 2D geometry side normals (#629).
  • Fix premature unloading of CUDA modules used in JAX FFI graph captures (#782).

v1.7.2.post1

31 May 20:37
v1.7.2.post1
4ad2090
Compare
Choose a tag to compare

Changelog

[1.7.2] - 2025-05-31

Added

  • Add missing adjoint method for tile assign operations (#680).
  • Add documentation for the fact that += and -= invoke wp.atomic_add() and wp.atomic_sub(), respectively (#505).
  • Add a publications list of academic and research projects leveraging Warp (#686).

Changed

  • Prevent and document that class inheritance is not supported for wp.struct (now throws RuntimeError) (#656).
  • Warn when an incompatible data type conversion is detected when constructing an array using the __cuda_array_interface__ (#624, #670).
  • Relax the exact version requirement in omni.warp towards omni.warp.core (#702).
  • Rename the "Kernel Reference" documentation page to "Built-Ins Reference", with each built-in now having annotations to denote whether they are accessible only from the kernel scope or also from the Python runtime scope (#532).

Fixed

  • Fix an issue where arrays stored in structs could be garbage collected without updating the struct ctype (#720).
  • Fix an issue with preserving the base class of nested struct attributes (#574).
  • Allow recovering from out-of-memory errors during wp.Volume allocation (#611).
  • Fix 2D tile load when source array and tile have incompatible strides (#688).
  • Fix compilation errors with wp.tile_atomic_add() (#681).
  • Fix wp.svd2() with duplicate singular values and improved accuracy (#679).
  • Fix OpenGLRenderer.update_shape_instance() not having color buffers created for the shape instances.
  • Fix text rendering in wp.render.OpenGLRenderer (#704).
  • Fix assembly of rigid body inertia in ModelBuilder.collapse_fixed_joints() (#631).
  • Fix UsdRenderer.render_points() erroring out when passed 4 points or less (#708).
  • Fix wp.atomic_*() built-ins not working with some types (#733).
  • Fix garbage-collection issues with JAX FFI callbacks (#711).

v1.7.1

01 May 06:03
v1.7.1
Compare
Choose a tag to compare

Changelog

[1.7.1] - 2025-04-30

Added

  • Add example of a distributed Jacobi solver using mpi4py in warp/examples/distributed/example_jacobi_mpi.py (#475).

Changed

  • Improve repr() for Warp types, including adding repr() for wp.array.
  • Change the USD renderer to use framesPerSecond for time sampling instead of timeCodesPerSecond to avoid playback speed issues in some viewers (#617).
  • Model.rigid_contact_tids are now -1 at non-active contact indices which allows to retrieve the vertex index of a mesh collision, see test_collision.py (#623).
  • Improve handling of deprecated JAX features (#613).

Fixed

  • Fix a code generation bug involving return statements in Warp kernels, which could result in some threads in Warp being skipped when processed on the GPU (#594).
  • Fix constructing DeformedGeometry from wp.fem.Trimesh3D geometries (#614).
  • Fix lookup operator for wp.fem.Trimesh3D (#618).
  • Include the block dimension in the LTO file hash for the Cholesky solver (#639).
  • Fix tile loads for small tiles with aligned source memory (#622).
  • Fix length/shape matching for vectors and matrices from the Python scope.
  • Fix the dtype parameter missing for wp.quaternion().
  • Fix invalid dtype comparison when using the wp.matrix()/wp.vector()/wp.quaternion() constructors with literal values and an explicit dtype argument (#651).
  • Fix incorrect thread index lookup for the backward pass of wp.sim.collide() (#459).
  • Fix a bug where wp.sim.ModelBuilder adds springs with -1 as vertex indices (#621).
  • Fix center of mass, inertia computation for mesh shapes (#251).
  • Fix computation of body center of mass to account for shape orientation (#648).
  • Fix show_joints not working with wp.sim.render.SimRenderer set to render to USD (#510).
  • Fix the jitter for the OgnParticlesFromMesh node not being computed correctly.
  • Fix documentation of atol and rtol arguments to wp.autograd.gradcheck() and wp.autograd.gradcheck_tape() (#508).

v1.7.0

30 Mar 21:15
v1.7.0
a81f7e7
Compare
Choose a tag to compare

Changelog

[1.7.0] - 2025-03-30

Added

  • Support JAX foreign function interface (FFI) (docs, #511).
  • Support Python/SASS correlation in Nsight Compute reports by emitting #line directives in CUDA-C code. This setting is controlled by wp.config.line_directives and is True by default. (docs, #437)
  • Support vec4f grid construction in wp.Volume.allocate_by_tiles().
  • Add 2D SVD wp.svd2() (#436).
  • Add wp.randu() for random uint32 generation.
  • Add matrix construction functions wp.matrix_from_cols() and wp.matrix_from_rows() (#278).
  • Add wp.transform_from_matrix() to obtain a transform from a 4x4 matrix (#211).
  • Add wp.where() to select between two arguments conditionally using a more intuitive argument order (cond, value_if_true, value_if_false) (#469).
  • Add wp.get_mempool_used_mem_current() and wp.get_mempool_used_mem_high() to query the respective current and high-water mark memory pool allocator usage (#446 ).
  • Add Stream.is_complete and Event.is_complete properties to query completion status (#435).
  • Support timing events inside of CUDA graphs (#556).
  • Add LTO cache to speed up compilation times for kernels using MathDx-based tile functions. Use wp.clear_lto_cache() to clear the LTO cache (#507).
  • Add example demonstrating gradient checkpointing for fluid optimization in warp/examples/optim/example_fluid_checkpoint.py.
  • Add a hinge-angle-based bending force to wp.sim.VBDIntegrator.
  • Add an example to show mesh sampling using a CDF (#476).

Changed

  • Breaking: Remove CUTLASS dependency and wp.matmul() functionality (including batched version). Users should use tile primitives for matrix multiplication operations instead.
  • Deprecate constructing a matrix from vectors using wp.matrix().
  • Deprecate wp.select() in favor of wp.where(). Users should update their code to use wp.where(cond, value_if_true, value_if_false) instead of wp.select(cond, value_if_false, value_if_true).
  • wp.sim.Control no longer has a model attribute (#487).
  • wp.sim.Control.reset() is deprecated and now only zeros-out the controls (previously restored controls to initial model state). Use wp.sim.Control.clear() instead.
  • Vector/matrix/quaternion component assignment operations (e.g., v[0] = x) now compile and run faster in the backward pass. Note: For correct gradient computation, each component should only be assigned once.
  • @wp.kernel has now an optional module argument that allows passing a wp.context.Module to the kernel, or, if set to "unique" let Warp create a new unique module just for this kernel. The default behavior to use the current module is unchanged.
  • Default PTX architecture is now automatically determined by the devices present in the system, ensuring optimal compatibility and performance (#537).
  • Structs now have a trivial default constructor, allowing for wp.tile_reduce() on tiles with struct data types.
  • Extend wp.tile_broadcast() to support broadcasting to 1D, 3D, and 4D shapes (in addition to existing 2D support).
  • wp.fem.integrate() and wp.fem.interpolate() may now perform parallel evaluation of quadrature points within elements.
  • wp.fem.interpolate() can now build Jacobian sparse matrices of interpolated functions with respect to a trial field.
  • Multiple wp.sparse routines (bsr_set_from_triplets, bsr_assign, bsr_axpy, bsr_mm) now accept a masked flag to discard any non-zero not already present in the destination matrix.
  • wp.sparse.bsr_assign() no longer requires source and destination block shapes to evenly divide each other.
  • Extend wp.expect_near() to support all vectors and quaternions.
  • Extend wp.quat_from_matrix() to support 4x4 matrices.
  • Update the OgnClothSimulate node to use the VBD integrator (#512).
  • Remove the globalScale parameter from the OgnClothSimulate node.

Fixed

  • Fix an out-of-bounds access bug caused by an unbalanced BVH tree (#536).
  • Fix an error of incorrectly adding the offset to -1 elements in edge_indices when adding a ModelBuilder to another (#557).

v1.6.2

08 Mar 00:23
v1.6.2
Compare
Choose a tag to compare

Changelog

[1.6.2] - 2025-03-07

Changed

  • Update project license from NVIDIA Software License to Apache License, Version 2.0 (see LICENSE.md).

v1.6.1

03 Mar 16:27
v1.6.1
Compare
Choose a tag to compare

Changelog

[1.6.1] - 2025-03-03

Added

  • Document wp.Launch objects (docs, #428).
  • Document how overwriting previously computed results can lead to incorrect gradients (docs, #525).

Fixed

  • Fix unaligned loads with offset 2D tiles in wp.tile_load().
  • Fix FP64 accuracy of thread-level matrix-matrix multiplications (#489).
  • Fix wp.array() not initializing from arrays defining a CUDA array interface when the target device is CPU (#523).
  • Fix wp.Launch objects not storing and replaying adjoint kernel launches (#449).
  • Fix wp.config.verify_autograd_array_access failing to detect overwrites in generic Warp functions (#493).
  • Fix an error on Windows when closing an OpenGLRenderer app (#488).
  • Fix per-vertex colors not being correctly written out to USD meshes when a constant color is being passed (#480).
  • Fix an error in capturing the wp.sim.VBDIntegrator with CUDA graphs when handle_self_contact is enabled (#441).
  • Fix an error of AABB computation in wp.collide.TriMeshCollisionDetector.
  • Fix URDF-imported planar joints not being set with the intended target_ke, target_kd, and mode parameters (#454).
  • Fix ModelBuilder.add_builder() to use correct offsets for ModelBuilder.joint_parent and ModelBuilder.joint_child (#432)
  • Fix underallocation of contact points for box–sphere and box–capsule collisions.
  • Fix wp.randi() documentation to show correct output range of [-2^31, 2^31).

v1.6.0

03 Feb 23:34
v1.6.0
7f25bbf
Compare
Choose a tag to compare

Changelog

[1.6.0] - 2025-02-03

Added

  • Add preview of Tile Cholesky factorization and solve APIs through wp.tile_cholesky(), tile_cholesky_solve()
    and tile_diag_add() (preview APIs are subject to change).
  • Support for loading tiles from arrays whose shapes are not multiples of the tile dimensions.
    Out-of-bounds reads will be zero-filled and out-of-bounds writes will be skipped.
  • Support for higher-dimensional (up to 4D) tile shapes and memory operations.
  • Add intersection-free self-contact support in wp.sim.VDBIntegrator by passing handle_self_contact=True.
    See warp/examples/sim/example_cloth_self_contact.py for a usage example.
  • Add functions wp.norm_l1(), wp.norm_l2(), wp.norm_huber(), wp.norm_pseudo_huber(), and wp.smooth_normalize()
    for vector types to a new wp.math module.
  • wp.sim.SemiImplicitIntegrator and wp.sim.FeatherstoneIntegrator now have an optional friction_smoothing
    constructor argument (defaults to 1.0) that controls softness of the friction norm computation.
  • Support assert statements in kernels (docs).
    Assertions can only be triggered in "debug" mode (GH-366).
  • Support CUDA IPC on Linux. Call the ipc_handle() method to get an IPC handle for a wp.Event or a wp.array,
    and call wp.from_ipc_handle() or wp.event_from_ipc_handle() in another process to open the handle
    (docs).
  • Add per-module option to disable fused floating point operations, use wp.set_module_options({"fuse_fp": False})
    (GH-379).
  • Add per-module option to add CUDA-C line information for profiling, use wp.set_module_options({"lineinfo": True}).
  • Support operator overloading for wp.struct objects by defining wp.func functions
    (GH-392).
  • Add built-in function wp.len() to retrieve the number of elements for vectors, quaternions, matrices, and arrays
    (GH-389).
  • Add warp/examples/optim/example_softbody_properties.py as an optimization example for soft-body properties
    (GH-419).
  • Add warp/examples/tile/example_tile_walker.py, which reworks the existing example_walker.py
    to use Warp's tile API for matrix multiplication.
  • Add warp/examples/tile/example_tile_nbody.py as an example of an N-body simulation using Warp tile primitives.

Changed

  • Breaking: Change wp.tile_load() and wp.tile_store() indexing behavior so that indices are now specified in
    terms of array elements instead of tile multiples.
  • Breaking: Tile operations now take shape and offset parameters as tuples,
    e.g.: wp.tile_load(array, shape=(m,n), offset=(i,j)).
  • Breaking: Change exception types and error messages thrown by tile functions for improved consistency.
  • Add an implicit tile synchronization whenever a shared memory tile's data is reinitialized (e.g. in dynamic loops).
    This could result in lower performance.
  • wp.Bvh constructor now supports various construction algorithms via the constructor argument, including
    "sah" (Surface Area Heuristics), "median", and "lbvh" (docs)
  • Improve the query efficiency of wp.Bvh and wp.Mesh.
  • Improve memory consumption, compilation and runtime performance when using in-place vector/matrix assignments in
    kernels that have enable_backward set to False (GH-332).
  • Vector/matrix/quaternion component += and -= operations compile and run faster in the backward pass
    (GH-332).
  • Name files in the kernel cache according to their directory. Previously, all files began with
    module_codegen (GH-431).
  • Avoid recompilation of modules when changing block_dim.
  • wp.autograd.gradcheck_tape() now has additional optional arguments reverse_launches and skip_to_launch_index.
  • wp.autograd.gradcheck(), wp.autograd.jacobian(), and wp.autograd.jacobian_fd() now also accept
    arbitrary Python functions that have Warp arrays as inputs and outputs.
  • update_vbo_transforms kernel launches in the OpenGL renderer are no longer recorded onto the tape.
  • Skip emitting backward functions/kernels in the generated C++/CUDA code when enable_backward is set to False.
  • Emit deprecation warnings for the use of the owner and length keywords in the wp.array initializer.
  • Emit deprecation warnings for the use of wp.mlp(), wp.matmul(), and wp.batched_matmul().
    Use tile primitives instead.

Fixed

  • Fix unintended modification of non-Warp arrays during the backward pass (GH-394).
  • Fix so that wp.Tape.zero() zeroes gradients passed via the grads parameter in wp.Tape.backward()
    (GH-407).
  • Fix errors during graph capture caused by module unloading (GH-401).
  • Fix potential memory corruption errors when allocating arrays with strides (GH-404).
  • Fix wp.array() not respecting the target dtype and shape when the given data is an another array with a CUDA interface
    (GH-363).
  • Negative constants evaluate to compile-time constants (GH-403)
  • Fix ImportError exception being thrown during interpreter shutdown on Windows when using the OpenGL renderer
    (GH-412).
  • Fix the OpenGL renderer not working when multiple instances exist at the same time (GH-385).
  • Fix AttributeError crash in the OpenGL renderer when moving the camera (GH-426).
  • Fix the OpenGL renderer not correctly displaying duplicate capsule, cone, and cylinder shapes
    (GH-388).
  • Fix the overriding of wp.sim.ModelBuilder default parameters (GH-429).
  • Fix indexing of wp.tile_extract() when the block dimension is smaller than the tile size.
  • Fix scale and rotation issues with the rock geometry used in the granular collision SDF example
    (GH-409).
  • Fix autodiff Jacobian computation in wp.autograd.jacobian() where in some cases gradients were not zeroed-out properly.
  • Fix plotting issues in wp.autograd.jacobian_plot().
  • Fix the len() operator returning the total size of a matrix instead of its first dimension.
  • Fix gradient instability in rigid-body contact handling for wp.sim.SemiImplicitIntegrator and
    wp.sim.FeatherstoneIntegrator (GH-349).
  • Fix overload resolution of generic Warp functions with default arguments.
  • Fix rendering of arrows with different up_axis, color in OpenGLRenderer (GH-448).