From dcfffb69d3a86872f813c2cd79cf926eecdd9db8 Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Sat, 2 Jan 2021 15:32:57 -0700 Subject: [PATCH] removed AbstractMarkovEnv and AbstractZeroSumEnv to fix #43 --- README.md | 10 +------ examples/gridworld.jl | 2 +- src/CommonRLInterface.jl | 63 ++++++++-------------------------------- test/runtests.jl | 10 +++---- 4 files changed, 19 insertions(+), 66 deletions(-) diff --git a/README.md b/README.md index 6f8ee9e..3e045ab 100644 --- a/README.md +++ b/README.md @@ -11,10 +11,6 @@ This package is designed for two reasons: ## Required Interface -To accomplish this, there are two abstract environment types: -- `AbstractMarkovEnv`, which represents a (PO)MDP with a single player -- `AbstractZeroSumEnv`, which represents a two-player zero sum game - `AbstractEnv` is a base type for all environments. The interface has five required functions for all `AbstractEnv`s: @@ -26,11 +22,6 @@ act!(env, a) # steps the environment forward and returns a reward terminated(env) # returns true or false indicating whether the environment has finished ``` -For `AbstractZeroSumEnv`, there is an additional required function, -```julia -player(env) # returns the index of the current player -``` - ## Optional Interface There are several additional functions that are currently optional: @@ -41,6 +32,7 @@ There are several additional functions that are currently optional: - `valid_actions` - `valid_action_mask` - `observations` +- `player` To see documentation for one of these functions, use [Julia's built-in help system](https://docs.julialang.org/en/v1/manual/documentation/index.html#Accessing-Documentation-1). diff --git a/examples/gridworld.jl b/examples/gridworld.jl index 56b336f..59eafa2 100644 --- a/examples/gridworld.jl +++ b/examples/gridworld.jl @@ -8,7 +8,7 @@ import ColorSchemes const RL = CommonRLInterface -mutable struct GridWorld <: AbstractMarkovEnv +mutable struct GridWorld <: AbstractEnv size::SVector{2, Int} rewards::Dict{SVector{2, Int}, Float64} state::SVector{2, Int} diff --git a/src/CommonRLInterface.jl b/src/CommonRLInterface.jl index e68f8aa..965550c 100644 --- a/src/CommonRLInterface.jl +++ b/src/CommonRLInterface.jl @@ -4,8 +4,6 @@ using MacroTools export AbstractEnv, - AbstractMarkovEnv, - AbstractZeroSumEnv, reset!, actions, observe, @@ -15,35 +13,6 @@ export abstract type AbstractEnv end -""" -An environment that represents an MDP or a POMPD. There is only one player in this environment. - -The required interface consists of -``` -reset!(env) -actions(env) -observe(env) -act!(env, a) -terminated(env) -``` -""" -abstract type AbstractMarkovEnv <: AbstractEnv end - -""" -An environment that represents a zero-sum two player game, where only one player plays at a time. Player 1 seeks to maximize the reward; player 2 seeks to minimize it. - -The required interface consists of -``` -reset!(env) -actions(env) -player(env) -observe(env) -act!(env, a) -terminated(env) -``` -""" -abstract type AbstractZeroSumEnv <: AbstractEnv end - """ reset!(env::AbstractEnv) @@ -64,7 +33,7 @@ This function is a *static property* of the environment; the value it returns sh --- - actions(env::AbstractZeroSumEnv, i::Integer) + actions(env::AbstractEnv, i::Integer) Return a collection of all the actions available to player i. @@ -75,33 +44,25 @@ function actions end """ observe(env::AbstractEnv) -Return an observation from the environment. +Return an observation from the environment for the current player. This is a *required function* that must be provided by every AbstractEnv. - ---- - - observe(env::AbstractZeroSumEnv) - -Return an observation from the environment for the current player. """ function observe end """ r = act!(env::AbstractEnv, a) -Take action `a` and advance AbstractEnv `env` forward one step. +Take action `a` for the current player, advance AbstractEnv `env` forward one step, and return rewards for all players. -This is a *required function* that must be provided by every AbstractEnv and should return a reward, Boolean done signal, and any extra information for debugging or human understanding (typically in a NamedTuple). +This is a *required function* that must be provided by every AbstractEnv. """ function act! end """ - player(env::AbstractZeroSumEnv) + player(env::AbstractEnv) Return the index of the player who should play next in the environment. - -This is a *required function* for all `AbstractZeroSumEnvs`. """ function player end @@ -144,7 +105,6 @@ provided(::typeof(reset!), ::Type{<:Tuple{AbstractEnv}}) = true provided(::typeof(actions), ::Type{<:Tuple{AbstractEnv}}) = true provided(::typeof(observe), ::Type{<:Tuple{AbstractEnv}}) = true provided(::typeof(act!), ::Type{<:Tuple{AbstractEnv, Any}}) = true -provided(::typeof(player), ::Type{<:Tuple{AbstractZeroSumEnv}}) = true """ @provide f(x::X) = x^2 @@ -157,15 +117,16 @@ This will automatically implement the appropriate methods of `provided`. Both th ```jldoctest using CommonRLInterface -struct MyEnv <: AbstractEnv end +struct MyEnv <: AbstractEnv + s::Int +end -@assert provided(clone, MyEnv()) == false +@assert provided(clone, MyEnv(1)) == false -@provide function CommonRLInterface.clone(env::MyEnv) - return deepcopy(env) -end +@provide CommonRLInterface.clone(env::MyEnv) = MyEnv(env.s) -@assert provided(clone, MyEnv()) == true +@assert provided(clone, MyEnv(1)) == true +@assert clone(MyEnv(1)) == MyEnv(1) """ macro provide(f) def = splitdef(f) # TODO: probably give a better error message that mentions @provide if this fails diff --git a/test/runtests.jl b/test/runtests.jl index 84f1f0d..710852c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,8 +33,8 @@ end @show rsum end -# a reference MarkovEnv -mutable struct MyEnv <: AbstractMarkovEnv +# a reference MDP Env +mutable struct MyEnv <: AbstractEnv state::Int end MyEnv() = MyEnv(1) @@ -50,8 +50,8 @@ function CommonRLInterface.act!(env::MyEnv, a) end env = MyEnv(1) -# a reference ZeroSumEnv -mutable struct MyGame <: AbstractZeroSumEnv +# a reference Game Env +mutable struct MyGame <: AbstractEnv state::Int end MyGame() = MyGame(1) @@ -65,7 +65,7 @@ function CommonRLInterface.act!(env::MyGame, a) env.state = clamp(env.state + a, 1, 10) return -o^2 end -CommonRLInterface.player(env::MyGame) = 1 + iseven(env.state) +@provide CommonRLInterface.player(env::MyGame) = 1 + iseven(env.state) game = MyGame() function f end