Skip to content
This repository has been archived by the owner on May 6, 2021. It is now read-only.

Commit

Permalink
Fixes to Pendulum.jl and some standardizing of classical control envs (
Browse files Browse the repository at this point in the history
…#65)

* Add discrete, fix angle_normalize and reset! in Pendulum

* set actions and rename _interact! to _step! in mountain_car

* rename tau to dt in cartpole

* Fix docstring in Pendulum

* Remove duplicate clamp in Pendulum

* Properly set action and time in mountain_car
  • Loading branch information
AlexLewandowski authored Jun 4, 2020
1 parent e7cf316 commit 5dcada4
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 20 deletions.
15 changes: 9 additions & 6 deletions src/environments/classic_control/cartpole.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ struct CartPoleEnvParams{T}
halflength::T
polemasslength::T
forcemag::T
tau::T
dt::T
thetathreshold::T
xthreshold::T
max_steps::Int
Expand Down Expand Up @@ -46,6 +46,7 @@ Base.show(io::IO, env::CartPoleEnv{T}) where {T} =
- `halflength = T(0.5)`
- `forcemag = T(10.0)`
- `max_steps = 200`
- 'dt = 0.02'
- `seed = nothing`
"""
function CartPoleEnv(;
Expand All @@ -56,6 +57,7 @@ function CartPoleEnv(;
halflength = 0.5,
forcemag = 10.0,
max_steps = 200,
dt = 0.02,
seed = nothing,
)
params = CartPoleEnvParams{T}(
Expand All @@ -66,7 +68,7 @@ function CartPoleEnv(;
halflength,
masspole * halflength,
forcemag,
0.02,
dt,
2 * 12 * π / 360,
2.4,
max_steps,
Expand Down Expand Up @@ -100,6 +102,7 @@ RLBase.observe(env::CartPoleEnv{T}) where {T} =
(reward = env.done ? zero(T) : one(T), terminal = env.done, state = env.state)

function (env::CartPoleEnv)(a)
@assert a in (1, 2)
env.action = a
env.t += 1
force = a == 2 ? env.params.forcemag : -env.params.forcemag
Expand All @@ -113,10 +116,10 @@ function (env::CartPoleEnv)(a)
(4 / 3 - env.params.masspole * costheta^2 / env.params.totalmass)
)
xacc = tmp - env.params.polemasslength * thetaacc * costheta / env.params.totalmass
env.state[1] += env.params.tau * xdot
env.state[2] += env.params.tau * xacc
env.state[3] += env.params.tau * thetadot
env.state[4] += env.params.tau * thetaacc
env.state[1] += env.params.dt * xdot
env.state[2] += env.params.dt * xacc
env.state[3] += env.params.dt * thetadot
env.state[4] += env.params.dt * thetaacc
env.done =
abs(env.state[1]) > env.params.xthreshold ||
abs(env.state[3]) > env.params.thetathreshold ||
Expand Down
21 changes: 15 additions & 6 deletions src/environments/classic_control/mountain_car.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ mutable struct MountainCarEnv{A,T,R<:AbstractRNG} <: AbstractEnv
action_space::A
observation_space::MultiContinuousSpace{Vector{T}}
state::Vector{T}
action::Int
action::Union{Int,AbstractFloat}
done::Bool
t::Int
rng::R
Expand Down Expand Up @@ -70,15 +70,16 @@ function MountainCarEnv(; T = Float64, continuous = false, seed = nothing, kwarg
else
params = MountainCarEnvParams(; T = T, kwargs...)
end
action_space = continuous ? ContinuousSpace(-T(1.0), T(1.0)) : DiscreteSpace(3)
env = MountainCarEnv(
params,
continuous ? ContinuousSpace(-T(1.0), T(1.0)) : DiscreteSpace(3),
action_space,
MultiContinuousSpace(
[params.min_pos, -params.max_speed],
[params.max_pos, params.max_speed],
),
zeros(T, 2),
1,
rand(action_space),
false,
0,
MersenneTwister(seed),
Expand All @@ -102,11 +103,19 @@ function RLBase.reset!(env::MountainCarEnv{A,T}) where {A,T}
nothing
end

(env::MountainCarEnv{<:ContinuousSpace})(a) = _interact!(env, min(max(a, -1, 1)))
function (env::MountainCarEnv{<:ContinuousSpace})(a::AbstractFloat)
@assert a in env.action_space
env.action = a
_step!(env, a)
end

(env::MountainCarEnv{<:DiscreteSpace})(a) = _interact!(env, a - 2)
function (env::MountainCarEnv{<:DiscreteSpace})(a::Int)
@assert a in env.action_space
env.action = a
_step!(env, a - 2)
end

function _interact!(env::MountainCarEnv, force)
function _step!(env::MountainCarEnv, force)
env.t += 1
x, v = env.state
v += force * env.params.power + cos(3 * x) * (-env.params.gravity)
Expand Down
39 changes: 31 additions & 8 deletions src/environments/classic_control/pendulum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,21 @@ struct PendulumEnvParams{T}
max_steps::Int
end

mutable struct PendulumEnv{T,R<:AbstractRNG} <: AbstractEnv
mutable struct PendulumEnv{A,T,R<:AbstractRNG} <: AbstractEnv
params::PendulumEnvParams{T}
action_space::ContinuousSpace
action_space::A
observation_space::MultiContinuousSpace{Vector{T}}
state::Vector{T}
done::Bool
t::Int
rng::R
reward::T
n_actions::Int
action::Union{Int,AbstractFloat}
end

"""
PwendulumEnv(;kwargs...)
PendulumEnv(;kwargs...)
# Keyword arguments
Expand All @@ -36,6 +38,8 @@ end
- `l = T(1)`
- `dt = T(0.05)`
- `max_steps = 200`
- `continuous::Bool = true`
- `n_actions::Int = 3`
- `seed = nothing`
"""
function PendulumEnv(;
Expand All @@ -47,18 +51,23 @@ function PendulumEnv(;
l = T(1),
dt = T(0.05),
max_steps = 200,
continuous::Bool = true,
n_actions::Int = 3,
seed = nothing,
)
high = T.([1, 1, max_speed])
action_space = continuous ? ContinuousSpace(-2.0, 2.0) : DiscreteSpace(n_actions)
env = PendulumEnv(
PendulumEnvParams(max_speed, max_torque, g, m, l, dt, max_steps),
ContinuousSpace(-2.0, 2.0),
action_space,
MultiContinuousSpace(-high, high),
zeros(T, 2),
false,
0,
MersenneTwister(seed),
zero(T),
n_actions,
rand(action_space),
)
reset!(env)
env
Expand All @@ -67,20 +76,34 @@ end
Random.seed!(env::PendulumEnv, seed) = Random.seed!(env.rng, seed)

pendulum_observation(s) = [cos(s[1]), sin(s[1]), s[2]]
angle_normalize(x) = ((x + pi) % (2 * pi)) - pi
angle_normalize(x) = Base.mod((x + Base.π), (2 * Base.π)) - Base.π

RLBase.observe(env::PendulumEnv) =
(reward = env.reward, state = pendulum_observation(env.state), terminal = env.done)

function RLBase.reset!(env::PendulumEnv{T}) where {T}
env.state[:] = 2 * rand(env.rng, T, 2) .- 1
function RLBase.reset!(env::PendulumEnv{A,T}) where {A,T}
env.state[1] = 2*π*(rand(env.rng, T) .- 1)
env.state[2] = 2*(rand(env.rng, T) .- 1)
env.t = 0
env.done = false
env.reward = zero(T)
nothing
end

function (env::PendulumEnv)(a)
function (env::PendulumEnv{<:ContinuousSpace})(a::AbstractFloat)
@assert a in env.action_space
env.action = a
_step!(env, a)
end

function (env::PendulumEnv{<:DiscreteSpace})(a::Int)
@assert a in env.action_space
env.action = a
float_a = (4 / (env.n_actions - 1)) * (a - (env.n_actions - 1) / 2 - 1)
_step!(env, float_a)
end

function _step!(env::PendulumEnv, a)
env.t += 1
th, thdot = env.state
a = clamp(a, -env.params.max_torque, env.params.max_torque)
Expand Down

0 comments on commit 5dcada4

Please sign in to comment.