diff --git a/src/abstractenv.jl b/src/abstractenv.jl index 97d8c7b..44ac680 100644 --- a/src/abstractenv.jl +++ b/src/abstractenv.jl @@ -20,6 +20,21 @@ function action_space end function observation_space end function render end +""" + Observation(;reward, terminal, state, meta...) + +The observation of an environment from the perspective of an agent. + +# Keywords & Fields + +- `reward`: the reward of an agent +- `terminal`: indicates that if the environment is terminated or not. +- `state`: the current state of the environment from the perspective of an agent +- `meta`: some other information, like `legal_actions`... + +!!! note + The `reward` and `terminal` of the first observation before interacting with an environment may not be valid. +""" struct Observation{R,T,S,M<:NamedTuple} reward::R terminal::T diff --git a/src/environments/atari.jl b/src/environments/atari.jl index 16827cf..9279588 100644 --- a/src/environments/atari.jl +++ b/src/environments/atari.jl @@ -1,93 +1,136 @@ -using ArcadeLearningEnvironment, GR +using ArcadeLearningEnvironment, GR, Random export AtariEnv -mutable struct AtariEnv{To,F} <: AbstractEnv +mutable struct AtariEnv{IsGrayScale, TerminalOnLifeLoss, N, S<:AbstractRNG} <: AbstractEnv ale::Ptr{Nothing} - screen::Array{UInt8,1} - getscreen!::F - actions::Array{Int64,1} + screens::Tuple{Array{UInt8, N}, Array{UInt8, N}} # for max-pooling + actions::Vector{Int64} action_space::DiscreteSpace{Int} - observation_space::To + observation_space::MultiDiscreteSpace{UInt8, N} noopmax::Int + frame_skip::Int reward::Float32 + lives::Int + seed::S end """ - AtariEnv(name; colorspace = "Grayscale", frame_skip = 4, noopmax = 20, - color_averaging = true, repeat_action_probability = 0.) -Returns an AtariEnv that can be used in an RLSetup of the -[ReinforcementLearning](https://github.com/jbrea/ReinforcementLearning.jl) -package. Check the deps/roms folder of the ArcadeLearningEnvironment package to -see all available `name`s. + AtariEnv(;kwargs...) + +This implementation follows the guidelines in [Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents](https://arxiv.org/abs/1709.06009) + +TODO: support seed! in single/multi thread + +# Keywords + +- `name::String="pong"`: name of the Atari environments. Use `getROMList` to show all supported environments. +- `grayscale_obs::Bool=true`:if `true`, then gray scale observation is returned, otherwise, RGB observation is returned. +- `noop_max::Int=30`: max number of no-ops. +- `frame_skip::Int=4`: the frequency at which the agent experiences the game. +- `terminal_on_life_loss::Bool=false`: if `true`, then game is over whenever a life is lost. +- `repeat_action_probability::Float64=0.` +- `color_averaging::Bool=false`: whether to perform phosphor averaging or not. +- `max_num_frames_per_episode::Int=0` +- `full_action_space::Bool=false`: by default, only use minimal action set. If `true`, one need to call `legal_actions` to get the valid action set. TODO + +See also the [python implementation](https://github.com/openai/gym/blob/c072172d64bdcd74313d97395436c592dc836d5c/gym/wrappers/atari_preprocessing.py#L8-L36) """ function AtariEnv( - name; - colorspace = "Grayscale", + ;name = "pong", + grayscale_obs=true, + noop_max = 30, frame_skip = 4, - noopmax = 20, - color_averaging = true, - actionset = :minimal, - repeat_action_probability = 0., + terminal_on_life_loss=false, + repeat_action_probability=0., + color_averaging=false, + max_num_frames_per_episode=0, + full_action_space=false, + seed=nothing ) + frame_skip > 0 || throw(ArgumentError("frame_skip must be greater than 0!")) + name in getROMList() || throw(ArgumentError("unknown ROM name! run `getROMList()` to see all the game names.")) + + if isnothing(seed) + seed = (MersenneTwister(), 0) + elseif seed isa Tuple{Int, Int} + seed = (MersenneTwister(seed[1]), seed[2]) + else + @error "You must specify two seeds, one for Julia wrapper, one for internal C implementation" # ??? maybe auto generate two seed from one + end + ale = ALE_new() - setBool(ale, "color_averaging", color_averaging) - setInt(ale, "frame_skip", Int32(frame_skip)) + setInt(ale, "random_seed", seed[2]) + setInt(ale, "frame_skip", Int32(1)) # !!! do not use internal frame_skip here, we need to apply max-pooling for the latest two frames, so we need to manually implement the mechanism. + setInt(ale, "max_num_frames_per_episode", max_num_frames_per_episode) setFloat(ale, "repeat_action_probability", Float32(repeat_action_probability)) + setBool(ale, "color_averaging", color_averaging) loadROM(ale, name) - observation_length = getScreenWidth(ale) * getScreenHeight(ale) - if colorspace == "Grayscale" - screen = Array{Cuchar}(undef, observation_length) - getscreen! = ArcadeLearningEnvironment.getScreenGrayscale! - observation_space = MultiDiscreteSpace( - fill(typemax(Cuchar), observation_length), - fill(typemin(Cuchar), observation_length), - ) - elseif colorspace == "RGB" - screen = Array{Cuchar}(undef, 3 * observation_length) - getscreen! = ArcadeLearningEnvironment.getScreenRGB! - observation_space = MultiDiscreteSpace( - fill(typemax(Cuchar), 3 * observation_length), - fill(typemin(Cuchar), 3 * observation_length), - ) - elseif colorspace == "Raw" - screen = Array{Cuchar}(undef, observation_length) - getscreen! = ArcadeLearningEnvironment.getScreen! - observation_space = MultiDiscreteSpace( - fill(typemax(Cuchar), observation_length), - fill(typemin(Cuchar), observation_length), - ) - end - actions = actionset == :minimal ? getMinimalActionSet(ale) : getLegalActionSet(ale) + + observation_size = grayscale_obs ? (getScreenWidth(ale), getScreenHeight(ale)) : (3, getScreenWidth(ale), getScreenHeight(ale)) # !!! note the order + observation_space = MultiDiscreteSpace( + fill(typemax(Cuchar), observation_size), + fill(typemin(Cuchar), observation_size), + ) + + actions = full_action_space ? getLegalActionSet(ale) : getMinimalActionSet(ale) action_space = DiscreteSpace(length(actions)) - AtariEnv( + screens = ( + fill(typemin(Cuchar), observation_size), + fill(typemin(Cuchar), observation_size), + ) + + AtariEnv{grayscale_obs, terminal_on_life_loss, grayscale_obs ? 2 : 3, typeof(seed[1])}( ale, - screen, - getscreen!, + screens, actions, action_space, observation_space, - noopmax, + noop_max, + frame_skip, 0.0f0, + lives(ale), + seed[1] ) end -function interact!(env::AtariEnv, a) - env.reward = act(env.ale, env.actions[a]) - env.getscreen!(env.ale, env.screen) +update_screen!(env::AtariEnv{true}, screen) = ArcadeLearningEnvironment.getScreenGrayscale!(env.ale, vec(screen)) +update_screen!(env::AtariEnv{false}, screen) = ArcadeLearningEnvironment.getScreenRGB!(env.ale, vec(screen)) + +function interact!(env::AtariEnv{is_gray_scale, is_terminal_on_life_loss}, a) where {is_gray_scale, is_terminal_on_life_loss} + r = 0.0f0 + + for i in 1:env.frame_skip + r += act(env.ale, env.actions[a]) + if i == env.frame_skip + update_screen!(env, env.screens[1]) + elseif i == env.frame_skip - 1 + update_screen!(env, env.screens[2]) + end + end + + # max-pooling + if env.frame_skip > 1 + env.screens[1] .= max.(env.screens[1], env.screens[2]) + end + + env.reward = r nothing end -observe(env::AtariEnv) = - Observation(reward = env.reward, terminal = game_over(env.ale), state = env.screen) +is_terminal(env::AtariEnv{<:Any, true}) = game_over(env.ale) || (lives(env.ale) < env.lives) +is_terminal(env::AtariEnv{<:Any, false}) = game_over(env.ale) + +observe(env::AtariEnv) = Observation(reward = env.reward, terminal = is_terminal(env), state = env.screens[1]) function reset!(env::AtariEnv) reset_game(env.ale) - for _ = 1:rand(0:env.noopmax) + for _ = 1:rand(env.seed, 0:env.noopmax) act(env.ale, Int32(0)) end - env.getscreen!(env.ale, env.screen) + update_screen!(env, env.screens[1]) # no need to update env.screens[2] env.reward = 0.0f0 # dummy + env.lives = lives(env.ale) nothing end diff --git a/test/atari.jl b/test/atari.jl new file mode 100644 index 0000000..310309c --- /dev/null +++ b/test/atari.jl @@ -0,0 +1,75 @@ +@testset "atari" begin + @testset "seed" begin + env = AtariEnv(;name="pong", seed=(123,456)) + old_states = [] + actions = [rand(action_space(env)) for i in 1:10, j in 1:100] + + for i in 1:10 + for j in 1:100 + interact!(env, actions[i, j]) + push!(old_states, copy(observe(env).state)) + end + reset!(env) + end + + env = AtariEnv(;name="pong", seed=(123,456)) + new_states = [] + for i in 1:10 + for j in 1:100 + interact!(env, actions[i, j]) + push!(new_states, copy(observe(env).state)) + end + reset!(env) + end + + @test old_states == new_states + end + + @testset "frame_skip" begin + env = AtariEnv(;name="pong", frame_skip=4, seed=(123,456)) + states = [] + actions = [rand(action_space(env)) for _ in 1:100] + + for i in 1:100 + interact!(env, actions[i]) + push!(states, copy(observe(env).state)) + end + + env = AtariEnv(;name="pong", frame_skip=1, seed=(123,456)) + for i in 1:100 + interact!(env, actions[i]) + interact!(env, actions[i]) + interact!(env, actions[i]) + s1 = copy(observe(env).state) + interact!(env, actions[i]) + s2 = copy(observe(env).state) + @test states[i] == max.(s1, s2) + end + end + + @testset "repeat_action_probability" begin + env = AtariEnv(;name="pong", repeat_action_probability=1.0, seed=(123,456)) + states = [] + actions = [rand(action_space(env)) for _ in 1:100] + for i in 1:100 + interact!(env, actions[i]) + push!(states, copy(observe(env).state)) + end + + env = AtariEnv(;name="pong", repeat_action_probability=1.0, seed=(123,456)) + for i in 1:100 + interact!(env, actions[1]) + @test states[i] == observe(env).state + end + end + + @testset "max_num_frames_per_episode" begin + for i in 1:10 + env = AtariEnv(;name="pong", max_num_frames_per_episode=i, seed=(123,456)) + for _ in 1:i + interact!(env, 1) + end + @test true == observe(env).terminal + end + end +end \ No newline at end of file diff --git a/test/environments.jl b/test/environments.jl index 914fca3..83330cf 100644 --- a/test/environments.jl +++ b/test/environments.jl @@ -62,7 +62,7 @@ :(deterministic_tree_MDP_with_rand_reward()), :(deterministic_tree_MDP()), :(deterministic_MDP()), - (:(AtariEnv($x)) for x in atari_env_names)..., + (:(AtariEnv(;name=$x)) for x in atari_env_names)..., (:(GymEnv($x)) for x in gym_env_names)..., ] diff --git a/test/runtests.jl b/test/runtests.jl index 8687c0d..b18c533 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,5 +10,5 @@ using Hanabi include("spaces.jl") include("environments.jl") - + include("atari.jl") end \ No newline at end of file