Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version of Frozen Lake based on the implementation #186

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ This package is inspired by [gym-minigrid](https://github.com/maximecb/gym-minig
1. [Catcher](#catcher)
1. [TransportUndirected](#transportundirected)
1. [TransportDirected](#transportdirected)
1. [FrozenLakeUndirected](#frozenlakeundirected)

## Getting Started

Expand Down Expand Up @@ -355,3 +356,10 @@ In `ReinforcementLearning.jl`, you can create a [hook](https://juliareinforcemen

<img src="https://user-images.githubusercontent.com/32610387/126910050-723e100c-c5c7-4703-8eab-5ab86a15e41f.png">
<img src="https://user-images.githubusercontent.com/32610387/126909921-fdb3c853-4cac-4e6a-b20c-604caf5632e0.gif">

1. ### FrozenLakeUndirected

The objective of the agent is to navigate its way to the goal while avoiding falling into the holes in the lake. When the agent reaches the goal, it receives a reward of 1 and the environment terminates. If the agent collides falls into a hole, the agent receives a reward of -1 and the environment terminates. The probablility of moving in the direction given by an agent is 1/3 while there is 1/3 chance to move in either perpendicular direction (for example: 1/3 chance to move up, 1/3 chance to move left and 1/3 chance to move right if the agent chose up). The scenario is based on the [Frozen Lake environment](https://gymnasium.farama.org/environments/toy_text/frozen_lake/) in Python's gymnasium. In the Python version there are two preset maps: "4x4" and "8x8". The GridWorlds implementation includes the walls as part of the dimensions, so the equivalent maps in GridWorlds is "6x6" and "10x10" respectively. The start, goal, and holes are located in the same positions in the lake as the Python version. If specifying custom height and widths keep in mind it is going to add walls all around the map so the actual surface of the lake is (height - 2, width - 2).

<img src="https://user-images.githubusercontent.com/32610387/126910030-d93a714d-10b7-4117-887c-773afe78c625.png">
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The screenshot and gif are from DynamicObstaclesUndirected. We need to create new ones for FrozenLakeUndirected.

We can do that in the end.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are correct.

<img src="https://user-images.githubusercontent.com/32610387/126909888-8fa8473f-deb6-4562-9004-419fa8080693.gif">
2 changes: 2 additions & 0 deletions src/envs/envs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ include("snake.jl")
include("catcher.jl")
include("transport_undirected.jl")
include("transport_directed.jl")
include("frozen_lake_undirected.jl")

const ENVS = [
SingleRoomUndirectedModule.SingleRoomUndirected,
Expand All @@ -46,4 +47,5 @@ const ENVS = [
CatcherModule.Catcher,
TransportUndirectedModule.TransportUndirected,
TransportDirectedModule.TransportDirected,
FrozenLakeUndirectedModule.FrozenLakeUndirected
]
241 changes: 241 additions & 0 deletions src/envs/frozen_lake_undirected.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
module FrozenLakeUndirectedModule

import ..GridWorlds as GW
import Random
import ReinforcementLearningBase as RLBase
import AStarSearch.astar
LooseTerrifyingSpaceMonkey marked this conversation as resolved.
Show resolved Hide resolved

#####
##### game logic
#####

const NUM_OBJECTS = 4
const AGENT = 1
const WALL = 2
const GOAL = 3
const HOLE = 4
const NUM_ACTIONS = 4

mutable struct FrozenLakeUndirected{R, RNG} <: GW.AbstractGridWorld
tile_map::BitArray{3}
map_name::String
agent_position::CartesianIndex{2}
reward::R
rng::RNG
done::Bool
terminal_reward::R
terminal_penalty::R
goal_position::CartesianIndex{2}
num_holes::Int
hole_positions::Vector{CartesianIndex{2}}
is_slippery::Bool
randomize_start_end::Bool
end

function FrozenLakeUndirected(; map_name::String = "", R::Type = Float32, height::Int = 8, width::Int = 8, num_holes::Int = floor(Int, sqrt(height * width) / 2), rng = Random.GLOBAL_RNG, is_slippery::Bool = true, randomize_start_end::Bool = false)
hole_positions = Array{CartesianIndex{2}}(undef, num_holes)
if map_name == "6x6"
height = 6
width = 6
num_holes = 4
hole_positions = Array{CartesianIndex{2}}(undef, num_holes)
hole_positions = [CartesianIndex(3, 3), CartesianIndex(3, 5), CartesianIndex(4, 5), CartesianIndex(5, 2)]
elseif map_name == "10x10"
height = 10
width = 10
num_holes = 10
hole_positions = Array{CartesianIndex{2}}(undef, num_holes)
hole_positions = [CartesianIndex(4, 5), CartesianIndex(5, 7), CartesianIndex(6, 5), CartesianIndex(7, 3), CartesianIndex(7, 4), CartesianIndex(7, 8), CartesianIndex(8, 3), CartesianIndex(8, 6), CartesianIndex(8, 8), CartesianIndex(9, 5)]
elseif map_name != ""
throw(ArgumentError("Unsupported map_name value: '$(map_name)'. Please use '6x6', '10x10', or undefined."))
end

tile_map = falses(NUM_OBJECTS, height, width)

tile_map[WALL, 1, :] .= true
tile_map[WALL, height, :] .= true
tile_map[WALL, :, 1] .= true
tile_map[WALL, :, width] .= true

if randomize_start_end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This randomization should also be done when resetting the environment.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a significant change from the gymnasium model.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't used the gymnasium one. Do you mean for every episode, the map remains the same?
I think it is valuable to have an option to be able to randomize the map on every episode to facilitate generalization. Most other environments in this package are like that. We could add a boolean to toggle it off if a user wants to keep the same map and adhere to the gym specification. Does that sound reasonable?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the map remains the same. From what I have seen it is used for q-learning and dynamic programming examples. If the map changed after every episode the agent wouldn't learn. I will set a boolean so the user has the option to turn on the resetting.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide some guidance on unit testing with all these different flags? I don't understand how the settings are passed in runtests.jl

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a boolean works.

Unfortunately, automatic unit testing all flag combinations for these games is a bit hard, and isn't implemented in runtests.jl yet. So for now, you can just test the default settings for an environment in runtests.jl:

env = GW.RLBaseEnv(Env(R = R))
and test the rest of the combinations manually by directly playing these games directly in the terminal by doing GW.play!(env).

GW.sample_two_positions_without_replacement(rng, tile_map[2:height - 1, 2:width - 1])
else
agent_position = CartesianIndex(2, 2)
goal_position = CartesianIndex(height - 1, width - 1)
end

tile_map[AGENT, agent_position] = true
tile_map[GOAL, goal_position] = true

if map_name == ""
function get_neighbors(state::CartesianIndex)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reasons of having these functions defined inside the constructor? If not, we can move them outside. Also, we will need to run A* when resetting the environment too. So we will be needing them there.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because they are never used elsewhere. They won't be used when resetting the environment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reopening it as per above

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, just saying, I haven't seen a lot of code where functions are put inside other functions, even if they are not used elsewhere. In closures it could be useful when you want to capture other variable and such. But other than that, I haven't seen it happening much unless there is some specific reason. I may be wrong, but I think it is clearer to have them outside, unless there is a specific need, in my opinion.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually see this quite a bit within Python code written by academics. I don't see it by people in industry. I don't know why. I can change it though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Please go ahead with the change. Thanks.

return_list = []
for pos in (GW.move_up(state), GW.move_down(state), GW.move_left(state), GW.move_right(state))
if !tile_map[WALL, pos] && !tile_map[HOLE, pos]
push!(return_list, pos)
end
end
return return_list
end

manhattan(a::CartesianIndex, b::CartesianIndex) = sum(abs.((b-a).I))
is_goal(state::CartesianIndex, end_state::CartesianIndex) = state == end_state

distance_heuristic(state::CartesianIndex, end_state::CartesianIndex) = manhattan(state, end_state)

path_exists = false
hole_positions = Array{CartesianIndex{2}}(undef, num_holes)
while !path_exists
for i in 1:num_holes
hole_position = GW.sample_empty_position(rng, tile_map)
hole_positions[i] = hole_position
end
tile_map = update_holes_on_map(tile_map, hole_positions)

path_exists = astar(get_neighbors, agent_position, goal_position; heuristic = distance_heuristic, isgoal = is_goal).status == :success
@debug "path_exists: $(path_exists)"
end
end

tile_map = update_holes_on_map(tile_map, hole_positions)

reward = zero(R)
done = false
terminal_reward = one(R)
terminal_penalty = -one(R)

env = FrozenLakeUndirected(tile_map, map_name, agent_position, reward, rng, done, terminal_reward, terminal_penalty, goal_position, num_holes, hole_positions, is_slippery, randomize_start_end)

return env
end

function update_holes_on_map(tile_map, hole_positions)
for position in hole_positions
tile_map[HOLE, position] = true
end
return tile_map
end

function GW.reset!(env::FrozenLakeUndirected)
tile_map = env.tile_map

tile_map[AGENT, env.agent_position] = false

agent_position = CartesianIndex(2, 2)
env.agent_position = agent_position
tile_map[AGENT, agent_position] = true

env.reward = zero(env.reward)
env.done = false

return nothing
end

function GW.act!(env::FrozenLakeUndirected, action)
@assert action in Base.OneTo(NUM_ACTIONS) "Invalid action $(action). Action must be in Base.OneTo($(NUM_ACTIONS))"

tile_map = env.tile_map

agent_position = env.agent_position

is_slippery = env.is_slippery

if action == 1
new_agent_position = is_slippery ? rand((GW.move_up(agent_position), GW.move_left(agent_position), GW.move_right(agent_position))) : GW.move_up(agent_position)
elseif action == 2
new_agent_position = is_slippery ? rand((GW.move_down(agent_position), GW.move_left(agent_position), GW.move_right(agent_position))) : GW.move_down(agent_position)
elseif action == 3
new_agent_position = is_slippery ? rand((GW.move_left(agent_position), GW.move_up(agent_position), GW.move_down(agent_position))) : GW.move_left(agent_position)
else
new_agent_position = is_slippery ? rand((GW.move_right(agent_position), GW.move_up(agent_position), GW.move_down(agent_position))) : GW.move_right(agent_position)
end

if !tile_map[WALL, new_agent_position]
tile_map[AGENT, agent_position] = false
env.agent_position = new_agent_position
tile_map[AGENT, new_agent_position] = true
end

if tile_map[GOAL, env.agent_position]
env.reward = env.terminal_reward
env.done = true
elseif tile_map[HOLE, env.agent_position]
env.reward = env.terminal_penalty
env.done = true
else
env.reward = zero(env.reward)
env.done = false
end

return nothing
end

#####
##### miscellaneous
#####

GW.get_height(env::FrozenLakeUndirected) = size(env.tile_map, 2)
GW.get_width(env::FrozenLakeUndirected) = size(env.tile_map, 3)

GW.get_action_names(env::FrozenLakeUndirected) = (:MOVE_UP, :MOVE_DOWN, :MOVE_LEFT, :MOVE_RIGHT)
GW.get_object_names(env::FrozenLakeUndirected) = (:AGENT, :WALL, :GOAL, :HOLE)

function GW.get_pretty_tile_map(env::FrozenLakeUndirected, position::CartesianIndex{2})
characters = ('☻', '█', '♥', '○', '⋅')

object = findfirst(@view env.tile_map[:, position])
if isnothing(object)
return characters[end]
else
return characters[object]
end
end

function GW.get_pretty_sub_tile_map(env::FrozenLakeUndirected, window_size, position::CartesianIndex{2})
tile_map = env.tile_map
agent_position = env.agent_position

characters = ('☻', '█', '♥', '○', '⋅')

sub_tile_map = GW.get_sub_tile_map(tile_map, agent_position, window_size)

object = findfirst(@view sub_tile_map[:, position])
if isnothing(object)
return characters[end]
else
return characters[object]
end
end

function Base.show(io::IO, ::MIME"text/plain", env::FrozenLakeUndirected)
str = "tile_map:\n"
str = str * GW.get_pretty_tile_map(env)
str = str * "\nsub_tile_map:\n"
str = str * GW.get_pretty_sub_tile_map(env, GW.get_window_size(env))
str = str * "\nreward: $(env.reward)"
str = str * "\ndone: $(env.done)"
str = str * "\naction_names: $(GW.get_action_names(env))"
str = str * "\nobject_names: $(GW.get_object_names(env))"
print(io, str)
return nothing
end

GW.get_action_keys(env::FrozenLakeUndirected) = ('w', 's', 'a', 'd')

#####
##### FrozenLakeUndirected
#####

RLBase.StateStyle(env::GW.RLBaseEnv{E}) where {E <: FrozenLakeUndirected} = RLBase.InternalState{Any}()
RLBase.state_space(env::GW.RLBaseEnv{E}, ::RLBase.InternalState) where {E <: FrozenLakeUndirected} = nothing
RLBase.state(env::GW.RLBaseEnv{E}, ::RLBase.InternalState) where {E <: FrozenLakeUndirected} = env.env.tile_map

RLBase.reset!(env::GW.RLBaseEnv{E}) where {E <: FrozenLakeUndirected} = GW.reset!(env.env)

RLBase.action_space(env::GW.RLBaseEnv{E}) where {E <: FrozenLakeUndirected} = Base.OneTo(NUM_ACTIONS)
(env::GW.RLBaseEnv{E})(action) where {E <: FrozenLakeUndirected} = GW.act!(env.env, action)

RLBase.reward(env::GW.RLBaseEnv{E}) where {E <: FrozenLakeUndirected} = env.env.reward
RLBase.is_terminated(env::GW.RLBaseEnv{E}) where {E <: FrozenLakeUndirected} = env.env.done

end # module
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ get_terminal_returns(env::GW.RLBaseEnv{E}) where {E <: GW.CatcherModule.Catcher}
get_terminal_returns(env::GW.RLBaseEnv{E}) where {E <: GW.TransportUndirectedModule.TransportUndirected} = (env.env.terminal_reward,)
get_terminal_returns(env::GW.RLBaseEnv{E}) where {E <: GW.TransportDirectedModule.TransportDirected} = (env.env.env.terminal_reward,)

get_terminal_returns(env::GW.RLBaseEnv{E}) where {E <: GW.FrozenLakeUndirectedModule.FrozenLakeUndirected} = (env.env.terminal_reward, env.env.terminal_penalty)

function is_valid_terminal_return(env::GW.RLBaseEnv{E}, terminal_return) where {E <: GW.SnakeModule.Snake}
terminal_reward = env.env.terminal_reward
terminal_penalty = env.env.terminal_penalty
Expand Down