diff --git a/README.md b/README.md
index a68696d..7bc87e6 100644
--- a/README.md
+++ b/README.md
@@ -33,6 +33,7 @@ This package is inspired by [gym-minigrid](https://github.com/maximecb/gym-minig
1. [Catcher](#catcher)
1. [TransportUndirected](#transportundirected)
1. [TransportDirected](#transportdirected)
+1. [FrozenLakeUndirected](#frozenlakeundirected)
## Getting Started
@@ -355,3 +356,10 @@ In `ReinforcementLearning.jl`, you can create a [hook](https://juliareinforcemen
+
+1. ### FrozenLakeUndirected
+
+ The objective of the agent is to navigate its way to the goal while avoiding falling into the holes in the lake. When the agent reaches the goal, it receives a reward of 1 and the environment terminates. If the agent collides falls into a hole, the agent receives a reward of -1 and the environment terminates. The probablility of moving in the direction given by an agent is 1/3 while there is 1/3 chance to move in either perpendicular direction (for example: 1/3 chance to move up, 1/3 chance to move left and 1/3 chance to move right if the agent chose up). The scenario is based on the [Frozen Lake environment](https://gymnasium.farama.org/environments/toy_text/frozen_lake/) in Python's gymnasium. In the Python version there are two preset maps: "4x4" and "8x8". The GridWorlds implementation includes the walls as part of the dimensions, so the equivalent maps in GridWorlds is "6x6" and "10x10" respectively. The start, goal, and holes are located in the same positions in the lake as the Python version. If specifying custom height and widths keep in mind it is going to add walls all around the map so the actual surface of the lake is (height - 2, width - 2).
+
+
+
diff --git a/src/envs/envs.jl b/src/envs/envs.jl
index 0c90883..ea20f14 100644
--- a/src/envs/envs.jl
+++ b/src/envs/envs.jl
@@ -21,6 +21,7 @@ include("snake.jl")
include("catcher.jl")
include("transport_undirected.jl")
include("transport_directed.jl")
+include("frozen_lake_undirected.jl")
const ENVS = [
SingleRoomUndirectedModule.SingleRoomUndirected,
@@ -46,4 +47,5 @@ const ENVS = [
CatcherModule.Catcher,
TransportUndirectedModule.TransportUndirected,
TransportDirectedModule.TransportDirected,
+ FrozenLakeUndirectedModule.FrozenLakeUndirected
]
diff --git a/src/envs/frozen_lake_undirected.jl b/src/envs/frozen_lake_undirected.jl
new file mode 100644
index 0000000..89161fa
--- /dev/null
+++ b/src/envs/frozen_lake_undirected.jl
@@ -0,0 +1,241 @@
+module FrozenLakeUndirectedModule
+
+import ..GridWorlds as GW
+import Random
+import ReinforcementLearningBase as RLBase
+import AStarSearch.astar
+
+#####
+##### game logic
+#####
+
+const NUM_OBJECTS = 4
+const AGENT = 1
+const WALL = 2
+const GOAL = 3
+const HOLE = 4
+const NUM_ACTIONS = 4
+
+mutable struct FrozenLakeUndirected{R, RNG} <: GW.AbstractGridWorld
+ tile_map::BitArray{3}
+ map_name::String
+ agent_position::CartesianIndex{2}
+ reward::R
+ rng::RNG
+ done::Bool
+ terminal_reward::R
+ terminal_penalty::R
+ goal_position::CartesianIndex{2}
+ num_holes::Int
+ hole_positions::Vector{CartesianIndex{2}}
+ is_slippery::Bool
+ randomize_start_end::Bool
+end
+
+function FrozenLakeUndirected(; map_name::String = "", R::Type = Float32, height::Int = 8, width::Int = 8, num_holes::Int = floor(Int, sqrt(height * width) / 2), rng = Random.GLOBAL_RNG, is_slippery::Bool = true, randomize_start_end::Bool = false)
+ hole_positions = Array{CartesianIndex{2}}(undef, num_holes)
+ if map_name == "6x6"
+ height = 6
+ width = 6
+ num_holes = 4
+ hole_positions = Array{CartesianIndex{2}}(undef, num_holes)
+ hole_positions = [CartesianIndex(3, 3), CartesianIndex(3, 5), CartesianIndex(4, 5), CartesianIndex(5, 2)]
+ elseif map_name == "10x10"
+ height = 10
+ width = 10
+ num_holes = 10
+ hole_positions = Array{CartesianIndex{2}}(undef, num_holes)
+ hole_positions = [CartesianIndex(4, 5), CartesianIndex(5, 7), CartesianIndex(6, 5), CartesianIndex(7, 3), CartesianIndex(7, 4), CartesianIndex(7, 8), CartesianIndex(8, 3), CartesianIndex(8, 6), CartesianIndex(8, 8), CartesianIndex(9, 5)]
+ elseif map_name != ""
+ throw(ArgumentError("Unsupported map_name value: '$(map_name)'. Please use '6x6', '10x10', or undefined."))
+ end
+
+ tile_map = falses(NUM_OBJECTS, height, width)
+
+ tile_map[WALL, 1, :] .= true
+ tile_map[WALL, height, :] .= true
+ tile_map[WALL, :, 1] .= true
+ tile_map[WALL, :, width] .= true
+
+ if randomize_start_end
+ GW.sample_two_positions_without_replacement(rng, tile_map[2:height - 1, 2:width - 1])
+ else
+ agent_position = CartesianIndex(2, 2)
+ goal_position = CartesianIndex(height - 1, width - 1)
+ end
+
+ tile_map[AGENT, agent_position] = true
+ tile_map[GOAL, goal_position] = true
+
+ if map_name == ""
+ function get_neighbors(state::CartesianIndex)
+ return_list = []
+ for pos in (GW.move_up(state), GW.move_down(state), GW.move_left(state), GW.move_right(state))
+ if !tile_map[WALL, pos] && !tile_map[HOLE, pos]
+ push!(return_list, pos)
+ end
+ end
+ return return_list
+ end
+
+ manhattan(a::CartesianIndex, b::CartesianIndex) = sum(abs.((b-a).I))
+ is_goal(state::CartesianIndex, end_state::CartesianIndex) = state == end_state
+
+ distance_heuristic(state::CartesianIndex, end_state::CartesianIndex) = manhattan(state, end_state)
+
+ path_exists = false
+ hole_positions = Array{CartesianIndex{2}}(undef, num_holes)
+ while !path_exists
+ for i in 1:num_holes
+ hole_position = GW.sample_empty_position(rng, tile_map)
+ hole_positions[i] = hole_position
+ end
+ tile_map = update_holes_on_map(tile_map, hole_positions)
+
+ path_exists = astar(get_neighbors, agent_position, goal_position; heuristic = distance_heuristic, isgoal = is_goal).status == :success
+ @debug "path_exists: $(path_exists)"
+ end
+ end
+
+ tile_map = update_holes_on_map(tile_map, hole_positions)
+
+ reward = zero(R)
+ done = false
+ terminal_reward = one(R)
+ terminal_penalty = -one(R)
+
+ env = FrozenLakeUndirected(tile_map, map_name, agent_position, reward, rng, done, terminal_reward, terminal_penalty, goal_position, num_holes, hole_positions, is_slippery, randomize_start_end)
+
+ return env
+end
+
+function update_holes_on_map(tile_map, hole_positions)
+ for position in hole_positions
+ tile_map[HOLE, position] = true
+ end
+ return tile_map
+end
+
+function GW.reset!(env::FrozenLakeUndirected)
+ tile_map = env.tile_map
+
+ tile_map[AGENT, env.agent_position] = false
+
+ agent_position = CartesianIndex(2, 2)
+ env.agent_position = agent_position
+ tile_map[AGENT, agent_position] = true
+
+ env.reward = zero(env.reward)
+ env.done = false
+
+ return nothing
+end
+
+function GW.act!(env::FrozenLakeUndirected, action)
+ @assert action in Base.OneTo(NUM_ACTIONS) "Invalid action $(action). Action must be in Base.OneTo($(NUM_ACTIONS))"
+
+ tile_map = env.tile_map
+
+ agent_position = env.agent_position
+
+ is_slippery = env.is_slippery
+
+ if action == 1
+ new_agent_position = is_slippery ? rand((GW.move_up(agent_position), GW.move_left(agent_position), GW.move_right(agent_position))) : GW.move_up(agent_position)
+ elseif action == 2
+ new_agent_position = is_slippery ? rand((GW.move_down(agent_position), GW.move_left(agent_position), GW.move_right(agent_position))) : GW.move_down(agent_position)
+ elseif action == 3
+ new_agent_position = is_slippery ? rand((GW.move_left(agent_position), GW.move_up(agent_position), GW.move_down(agent_position))) : GW.move_left(agent_position)
+ else
+ new_agent_position = is_slippery ? rand((GW.move_right(agent_position), GW.move_up(agent_position), GW.move_down(agent_position))) : GW.move_right(agent_position)
+ end
+
+ if !tile_map[WALL, new_agent_position]
+ tile_map[AGENT, agent_position] = false
+ env.agent_position = new_agent_position
+ tile_map[AGENT, new_agent_position] = true
+ end
+
+ if tile_map[GOAL, env.agent_position]
+ env.reward = env.terminal_reward
+ env.done = true
+ elseif tile_map[HOLE, env.agent_position]
+ env.reward = env.terminal_penalty
+ env.done = true
+ else
+ env.reward = zero(env.reward)
+ env.done = false
+ end
+
+ return nothing
+end
+
+#####
+##### miscellaneous
+#####
+
+GW.get_height(env::FrozenLakeUndirected) = size(env.tile_map, 2)
+GW.get_width(env::FrozenLakeUndirected) = size(env.tile_map, 3)
+
+GW.get_action_names(env::FrozenLakeUndirected) = (:MOVE_UP, :MOVE_DOWN, :MOVE_LEFT, :MOVE_RIGHT)
+GW.get_object_names(env::FrozenLakeUndirected) = (:AGENT, :WALL, :GOAL, :HOLE)
+
+function GW.get_pretty_tile_map(env::FrozenLakeUndirected, position::CartesianIndex{2})
+ characters = ('☻', '█', '♥', '○', '⋅')
+
+ object = findfirst(@view env.tile_map[:, position])
+ if isnothing(object)
+ return characters[end]
+ else
+ return characters[object]
+ end
+end
+
+function GW.get_pretty_sub_tile_map(env::FrozenLakeUndirected, window_size, position::CartesianIndex{2})
+ tile_map = env.tile_map
+ agent_position = env.agent_position
+
+ characters = ('☻', '█', '♥', '○', '⋅')
+
+ sub_tile_map = GW.get_sub_tile_map(tile_map, agent_position, window_size)
+
+ object = findfirst(@view sub_tile_map[:, position])
+ if isnothing(object)
+ return characters[end]
+ else
+ return characters[object]
+ end
+end
+
+function Base.show(io::IO, ::MIME"text/plain", env::FrozenLakeUndirected)
+ str = "tile_map:\n"
+ str = str * GW.get_pretty_tile_map(env)
+ str = str * "\nsub_tile_map:\n"
+ str = str * GW.get_pretty_sub_tile_map(env, GW.get_window_size(env))
+ str = str * "\nreward: $(env.reward)"
+ str = str * "\ndone: $(env.done)"
+ str = str * "\naction_names: $(GW.get_action_names(env))"
+ str = str * "\nobject_names: $(GW.get_object_names(env))"
+ print(io, str)
+ return nothing
+end
+
+GW.get_action_keys(env::FrozenLakeUndirected) = ('w', 's', 'a', 'd')
+
+#####
+##### FrozenLakeUndirected
+#####
+
+RLBase.StateStyle(env::GW.RLBaseEnv{E}) where {E <: FrozenLakeUndirected} = RLBase.InternalState{Any}()
+RLBase.state_space(env::GW.RLBaseEnv{E}, ::RLBase.InternalState) where {E <: FrozenLakeUndirected} = nothing
+RLBase.state(env::GW.RLBaseEnv{E}, ::RLBase.InternalState) where {E <: FrozenLakeUndirected} = env.env.tile_map
+
+RLBase.reset!(env::GW.RLBaseEnv{E}) where {E <: FrozenLakeUndirected} = GW.reset!(env.env)
+
+RLBase.action_space(env::GW.RLBaseEnv{E}) where {E <: FrozenLakeUndirected} = Base.OneTo(NUM_ACTIONS)
+(env::GW.RLBaseEnv{E})(action) where {E <: FrozenLakeUndirected} = GW.act!(env.env, action)
+
+RLBase.reward(env::GW.RLBaseEnv{E}) where {E <: FrozenLakeUndirected} = env.env.reward
+RLBase.is_terminated(env::GW.RLBaseEnv{E}) where {E <: FrozenLakeUndirected} = env.env.done
+
+end # module
diff --git a/test/runtests.jl b/test/runtests.jl
index 2596d53..5168afe 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -39,6 +39,8 @@ get_terminal_returns(env::GW.RLBaseEnv{E}) where {E <: GW.CatcherModule.Catcher}
get_terminal_returns(env::GW.RLBaseEnv{E}) where {E <: GW.TransportUndirectedModule.TransportUndirected} = (env.env.terminal_reward,)
get_terminal_returns(env::GW.RLBaseEnv{E}) where {E <: GW.TransportDirectedModule.TransportDirected} = (env.env.env.terminal_reward,)
+get_terminal_returns(env::GW.RLBaseEnv{E}) where {E <: GW.FrozenLakeUndirectedModule.FrozenLakeUndirected} = (env.env.terminal_reward, env.env.terminal_penalty)
+
function is_valid_terminal_return(env::GW.RLBaseEnv{E}, terminal_return) where {E <: GW.SnakeModule.Snake}
terminal_reward = env.env.terminal_reward
terminal_penalty = env.env.terminal_penalty