diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index a3a154268..0b255dbd3 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -338,6 +338,7 @@ def reduce( reduced_model = JaxSimModel.build( model_description=reduced_intermediate_description, model_name=model.name(), + terrain=model.terrain, ) # Store the origin of the model, in case downstream logic needs it diff --git a/tests/test_api_model.py b/tests/test_api_model.py index 88ea31d41..f903a9f64 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -93,6 +93,9 @@ def test_model_creation_and_reduction( # Check that all non-fixed joints are in the reduced model. assert set(reduced_joints) == set(model_reduced.joint_names()) + # Check that the reduce model maintain the same terrain of the full model. + assert model_full.terrain == model_reduced.terrain + # Build the data of the reduced model. data_reduced = js.data.JaxSimModelData.build( model=model_reduced,