Skip to content

Commit

Permalink
make non-sep works
Browse files Browse the repository at this point in the history
  • Loading branch information
LovelyBuggies committed Aug 30, 2022
1 parent 62488ce commit 6ad681f
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 7 deletions.
2 changes: 1 addition & 1 deletion MFG_VI.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


if __name__ == '__main__':
data = pd.read_csv('data_rho_lwr_new.csv')
data = pd.read_csv('data_rho_non_sep.csv')
rho = data.iloc[:, 1:len(data.iloc[0, :])]
d = np.array(data['0.1'])
rho = np.array(rho)
Expand Down
2 changes: 1 addition & 1 deletion value_iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def value_iteration(n_cell, T_terminal, rho, fake=False):
rho_i_t = rho[i, t]
u[i, t] = 1 - rho_i_t
u[i, t] = 0.5 if fake else min(max(u[i, t], 0), 1)
V[i, t] = delta_T * 0.5 * (1 - u[i, t] - rho_i_t) ** 2 + (1 - u[i, t]) * V[i, t + 1] + u[i, t] * V[i + 1, t + 1]
V[i, t] = delta_T * (0.5 * u[i, t] ** 2 + rho_i_t * u[i, t] - u[i, t]) + (1 - u[i, t]) * V[i, t + 1] + u[i, t] * V[i + 1, t + 1]

V[-1, :] = V[0, :].copy()

Expand Down
9 changes: 4 additions & 5 deletions value_iteration_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def train_ddpg(rho, d, iterations):
T_terminal = int(rho.shape[1] / rho.shape[0])
delta_T = 1 / n_cell
T = int(T_terminal / delta_T)
V = np.ones((n_cell + 1, T + 1))
states = list()
rhos = list()
rho_hist = list()
Expand Down Expand Up @@ -118,10 +117,10 @@ def train_ddpg(rho, d, iterations):
speed = float(actor.forward(np.array([i, t])))
if t != T:
if i != n_cell:
truths.append(delta_T * 0.5 * (1 - rho[i, t] - speed) ** 2 + fake_critic(
truths.append(delta_T * (0.5 * speed ** 2 + rho[i, t] * speed - speed) + fake_critic(
np.array([i + speed, t + 1])))
else:
truths.append(delta_T * 0.5 * (1 - rho[0, t] - speed) ** 2 + fake_critic(
truths.append(delta_T * (0.5 * speed ** 2 + rho[0, t] * speed - speed) + fake_critic(
np.array([speed, t + 1])))
else:
truths.append(0)
Expand Down Expand Up @@ -151,7 +150,7 @@ def train_ddpg(rho, d, iterations):
next_states = np.append(next_xs, next_ts, axis=1)
interp_V_next_state = (torch.ones((n_cell * T, 1)) - speeds) * critic.forward(
next_states_1) + speeds * critic.forward(next_states_2)
advantages = delta_T * 0.5 * (1 - rhos - speeds) ** 2 + critic.forward(next_states) - critic(states)
advantages = advantages = delta_T * (0.5 * speeds ** 2 + rhos * speeds - speeds) + interp_V_next_state - critic(states)
policy_loss = advantages.mean()
if a_it % 5 == 0:
# print(max(critic.forward(next_states) - interp_V_next_state))
Expand All @@ -168,7 +167,7 @@ def train_ddpg(rho, d, iterations):
if i < n_cell and t < T:
u_new[i, t] = actor(np.array([i, t]))

V_new[i, t] = V[i, t]
V_new[i, t] = critic(np.array([i, t]))

rho_hist.append(get_rho_from_u(u_new, d))
if a_it % 50 == 0 and a_it != 0:
Expand Down

0 comments on commit 6ad681f

Please sign in to comment.