Skip to content
This repository has been archived by the owner on May 6, 2021. It is now read-only.

Commit

Permalink
update atari env to use max-pooling (#22)
Browse files Browse the repository at this point in the history
* update atari env

* add test cases for atari environments

* add doc
  • Loading branch information
findmyway authored Nov 2, 2019
1 parent 033cf4c commit 8c21b67
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 57 deletions.
15 changes: 15 additions & 0 deletions src/abstractenv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,21 @@ function action_space end
function observation_space end
function render end

"""
Observation(;reward, terminal, state, meta...)
The observation of an environment from the perspective of an agent.
# Keywords & Fields
- `reward`: the reward of an agent
- `terminal`: indicates that if the environment is terminated or not.
- `state`: the current state of the environment from the perspective of an agent
- `meta`: some other information, like `legal_actions`...
!!! note
The `reward` and `terminal` of the first observation before interacting with an environment may not be valid.
"""
struct Observation{R,T,S,M<:NamedTuple}
reward::R
terminal::T
Expand Down
153 changes: 98 additions & 55 deletions src/environments/atari.jl
Original file line number Diff line number Diff line change
@@ -1,93 +1,136 @@
using ArcadeLearningEnvironment, GR
using ArcadeLearningEnvironment, GR, Random

export AtariEnv

mutable struct AtariEnv{To,F} <: AbstractEnv
mutable struct AtariEnv{IsGrayScale, TerminalOnLifeLoss, N, S<:AbstractRNG} <: AbstractEnv
ale::Ptr{Nothing}
screen::Array{UInt8,1}
getscreen!::F
actions::Array{Int64,1}
screens::Tuple{Array{UInt8, N}, Array{UInt8, N}} # for max-pooling
actions::Vector{Int64}
action_space::DiscreteSpace{Int}
observation_space::To
observation_space::MultiDiscreteSpace{UInt8, N}
noopmax::Int
frame_skip::Int
reward::Float32
lives::Int
seed::S
end

"""
AtariEnv(name; colorspace = "Grayscale", frame_skip = 4, noopmax = 20,
color_averaging = true, repeat_action_probability = 0.)
Returns an AtariEnv that can be used in an RLSetup of the
[ReinforcementLearning](https://github.com/jbrea/ReinforcementLearning.jl)
package. Check the deps/roms folder of the ArcadeLearningEnvironment package to
see all available `name`s.
AtariEnv(;kwargs...)
This implementation follows the guidelines in [Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents](https://arxiv.org/abs/1709.06009)
TODO: support seed! in single/multi thread
# Keywords
- `name::String="pong"`: name of the Atari environments. Use `getROMList` to show all supported environments.
- `grayscale_obs::Bool=true`:if `true`, then gray scale observation is returned, otherwise, RGB observation is returned.
- `noop_max::Int=30`: max number of no-ops.
- `frame_skip::Int=4`: the frequency at which the agent experiences the game.
- `terminal_on_life_loss::Bool=false`: if `true`, then game is over whenever a life is lost.
- `repeat_action_probability::Float64=0.`
- `color_averaging::Bool=false`: whether to perform phosphor averaging or not.
- `max_num_frames_per_episode::Int=0`
- `full_action_space::Bool=false`: by default, only use minimal action set. If `true`, one need to call `legal_actions` to get the valid action set. TODO
See also the [python implementation](https://github.com/openai/gym/blob/c072172d64bdcd74313d97395436c592dc836d5c/gym/wrappers/atari_preprocessing.py#L8-L36)
"""
function AtariEnv(
name;
colorspace = "Grayscale",
;name = "pong",
grayscale_obs=true,
noop_max = 30,
frame_skip = 4,
noopmax = 20,
color_averaging = true,
actionset = :minimal,
repeat_action_probability = 0.,
terminal_on_life_loss=false,
repeat_action_probability=0.,
color_averaging=false,
max_num_frames_per_episode=0,
full_action_space=false,
seed=nothing
)
frame_skip > 0 || throw(ArgumentError("frame_skip must be greater than 0!"))
name in getROMList() || throw(ArgumentError("unknown ROM name! run `getROMList()` to see all the game names."))

if isnothing(seed)
seed = (MersenneTwister(), 0)
elseif seed isa Tuple{Int, Int}
seed = (MersenneTwister(seed[1]), seed[2])
else
@error "You must specify two seeds, one for Julia wrapper, one for internal C implementation" # ??? maybe auto generate two seed from one
end

ale = ALE_new()
setBool(ale, "color_averaging", color_averaging)
setInt(ale, "frame_skip", Int32(frame_skip))
setInt(ale, "random_seed", seed[2])
setInt(ale, "frame_skip", Int32(1)) # !!! do not use internal frame_skip here, we need to apply max-pooling for the latest two frames, so we need to manually implement the mechanism.
setInt(ale, "max_num_frames_per_episode", max_num_frames_per_episode)
setFloat(ale, "repeat_action_probability", Float32(repeat_action_probability))
setBool(ale, "color_averaging", color_averaging)
loadROM(ale, name)
observation_length = getScreenWidth(ale) * getScreenHeight(ale)
if colorspace == "Grayscale"
screen = Array{Cuchar}(undef, observation_length)
getscreen! = ArcadeLearningEnvironment.getScreenGrayscale!
observation_space = MultiDiscreteSpace(
fill(typemax(Cuchar), observation_length),
fill(typemin(Cuchar), observation_length),
)
elseif colorspace == "RGB"
screen = Array{Cuchar}(undef, 3 * observation_length)
getscreen! = ArcadeLearningEnvironment.getScreenRGB!
observation_space = MultiDiscreteSpace(
fill(typemax(Cuchar), 3 * observation_length),
fill(typemin(Cuchar), 3 * observation_length),
)
elseif colorspace == "Raw"
screen = Array{Cuchar}(undef, observation_length)
getscreen! = ArcadeLearningEnvironment.getScreen!
observation_space = MultiDiscreteSpace(
fill(typemax(Cuchar), observation_length),
fill(typemin(Cuchar), observation_length),
)
end
actions = actionset == :minimal ? getMinimalActionSet(ale) : getLegalActionSet(ale)

observation_size = grayscale_obs ? (getScreenWidth(ale), getScreenHeight(ale)) : (3, getScreenWidth(ale), getScreenHeight(ale)) # !!! note the order
observation_space = MultiDiscreteSpace(
fill(typemax(Cuchar), observation_size),
fill(typemin(Cuchar), observation_size),
)

actions = full_action_space ? getLegalActionSet(ale) : getMinimalActionSet(ale)
action_space = DiscreteSpace(length(actions))
AtariEnv(
screens = (
fill(typemin(Cuchar), observation_size),
fill(typemin(Cuchar), observation_size),
)

AtariEnv{grayscale_obs, terminal_on_life_loss, grayscale_obs ? 2 : 3, typeof(seed[1])}(
ale,
screen,
getscreen!,
screens,
actions,
action_space,
observation_space,
noopmax,
noop_max,
frame_skip,
0.0f0,
lives(ale),
seed[1]
)
end

function interact!(env::AtariEnv, a)
env.reward = act(env.ale, env.actions[a])
env.getscreen!(env.ale, env.screen)
update_screen!(env::AtariEnv{true}, screen) = ArcadeLearningEnvironment.getScreenGrayscale!(env.ale, vec(screen))
update_screen!(env::AtariEnv{false}, screen) = ArcadeLearningEnvironment.getScreenRGB!(env.ale, vec(screen))

function interact!(env::AtariEnv{is_gray_scale, is_terminal_on_life_loss}, a) where {is_gray_scale, is_terminal_on_life_loss}
r = 0.0f0

for i in 1:env.frame_skip
r += act(env.ale, env.actions[a])
if i == env.frame_skip
update_screen!(env, env.screens[1])
elseif i == env.frame_skip - 1
update_screen!(env, env.screens[2])
end
end

# max-pooling
if env.frame_skip > 1
env.screens[1] .= max.(env.screens[1], env.screens[2])
end

env.reward = r
nothing
end

observe(env::AtariEnv) =
Observation(reward = env.reward, terminal = game_over(env.ale), state = env.screen)
is_terminal(env::AtariEnv{<:Any, true}) = game_over(env.ale) || (lives(env.ale) < env.lives)
is_terminal(env::AtariEnv{<:Any, false}) = game_over(env.ale)

observe(env::AtariEnv) = Observation(reward = env.reward, terminal = is_terminal(env), state = env.screens[1])

function reset!(env::AtariEnv)
reset_game(env.ale)
for _ = 1:rand(0:env.noopmax)
for _ = 1:rand(env.seed, 0:env.noopmax)
act(env.ale, Int32(0))
end
env.getscreen!(env.ale, env.screen)
update_screen!(env, env.screens[1]) # no need to update env.screens[2]
env.reward = 0.0f0 # dummy
env.lives = lives(env.ale)
nothing
end

Expand Down
75 changes: 75 additions & 0 deletions test/atari.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
@testset "atari" begin
@testset "seed" begin
env = AtariEnv(;name="pong", seed=(123,456))
old_states = []
actions = [rand(action_space(env)) for i in 1:10, j in 1:100]

for i in 1:10
for j in 1:100
interact!(env, actions[i, j])
push!(old_states, copy(observe(env).state))
end
reset!(env)
end

env = AtariEnv(;name="pong", seed=(123,456))
new_states = []
for i in 1:10
for j in 1:100
interact!(env, actions[i, j])
push!(new_states, copy(observe(env).state))
end
reset!(env)
end

@test old_states == new_states
end

@testset "frame_skip" begin
env = AtariEnv(;name="pong", frame_skip=4, seed=(123,456))
states = []
actions = [rand(action_space(env)) for _ in 1:100]

for i in 1:100
interact!(env, actions[i])
push!(states, copy(observe(env).state))
end

env = AtariEnv(;name="pong", frame_skip=1, seed=(123,456))
for i in 1:100
interact!(env, actions[i])
interact!(env, actions[i])
interact!(env, actions[i])
s1 = copy(observe(env).state)
interact!(env, actions[i])
s2 = copy(observe(env).state)
@test states[i] == max.(s1, s2)
end
end

@testset "repeat_action_probability" begin
env = AtariEnv(;name="pong", repeat_action_probability=1.0, seed=(123,456))
states = []
actions = [rand(action_space(env)) for _ in 1:100]
for i in 1:100
interact!(env, actions[i])
push!(states, copy(observe(env).state))
end

env = AtariEnv(;name="pong", repeat_action_probability=1.0, seed=(123,456))
for i in 1:100
interact!(env, actions[1])
@test states[i] == observe(env).state
end
end

@testset "max_num_frames_per_episode" begin
for i in 1:10
env = AtariEnv(;name="pong", max_num_frames_per_episode=i, seed=(123,456))
for _ in 1:i
interact!(env, 1)
end
@test true == observe(env).terminal
end
end
end
2 changes: 1 addition & 1 deletion test/environments.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
:(deterministic_tree_MDP_with_rand_reward()),
:(deterministic_tree_MDP()),
:(deterministic_MDP()),
(:(AtariEnv($x)) for x in atari_env_names)...,
(:(AtariEnv(;name=$x)) for x in atari_env_names)...,
(:(GymEnv($x)) for x in gym_env_names)...,
]

Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ using Hanabi

include("spaces.jl")
include("environments.jl")

include("atari.jl")
end

0 comments on commit 8c21b67

Please sign in to comment.