Skip to content

Commit

Permalink
Use structural pattern matching for jcal
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Feb 2, 2024
1 parent ed05df4 commit 5573dd2
Showing 1 changed file with 47 additions and 46 deletions.
93 changes: 47 additions & 46 deletions src/jaxsim/math/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,70 +32,71 @@ def jcalc(
else:
raise ValueError(jtyp)

if code is JointType.F:
raise ValueError("Fixed joints shouldn't be here")
match code:
case JointType.F:
raise ValueError("Fixed joints shouldn't be here")

if code is JointType.R:
jtyp: JointGenericAxis
case JointType.R:
jtyp: JointGenericAxis

Xj = Adjoint.from_rotation_and_translation(
rotation=Rotation.from_axis_angle(vector=q * jtyp.axis), inverse=True
)
Xj = Adjoint.from_rotation_and_translation(
rotation=Rotation.from_axis_angle(vector=q * jtyp.axis), inverse=True
)

S = jnp.vstack(jnp.hstack([jnp.zeros(3), jtyp.axis.squeeze()]))
S = jnp.vstack(jnp.hstack([jnp.zeros(3), jtyp.axis.squeeze()]))

elif code is JointType.P:
jtyp: JointGenericAxis
case JointType.P:
jtyp: JointGenericAxis

Xj = Adjoint.from_rotation_and_translation(
translation=jnp.array(q * jtyp.axis), inverse=True
)
Xj = Adjoint.from_rotation_and_translation(
translation=jnp.array(q * jtyp.axis), inverse=True
)

S = jnp.vstack(jnp.hstack([jtyp.axis.squeeze(), jnp.zeros(3)]))
S = jnp.vstack(jnp.hstack([jtyp.axis.squeeze(), jnp.zeros(3)]))

elif code is JointType.Rx:
Xj = Adjoint.from_rotation_and_translation(
rotation=Rotation.x(theta=q), inverse=True
)
case JointType.Rx:
Xj = Adjoint.from_rotation_and_translation(
rotation=Rotation.x(theta=q), inverse=True
)

S = jnp.vstack([0, 0, 0, 1.0, 0, 0])
S = jnp.vstack([0, 0, 0, 1.0, 0, 0])

elif code is JointType.Ry:
Xj = Adjoint.from_rotation_and_translation(
rotation=Rotation.y(theta=q), inverse=True
)
case JointType.Ry:
Xj = Adjoint.from_rotation_and_translation(
rotation=Rotation.y(theta=q), inverse=True
)

S = jnp.vstack([0, 0, 0, 0, 1.0, 0])
S = jnp.vstack([0, 0, 0, 0, 1.0, 0])

elif code is JointType.Rz:
Xj = Adjoint.from_rotation_and_translation(
rotation=Rotation.z(theta=q), inverse=True
)
case JointType.Rz:
Xj = Adjoint.from_rotation_and_translation(
rotation=Rotation.z(theta=q), inverse=True
)

S = jnp.vstack([0, 0, 0, 0, 0, 1.0])
S = jnp.vstack([0, 0, 0, 0, 0, 1.0])

elif code is JointType.Px:
Xj = Adjoint.from_rotation_and_translation(
translation=jnp.array([q, 0.0, 0.0]), inverse=True
)
case JointType.Px:
Xj = Adjoint.from_rotation_and_translation(
translation=jnp.array([q, 0.0, 0.0]), inverse=True
)

S = jnp.vstack([1.0, 0, 0, 0, 0, 0])
S = jnp.vstack([1.0, 0, 0, 0, 0, 0])

elif code is JointType.Py:
Xj = Adjoint.from_rotation_and_translation(
translation=jnp.array([0.0, q, 0.0]), inverse=True
)
case JointType.Py:
Xj = Adjoint.from_rotation_and_translation(
translation=jnp.array([0.0, q, 0.0]), inverse=True
)

S = jnp.vstack([0, 1.0, 0, 0, 0, 0])
S = jnp.vstack([0, 1.0, 0, 0, 0, 0])

elif code is JointType.Pz:
Xj = Adjoint.from_rotation_and_translation(
translation=jnp.array([0.0, 0.0, q]), inverse=True
)
case JointType.Pz:
Xj = Adjoint.from_rotation_and_translation(
translation=jnp.array([0.0, 0.0, q]), inverse=True
)

S = jnp.vstack([0, 0, 1.0, 0, 0, 0])
S = jnp.vstack([0, 0, 1.0, 0, 0, 0])

else:
raise ValueError(code)
case _:
raise ValueError(code)

return Xj, S

0 comments on commit 5573dd2

Please sign in to comment.