Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version of Frozen Lake based on the implementation #186

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/envs/envs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ include("snake.jl")
include("catcher.jl")
include("transport_undirected.jl")
include("transport_directed.jl")
include("frozen_lake_undirected.jl")

const ENVS = [
SingleRoomUndirectedModule.SingleRoomUndirected,
Expand All @@ -46,4 +47,5 @@ const ENVS = [
CatcherModule.Catcher,
TransportUndirectedModule.TransportUndirected,
TransportDirectedModule.TransportDirected,
FrozenLakeUndirectedModule.FrozenLakeUndirected
]
214 changes: 214 additions & 0 deletions src/envs/frozen_lake_undirected.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
module FrozenLakeUndirectedModule

import ..GridWorlds as GW
import Random
import ReinforcementLearningBase as RLBase

#####
##### game logic
#####

const NUM_OBJECTS = 4
const AGENT = 1
const WALL = 2
const GOAL = 3
const OBSTACLE = 4
LooseTerrifyingSpaceMonkey marked this conversation as resolved.
Show resolved Hide resolved
const NUM_ACTIONS = 4

mutable struct FrozenLakeUndirected{R, RNG} <: GW.AbstractGridWorld
tile_map::BitArray{3}
map_name::String
agent_position::CartesianIndex{2}
reward::R
rng::RNG
done::Bool
terminal_reward::R
terminal_penalty::R
goal_position::CartesianIndex{2}
num_obstacles::Int
obstacle_positions::Vector{CartesianIndex{2}}
LooseTerrifyingSpaceMonkey marked this conversation as resolved.
Show resolved Hide resolved
is_slippery::Bool
end

function FrozenLakeUndirected(; map_name = String, R = Float32, height = 8, width = 8, num_obstacles = floor(Int, sqrt(height * width) / 2), rng = Random.GLOBAL_RNG, is_slippery = true)
obstacle_positions = Array{CartesianIndex{2}}(undef, num_obstacles)
if map_name == "4x4"
LooseTerrifyingSpaceMonkey marked this conversation as resolved.
Show resolved Hide resolved
height = 4
width = 4
num_obstacles = 4
obstacle_positions = Array{CartesianIndex{2}}(undef, num_obstacles)
obstacle_positions = [CartesianIndex(3, 3), CartesianIndex(3, 5), CartesianIndex(4, 5), CartesianIndex(5, 2)]
elseif map_name == "8x8"
height = 8
width = 8
num_obstacles = 10
obstacle_positions = Array{CartesianIndex{2}}(undef, num_obstacles)
obstacle_positions = [CartesianIndex(4, 5), CartesianIndex(5, 7), CartesianIndex(6, 5), CartesianIndex(7, 3), CartesianIndex(7, 4), CartesianIndex(7, 8), CartesianIndex(8, 3), CartesianIndex(8, 6), CartesianIndex(8, 8), CartesianIndex(9, 5)]
end

print("Obstacle Positions: ", obstacle_positions, " Height: ", height, " Width: ", width, "\n")
LooseTerrifyingSpaceMonkey marked this conversation as resolved.
Show resolved Hide resolved
tile_map = falses(NUM_OBJECTS, height + 2, width + 2)
LooseTerrifyingSpaceMonkey marked this conversation as resolved.
Show resolved Hide resolved

tile_map[WALL, 1, :] .= true
tile_map[WALL, height + 2, :] .= true
tile_map[WALL, :, 1] .= true
tile_map[WALL, :, width + 2] .= true

agent_position = CartesianIndex(2, 2)
LooseTerrifyingSpaceMonkey marked this conversation as resolved.
Show resolved Hide resolved
tile_map[AGENT, agent_position] = true

goal_position = CartesianIndex(height + 1, width + 1)
tile_map[GOAL, goal_position] = true

if map_name === nothing
LooseTerrifyingSpaceMonkey marked this conversation as resolved.
Show resolved Hide resolved
obstacle_positions = Array{CartesianIndex{2}}(undef, num_obstacles)
for i in 1:num_obstacles
obstacle_position = GW.sample_empty_position(rng, tile_map)
obstacle_positions[i] = obstacle_position
end
end

tile_map = update_obstacles_on_map(tile_map, obstacle_positions)

reward = zero(R)
done = false
terminal_reward = one(R)
terminal_penalty = -one(R)

env = FrozenLakeUndirected(tile_map, map_name, agent_position, reward, rng, done, terminal_reward, terminal_penalty, goal_position, num_obstacles, obstacle_positions, is_slippery)

# GW.reset!(env)

return env
end

function update_obstacles_on_map(tile_map, obstacle_positions)
for position in obstacle_positions
tile_map[OBSTACLE, position] = true
end
return tile_map
end

function GW.reset!(env::FrozenLakeUndirected)
tile_map = env.tile_map

tile_map[AGENT, env.agent_position] = false

agent_position = CartesianIndex(2, 2)
env.agent_position = agent_position
tile_map[AGENT, agent_position] = true

env.reward = zero(env.reward)
env.done = false

return nothing
end

function GW.act!(env::FrozenLakeUndirected, action)
@assert action in Base.OneTo(NUM_ACTIONS) "Invalid action $(action). Action must be in Base.OneTo($(NUM_ACTIONS))"

tile_map = env.tile_map

agent_position = env.agent_position

is_slippery = env.is_slippery

if action == 1
new_agent_position = is_slippery ? rand((GW.move_up(agent_position), GW.move_left(agent_position), GW.move_right(agent_position))) : GW.move_up(agent_position)
elseif action == 2
new_agent_position = is_slippery ? rand((GW.move_down(agent_position), GW.move_left(agent_position), GW.move_right(agent_position))) : GW.move_down(agent_position)
elseif action == 3
new_agent_position = is_slippery ? rand((GW.move_left(agent_position), GW.move_up(agent_position), GW.move_down(agent_position))) : GW.move_left(agent_position)
else
new_agent_position = is_slippery ? rand((GW.move_right(agent_position), GW.move_up(agent_position), GW.move_down(agent_position))) : GW.move_right(agent_position)
end

if !tile_map[WALL, new_agent_position]
tile_map[AGENT, agent_position] = false
env.agent_position = new_agent_position
tile_map[AGENT, new_agent_position] = true
end

if tile_map[GOAL, env.agent_position]
env.reward = env.terminal_reward
env.done = true
elseif tile_map[OBSTACLE, env.agent_position]
env.reward = env.terminal_penalty
env.done = true
else
env.reward = zero(env.reward)
env.done = false
end

return nothing
end

#####
##### miscellaneous
#####

GW.get_height(env::FrozenLakeUndirected) = size(env.tile_map, 2)
GW.get_width(env::FrozenLakeUndirected) = size(env.tile_map, 3)

GW.get_action_names(env::FrozenLakeUndirected) = (:MOVE_UP, :MOVE_DOWN, :MOVE_LEFT, :MOVE_RIGHT)
GW.get_object_names(env::FrozenLakeUndirected) = (:AGENT, :WALL, :GOAL, :OBSTACLE)

function GW.get_pretty_tile_map(env::FrozenLakeUndirected, position::CartesianIndex{2})
characters = ('☻', '█', '♥', '⊗', '⋅')
LooseTerrifyingSpaceMonkey marked this conversation as resolved.
Show resolved Hide resolved

object = findfirst(@view env.tile_map[:, position])
if isnothing(object)
return characters[end]
else
return characters[object]
end
end

function GW.get_pretty_sub_tile_map(env::FrozenLakeUndirected, window_size, position::CartesianIndex{2})
tile_map = env.tile_map
agent_position = env.agent_position

characters = ('☻', '█', '♥', '⊗', '⋅')

sub_tile_map = GW.get_sub_tile_map(tile_map, agent_position, window_size)

object = findfirst(@view sub_tile_map[:, position])
if isnothing(object)
return characters[end]
else
return characters[object]
end
end

function Base.show(io::IO, ::MIME"text/plain", env::FrozenLakeUndirected)
str = "tile_map:\n"
str = str * GW.get_pretty_tile_map(env)
str = str * "\nsub_tile_map:\n"
str = str * GW.get_pretty_sub_tile_map(env, GW.get_window_size(env))
str = str * "\nreward: $(env.reward)"
str = str * "\ndone: $(env.done)"
str = str * "\naction_names: $(GW.get_action_names(env))"
str = str * "\nobject_names: $(GW.get_object_names(env))"
print(io, str)
return nothing
end

GW.get_action_keys(env::FrozenLakeUndirected) = ('w', 's', 'a', 'd')

#####
##### FrozenLakeUndirected
#####

RLBase.StateStyle(env::GW.RLBaseEnv{E}) where {E <: FrozenLakeUndirected} = RLBase.InternalState{Any}()
RLBase.state_space(env::GW.RLBaseEnv{E}, ::RLBase.InternalState) where {E <: FrozenLakeUndirected} = nothing
RLBase.state(env::GW.RLBaseEnv{E}, ::RLBase.InternalState) where {E <: FrozenLakeUndirected} = env.env.tile_map

RLBase.reset!(env::GW.RLBaseEnv{E}) where {E <: FrozenLakeUndirected} = GW.reset!(env.env)

RLBase.action_space(env::GW.RLBaseEnv{E}) where {E <: FrozenLakeUndirected} = Base.OneTo(NUM_ACTIONS)
(env::GW.RLBaseEnv{E})(action) where {E <: FrozenLakeUndirected} = GW.act!(env.env, action)

RLBase.reward(env::GW.RLBaseEnv{E}) where {E <: FrozenLakeUndirected} = env.env.reward
RLBase.is_terminated(env::GW.RLBaseEnv{E}) where {E <: FrozenLakeUndirected} = env.env.done

end # module
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ get_terminal_returns(env::GW.RLBaseEnv{E}) where {E <: GW.CatcherModule.Catcher}
get_terminal_returns(env::GW.RLBaseEnv{E}) where {E <: GW.TransportUndirectedModule.TransportUndirected} = (env.env.terminal_reward,)
get_terminal_returns(env::GW.RLBaseEnv{E}) where {E <: GW.TransportDirectedModule.TransportDirected} = (env.env.env.terminal_reward,)

get_terminal_returns(env::GW.RLBaseEnv{E}) where {E <: GW.FrozenLakeUndirectedModule.FrozenLakeUndirected} = (env.env.terminal_reward, env.env.terminal_penalty)

function is_valid_terminal_return(env::GW.RLBaseEnv{E}, terminal_return) where {E <: GW.SnakeModule.Snake}
terminal_reward = env.env.terminal_reward
terminal_penalty = env.env.terminal_penalty
Expand Down