Skip to content

Commit

Permalink
removed AbstractMarkovEnv and AbstractZeroSumEnv to fix #43
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Jan 2, 2021
1 parent c4d14e7 commit dcfffb6
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 66 deletions.
10 changes: 1 addition & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ This package is designed for two reasons:

## Required Interface

To accomplish this, there are two abstract environment types:
- `AbstractMarkovEnv`, which represents a (PO)MDP with a single player
- `AbstractZeroSumEnv`, which represents a two-player zero sum game

`AbstractEnv` is a base type for all environments.

The interface has five required functions for all `AbstractEnv`s:
Expand All @@ -26,11 +22,6 @@ act!(env, a) # steps the environment forward and returns a reward
terminated(env) # returns true or false indicating whether the environment has finished
```

For `AbstractZeroSumEnv`, there is an additional required function,
```julia
player(env) # returns the index of the current player
```

## Optional Interface

There are several additional functions that are currently optional:
Expand All @@ -41,6 +32,7 @@ There are several additional functions that are currently optional:
- `valid_actions`
- `valid_action_mask`
- `observations`
- `player`

To see documentation for one of these functions, use [Julia's built-in help system](https://docs.julialang.org/en/v1/manual/documentation/index.html#Accessing-Documentation-1).

Expand Down
2 changes: 1 addition & 1 deletion examples/gridworld.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import ColorSchemes

const RL = CommonRLInterface

mutable struct GridWorld <: AbstractMarkovEnv
mutable struct GridWorld <: AbstractEnv
size::SVector{2, Int}
rewards::Dict{SVector{2, Int}, Float64}
state::SVector{2, Int}
Expand Down
63 changes: 12 additions & 51 deletions src/CommonRLInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ using MacroTools

export
AbstractEnv,
AbstractMarkovEnv,
AbstractZeroSumEnv,
reset!,
actions,
observe,
Expand All @@ -15,35 +13,6 @@ export

abstract type AbstractEnv end

"""
An environment that represents an MDP or a POMPD. There is only one player in this environment.
The required interface consists of
```
reset!(env)
actions(env)
observe(env)
act!(env, a)
terminated(env)
```
"""
abstract type AbstractMarkovEnv <: AbstractEnv end

"""
An environment that represents a zero-sum two player game, where only one player plays at a time. Player 1 seeks to maximize the reward; player 2 seeks to minimize it.
The required interface consists of
```
reset!(env)
actions(env)
player(env)
observe(env)
act!(env, a)
terminated(env)
```
"""
abstract type AbstractZeroSumEnv <: AbstractEnv end

"""
reset!(env::AbstractEnv)
Expand All @@ -64,7 +33,7 @@ This function is a *static property* of the environment; the value it returns sh
---
actions(env::AbstractZeroSumEnv, i::Integer)
actions(env::AbstractEnv, i::Integer)
Return a collection of all the actions available to player i.
Expand All @@ -75,33 +44,25 @@ function actions end
"""
observe(env::AbstractEnv)
Return an observation from the environment.
Return an observation from the environment for the current player.
This is a *required function* that must be provided by every AbstractEnv.
---
observe(env::AbstractZeroSumEnv)
Return an observation from the environment for the current player.
"""
function observe end

"""
r = act!(env::AbstractEnv, a)
Take action `a` and advance AbstractEnv `env` forward one step.
Take action `a` for the current player, advance AbstractEnv `env` forward one step, and return rewards for all players.
This is a *required function* that must be provided by every AbstractEnv and should return a reward, Boolean done signal, and any extra information for debugging or human understanding (typically in a NamedTuple).
This is a *required function* that must be provided by every AbstractEnv.
"""
function act! end

"""
player(env::AbstractZeroSumEnv)
player(env::AbstractEnv)
Return the index of the player who should play next in the environment.
This is a *required function* for all `AbstractZeroSumEnvs`.
"""
function player end

Expand Down Expand Up @@ -144,7 +105,6 @@ provided(::typeof(reset!), ::Type{<:Tuple{AbstractEnv}}) = true
provided(::typeof(actions), ::Type{<:Tuple{AbstractEnv}}) = true
provided(::typeof(observe), ::Type{<:Tuple{AbstractEnv}}) = true
provided(::typeof(act!), ::Type{<:Tuple{AbstractEnv, Any}}) = true
provided(::typeof(player), ::Type{<:Tuple{AbstractZeroSumEnv}}) = true

"""
@provide f(x::X) = x^2
Expand All @@ -157,15 +117,16 @@ This will automatically implement the appropriate methods of `provided`. Both th
```jldoctest
using CommonRLInterface
struct MyEnv <: AbstractEnv end
struct MyEnv <: AbstractEnv
s::Int
end
@assert provided(clone, MyEnv()) == false
@assert provided(clone, MyEnv(1)) == false
@provide function CommonRLInterface.clone(env::MyEnv)
return deepcopy(env)
end
@provide CommonRLInterface.clone(env::MyEnv) = MyEnv(env.s)
@assert provided(clone, MyEnv()) == true
@assert provided(clone, MyEnv(1)) == true
@assert clone(MyEnv(1)) == MyEnv(1)
"""
macro provide(f)
def = splitdef(f) # TODO: probably give a better error message that mentions @provide if this fails
Expand Down
10 changes: 5 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ end
@show rsum
end

# a reference MarkovEnv
mutable struct MyEnv <: AbstractMarkovEnv
# a reference MDP Env
mutable struct MyEnv <: AbstractEnv
state::Int
end
MyEnv() = MyEnv(1)
Expand All @@ -50,8 +50,8 @@ function CommonRLInterface.act!(env::MyEnv, a)
end
env = MyEnv(1)

# a reference ZeroSumEnv
mutable struct MyGame <: AbstractZeroSumEnv
# a reference Game Env
mutable struct MyGame <: AbstractEnv
state::Int
end
MyGame() = MyGame(1)
Expand All @@ -65,7 +65,7 @@ function CommonRLInterface.act!(env::MyGame, a)
env.state = clamp(env.state + a, 1, 10)
return -o^2
end
CommonRLInterface.player(env::MyGame) = 1 + iseven(env.state)
@provide CommonRLInterface.player(env::MyGame) = 1 + iseven(env.state)
game = MyGame()

function f end
Expand Down

0 comments on commit dcfffb6

Please sign in to comment.