Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve acados mpc formulation #182

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [1, 1, 1, 1] # Match LQR weights
rew_act_weight: [0.1]
done_on_out_of_bound: True
2 changes: 2 additions & 0 deletions examples/lqr/config_overrides/cartpole/cartpole_tracking.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,6 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [1, 0.1, 0.1, 0.1] # Match LQR weights
rew_act_weight: [0.1]
done_on_out_of_bound: True
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [1, 1, 1, 1, 1, 1] # Match LQR weights
rew_act_weight: [0.1]
done_on_out_of_bound: True
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,6 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [1, 0.1, 1, 0.1, 0.1, 0.1] # Match LQR weights
rew_act_weight: [0.1]
done_on_out_of_bound: True
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,7 @@ task_config:

episode_len_sec: 6
cost: quadratic
# Match LQR weights
rew_state_weight: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
rew_act_weight: [0.1]
done_on_out_of_bound: True
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,7 @@ task_config:

episode_len_sec: 6
cost: quadratic
# Match LQR weights
rew_state_weight: [1, 0.1, 1, 0.1, 1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
rew_act_weight: [0.1]
done_on_out_of_bound: True
22 changes: 4 additions & 18 deletions examples/lqr/lqr_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
from safe_control_gym.utils.registration import make


def run(gui=True, n_episodes=1, n_steps=None, save_data=False):
def run(gui=False, plot=True, n_episodes=1, n_steps=None, save_data=False):
'''The main function running LQR and iLQR experiments.

Args:
gui (bool): Whether to display the gui and plot graphs.
gui (bool): Whether to display the gui.
plot (bool): Whether to plot graphs.
n_episodes (int): The number of episodes to execute.
n_steps (int): The total number of steps to execute.
save_data (bool): Whether to save the collected experiment data.
Expand Down Expand Up @@ -61,7 +62,7 @@ def run(gui=True, n_episodes=1, n_steps=None, save_data=False):
else:
trajs_data, _ = experiment.run_evaluation(training=True, n_steps=n_steps)

if gui:
if plot:
post_analysis(trajs_data['obs'][0], trajs_data['action'][0], ctrl.env)

# Close environments
Expand Down Expand Up @@ -132,20 +133,5 @@ def post_analysis(state_stack, input_stack, env):
plt.show()


def wrap2pi_vec(angle_vec):
'''Wraps a vector of angles between -pi and pi.

Args:
angle_vec (ndarray): A vector of angles.
'''
for k, angle in enumerate(angle_vec):
while angle > np.pi:
angle -= np.pi
while angle <= -np.pi:
angle += np.pi
angle_vec[k] = angle
return angle_vec


if __name__ == '__main__':
run()
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [5.0, 0.1, 5.0, 0.1] # Match MPC weights
rew_act_weight: [0.1]
done_on_out_of_bound: True

constraints:
Expand Down
2 changes: 2 additions & 0 deletions examples/mpc/config_overrides/cartpole/cartpole_tracking.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [5.0, 0.1, 5.0, 0.1] # Match MPC weights
rew_act_weight: [0.1]
done_on_out_of_bound: True

constraints:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ algo_config:
- 0.1
- 0.1
q_mpc:
- 1.0
- 5.0
- 0.1
- 1.0
- 5.0
- 0.1
- 0.1
- 0.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [5.0, 0.1, 5.0, 0.1, 0.1, 0.1] # Match MPC weights
rew_act_weight: [0.1, 0.1]
done_on_out_of_bound: True

constraints:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [5.0, 0.1, 5.0, 0.1, 0.1, 0.1] # Match MPC weights
rew_act_weight: [0.1, 0.1]
done_on_out_of_bound: True

constraints:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ algo_config:
- 0.1
- 0.1
q_mpc:
- 1.0
- 5.0
- 0.1
- 1.0
- 5.0
- 0.1
- 1.0
- 5.0
- 0.1
- 0.1
- 0.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ task_config:

episode_len_sec: 6
cost: quadratic
# Match MPC weights
rew_state_weight: [5.0, 0.1, 5.0, 0.1, 5.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
rew_act_weight: [0.1, 0.1, 0.1, 0.1]
done_on_out_of_bound: True
constraints:
- constraint_form: default_constraint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ task_config:
proj_normal: [0, 1, 1]
episode_len_sec: 6
cost: quadratic
# Match MPC weights
rew_state_weight: [5.0, 0.1, 5.0, 0.1, 5.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
rew_act_weight: [0.1, 0.1, 0.1, 0.1]
done_on_out_of_bound: True
constraints:
- constraint_form: default_constraint
Expand Down
61 changes: 11 additions & 50 deletions examples/mpc/mpc_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import os
import pickle
from collections import defaultdict
from functools import partial

import matplotlib.pyplot as plt
Expand All @@ -15,11 +14,12 @@
from safe_control_gym.utils.registration import make


def run(gui=True, n_episodes=1, n_steps=None, save_data=False):
def run(gui=False, plot=True, n_episodes=1, n_steps=None, save_data=False):
'''The main function running MPC and Linear MPC experiments.

Args:
gui (bool): Whether to display the gui and plot graphs.
gui (bool): Whether to display the gui.
plot (bool): Whether to plot graphs.
n_episodes (int): The number of episodes to execute.
n_steps (int): The total number of steps to execute.
save_data (bool): Whether to save the collected experiment data.
Expand All @@ -34,51 +34,27 @@ def run(gui=True, n_episodes=1, n_steps=None, save_data=False):
config.task,
**config.task_config
)
random_env = env_func(gui=False)
env = env_func(gui=gui)

# Create controller.
ctrl = make(config.algo,
env_func,
**config.algo_config
)

all_trajs = defaultdict(list)
n_episodes = 1 if n_episodes is None else n_episodes

# Run the experiment.
for _ in range(n_episodes):
# Get initial state and create environments
init_state, _ = random_env.reset()
static_env = env_func(gui=gui, randomized_init=False, init_state=init_state)
static_train_env = env_func(gui=False, randomized_init=False, init_state=init_state)

# Create experiment, train, and run evaluation
experiment = BaseExperiment(env=static_env, ctrl=ctrl, train_env=static_train_env)
experiment.launch_training()

if n_steps is None:
trajs_data, _ = experiment.run_evaluation(training=True, n_episodes=1)
else:
trajs_data, _ = experiment.run_evaluation(training=True, n_steps=n_steps)

if gui:
post_analysis(trajs_data['obs'][0], trajs_data['action'][0], ctrl.env)

# Close environments
static_env.close()
static_train_env.close()
experiment = BaseExperiment(env=env, ctrl=ctrl)
trajs_data, metrics = experiment.run_evaluation(training=True, n_episodes=n_episodes, n_steps=n_steps)

# Merge in new trajectory data
for key, value in trajs_data.items():
all_trajs[key] += value
if plot:
for i in range(len(trajs_data['obs'])):
post_analysis(trajs_data['obs'][i], trajs_data['action'][i], ctrl.env)

ctrl.close()
random_env.close()
metrics = experiment.compute_metrics(all_trajs)
all_trajs = dict(all_trajs)
env.close()

if save_data:
results = {'trajs_data': all_trajs, 'metrics': metrics}
results = {'trajs_data': trajs_data, 'metrics': metrics}
path_dir = os.path.dirname('./temp-data/')
os.makedirs(path_dir, exist_ok=True)
with open(f'./temp-data/{config.algo}_data_{config.task}_{config.task_config.task}.pkl', 'wb') as file:
Expand Down Expand Up @@ -132,20 +108,5 @@ def post_analysis(state_stack, input_stack, env):
plt.show()


def wrap2pi_vec(angle_vec):
'''Wraps a vector of angles between -pi and pi.

Args:
angle_vec (ndarray): A vector of angles.
'''
for k, angle in enumerate(angle_vec):
while angle > np.pi:
angle -= np.pi
while angle <= -np.pi:
angle += np.pi
angle_vec[k] = angle
return angle_vec


if __name__ == '__main__':
run()
1 change: 0 additions & 1 deletion safe_control_gym/controllers/lqr/ilqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def __init__(
self.model = self.get_prior(self.env)
self.Q = get_cost_weight_matrix(self.q_lqr, self.model.nx)
self.R = get_cost_weight_matrix(self.r_lqr, self.model.nu)
self.env.set_cost_function_param(self.Q, self.R)

self.gain = compute_lqr_gain(self.model, self.model.X_EQ, self.model.U_EQ,
self.Q, self.R, self.discrete_dynamics)
Expand Down
1 change: 0 additions & 1 deletion safe_control_gym/controllers/lqr/lqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def __init__(
self.discrete_dynamics = discrete_dynamics
self.Q = get_cost_weight_matrix(q_lqr, self.model.nx)
self.R = get_cost_weight_matrix(r_lqr, self.model.nu)
self.env.set_cost_function_param(self.Q, self.R)

self.gain = compute_lqr_gain(self.model, self.model.X_EQ, self.model.U_EQ,
self.Q, self.R, self.discrete_dynamics)
Expand Down
5 changes: 1 addition & 4 deletions safe_control_gym/controllers/mpc/mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,10 +454,7 @@ def run(self,

self.x_prev = None
self.u_prev = None
if not env.initial_reset:
env.set_cost_function_param(self.Q, self.R)
# obs, info = env.reset()
obs = env.reset()
obs, info = env.reset()
print('Init State:')
print(obs)
ep_returns, ep_lengths = [], []
Expand Down
Loading