Skip to content

Commit

Permalink
Update typing module
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Jan 31, 2024
1 parent 3be1ed4 commit 0096caf
Showing 1 changed file with 35 additions and 30 deletions.
65 changes: 35 additions & 30 deletions src/jaxsim/typing.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,44 @@
from typing import Any, Dict, Hashable, List, NamedTuple, Tuple, Union
from typing import Any, Hashable, NamedTuple

import jax.numpy as jnp
import numpy.typing as npt
import jax

# =========
# JAX types
FloatJax = Union[jnp.float16, jnp.float32, jnp.float64]
IntJax = Union[
jnp.int8,
jnp.int16,
jnp.int32,
jnp.int64,
jnp.uint8,
jnp.uint16,
jnp.uint32,
jnp.uint64,
]
ArrayJax = jnp.ndarray
TensorJax = jnp.ndarray
# =========

ScalarJax = jax.Array
IntJax = ScalarJax
BoolJax = ScalarJax
FloatJax = ScalarJax

ArrayJax = jax.Array
VectorJax = ArrayJax
MatrixJax = ArrayJax
PyTree = Union[
TensorJax,
Dict[Hashable, "PyTree"],
List["PyTree"],
NamedTuple,
Tuple["PyTree"],
None,
Any,
]

PyTree = (
dict[Hashable, "PyTree"]
| list["PyTree"]
| NamedTuple
| tuple["PyTree"]
| None
| Any
)

# =======================
# Mixed JAX / NumPy types
Array = Union[npt.NDArray, ArrayJax]
Tensor = Union[npt.NDArray, ArrayJax]
# =======================

Array = jax.typing.ArrayLike
Vector = Array
Matrix = Array
Bool = Union[bool, ArrayJax]
Int = Union[int, IntJax]
Float = Union[float, FloatJax]

Int = int | IntJax
Bool = bool | ArrayJax
Float = float | FloatJax

ArrayLike = Array
VectorLike = Vector
MatrixLike = Matrix
IntLike = Int
BoolLike = Bool
FloatLike = Float

0 comments on commit 0096caf

Please sign in to comment.