Skip to content

Commit

Permalink
DROP! Update Notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Jun 3, 2024
1 parent e52a673 commit 742c1c6
Showing 1 changed file with 164 additions and 14 deletions.
178 changes: 164 additions & 14 deletions examples/Parallel_computing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"# @title Imports and setup\n",
"import sys\n",
"import os\n",
"import pathlib\n",
"\n",
"# Deactivate GPU to avoid out of memory errors\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n",
Expand Down Expand Up @@ -54,8 +55,15 @@
"\n",
"import jaxsim.typing as jtp\n",
"from jaxsim import logging\n",
"\n",
"from jaxsim.mujoco import MujocoVideoRecorder, MujocoModelHelper, RodModelToMjcf\n",
"from jaxsim.api.common import VelRepr\n",
"\n",
"from jaxsim.mujoco import (\n",
" MujocoVideoRecorder,\n",
" MujocoModelHelper,\n",
" RodModelToMjcf,\n",
" SdfToMjcf,\n",
" UrdfToMjcf,\n",
")\n",
"\n",
"logging.set_logging_level(logging.LoggingLevel.INFO)\n",
"logging.info(f\"Running on {jax.devices()}\")"
Expand All @@ -79,14 +87,21 @@
"# @title Create a sphere model\n",
"model_sdf_string = rod.Sdf(\n",
" version=\"1.7\",\n",
" model=BoxBuilder(x=0.30, y=0.30, z=0.30, mass=1.0, name=\"box\")\n",
" # model=BoxBuilder(x=0.30, y=0.30, z=0.30, mass=1.0, name=\"box\")\n",
" model=SphereBuilder(radius=0.15, mass=1.0, name=\"sphere\")\n",
" .build_model()\n",
" .add_link()\n",
" .add_inertial()\n",
" .add_visual()\n",
" .add_collision()\n",
" .build(),\n",
").serialize(pretty=True)"
").serialize(pretty=True)\n",
"# import urllib\n",
"\n",
"# url = \"https://raw.githubusercontent.com/icub-tech-iit/ergocub-gazebo-simulations/master/models/stickBot/model.urdf\"\n",
"\n",
"# model_sdf_string = urllib.request.urlopen(url).read().decode()\n",
"# # model_sdf_string = pathlib.Path(\"/home/flferretti/git/element_rl-for-codesign/assets/model/hopper.sdf\")"
]
},
{
Expand All @@ -109,21 +124,53 @@
"source": [
"import jaxsim.api as js\n",
"from jaxsim import integrators\n",
"import jaxsim\n",
"\n",
"dt = 0.001\n",
"integration_time = 1500\n",
"\n",
"model = js.model.JaxSimModel.build_from_model_description(\n",
" model_description=model_sdf_string\n",
" model_description=model_sdf_string,\n",
" contact_model=js.rigid_contacts.RigidContacts(),\n",
" is_urdf=True,\n",
")\n",
"\n",
"model = js.model.reduce(\n",
" model=model,\n",
" considered_joints=tuple(\n",
" [\n",
" j\n",
" for j in model.joint_names()\n",
" if \"camera\" not in j\n",
" and \"neck\" not in j\n",
" and \"wrist\" not in j\n",
" and \"thumb\" not in j\n",
" and \"index\" not in j\n",
" and \"middle\" not in j\n",
" and \"ring\" not in j\n",
" and \"pinkie\" not in j\n",
" and \"elbow\" not in j\n",
" and \"shoulder\" not in j\n",
" and \"hip\" not in j\n",
" and \"knee\" not in j\n",
" and \"lidar\" not in j\n",
" and \"torso\" not in j\n",
" ]\n",
" ),\n",
")\n",
"model = js.model.reduce(model=model, considered_joints=tuple())\n",
"\n",
"data = js.data.JaxSimModelData.build(\n",
" model=model, velocity_representation=VelRepr.Inertial\n",
")\n",
"data = js.data.JaxSimModelData.build(model=model)\n",
"integrator = integrators.fixed_step.RungeKutta4SO3.build(\n",
" dynamics=js.ode.wrap_system_dynamics_for_integration(\n",
" model=model,\n",
" data=data,\n",
" system_dynamics=js.ode.system_dynamics,\n",
" ),\n",
")\n",
"# with jax.disable_jit():\n",
"integrator_state = integrator.init(x0=data.state, t0=0.0, dt=dt)"
]
},
Expand All @@ -133,12 +180,12 @@
"metadata": {},
"outputs": [],
"source": [
"mcjf_string, assets = RodModelToMjcf.convert(rod_model=model_sdf_string.model)\n",
"mcjf_string, assets = UrdfToMjcf.convert(urdf=model_sdf_string)\n",
"mj_helper = MujocoModelHelper.build_from_xml(\n",
" mjcf_description=mcjf_string, assets=assets\n",
")\n",
"recorder = MujocoVideoRecorder(\n",
" model=mj_helper.model, assets=mj_helper.data, fps=int(1 / dt), width=640, height=480\n",
" model=mj_helper.model, data=mj_helper.data, fps=int(1 / dt), width=640, height=480\n",
")"
]
},
Expand Down Expand Up @@ -224,20 +271,71 @@
"\n",
" data = data.reset_base_position(base_position=pose)\n",
" x_t_i = []\n",
" forces = []\n",
"\n",
" S = jnp.block([jnp.zeros(shape=(model.dofs(), 6)), jnp.eye(model.dofs())]).T\n",
" τ = jnp.zeros(model.dofs())\n",
"\n",
" # l_foot = model.link_names().index(\"l_ankle_2\")\n",
" # r_foot = model.link_names().index(\"r_ankle_2\")\n",
"\n",
" for _ in range(integration_time):\n",
" F = []\n",
"\n",
" h = js.model.free_floating_bias_forces(model=model, data=data)\n",
"\n",
" M = js.model.free_floating_mass_matrix(model=model, data=data)\n",
"\n",
" J̇ν = js.model.link_bias_accelerations(model=model, data=data)\n",
"\n",
" M_inv = jnp.linalg.inv(M)\n",
"\n",
" # idxs = (0,) # (l_foot, r_foot)\n",
" # O_JL = jax.vmap(\n",
" # lambda body: js.link.jacobian(\n",
" # model=model,\n",
" # data=data,\n",
" # link_index=body,\n",
" # # output_vel_repr=VelRepr.Inertial,\n",
" # )\n",
" # )(jnp.array(idxs))\n",
" O_JL = js.link.jacobian(\n",
" model=model,\n",
" data=data,\n",
" link_index=0,\n",
" output_vel_repr=VelRepr.Mixed,\n",
" )\n",
"\n",
" # O_JL = O_JL.reshape(6 * len(idxs), 10)\n",
"\n",
" # W_H_L = js.link.transform(model=model, data=data, link_index=body)\n",
" # W_X_L = jaxsim.math.Adjoint.from_transform(W_H_L).T\n",
" # F = -jnp.linalg.inv(O_JL @ M_inv @ O_JL.T) @ (\n",
" # J̇ν[l_foot:r_foot+1].ravel() + O_JL @ M_inv @ (S @ τ - h)\n",
" # )\n",
" F = -jnp.linalg.inv(O_JL.squeeze() @ M_inv @ O_JL.squeeze().T) @ (\n",
" J̇ν[0] + O_JL.squeeze() @ M_inv @ (S @ τ - h)\n",
" )\n",
"\n",
" # F = F.reshape(-1, 6)\n",
"\n",
" # link_forces = jnp.zeros((model.number_of_links(), 6)).at[l_foot:r_foot+1].set(jnp.array(F))\n",
" link_forces = jnp.zeros((model.number_of_links(), 6)).at[0].set(jnp.array(F))\n",
"\n",
" data, integrator_state = js.model.step(\n",
" dt=dt,\n",
" model=model,\n",
" data=data,\n",
" integrator=integrator,\n",
" integrator_state=integrator_state,\n",
" joint_forces=None,\n",
" link_forces=None,\n",
" link_forces=link_forces,\n",
" )\n",
"\n",
" x_t_i.append(data.base_position())\n",
" forces.append(F)\n",
"\n",
" return x_t_i"
" return x_t_i, forces"
]
},
{
Expand Down Expand Up @@ -265,7 +363,7 @@
"now = time.perf_counter()\n",
"\n",
"# x_t = simulate_vectorized(data, integrator_state, poses[:, 0]).\n",
"x_t = simulate(data, integrator_state, poses)\n",
"x_t, forces = simulate(data, integrator_state, poses[:, 0])\n",
"\n",
"comp_time = time.perf_counter() - now\n",
"\n",
Expand All @@ -289,8 +387,13 @@
" mj_helper.set_base_position(pose)\n",
" recorder.record_frame()\n",
"\n",
"import datetime\n",
"\n",
"import mediapy as media\n",
"\n",
"recorder.write_video(path=Path.cwd() / Path(\"sphere.mp4\"), exists_ok=True)"
"media.show_video(recorder.frames, fps=1 / dt)\n",
"\n",
"recorder.write_video(path=Path.cwd() / Path(f\"video_{datetime.datetime.now()}.mp4\"))"
]
},
{
Expand All @@ -309,13 +412,60 @@
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"plt.plot(np.arange(len(x_t)) * dt, np.array(x_t)[:, :, 2])\n",
"plt.plot(np.arange(len(x_t[:])) * dt, np.array(x_t)[:, 2])\n",
"plt.grid(True)\n",
"plt.xlabel(\"Time [s]\")\n",
"plt.ylabel(\"Height [m]\")\n",
"plt.title(\"Trajectory of the model's base\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"forces = np.array([force for force in forces])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"forces.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"plt.plot(\n",
" np.arange(len(forces[:600])) * dt,\n",
" forces[:600],\n",
" label=[\"X\", \"Y\", \"Z\", \"Rx\", \"Ry\", \"Rz\"],\n",
")\n",
"plt.grid(True)\n",
"plt.xlabel(\"Time [s]\")\n",
"plt.ylabel(\"Force [N]\")\n",
"plt.title(\"Contact forces\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down Expand Up @@ -347,7 +497,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.11.8"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 742c1c6

Please sign in to comment.