diff --git a/src/actions.jl b/src/actions.jl index 33006bd..5f30c21 100644 --- a/src/actions.jl +++ b/src/actions.jl @@ -24,3 +24,9 @@ const TURN_LEFT = TurnLeft() (x::TurnLeft)(::Up) = LEFT (x::TurnLeft)(::Right) = UP (x::TurnLeft)(::Down) = RIGHT + +struct Pickup <: AbstractGridWorldAction end +const PICK_UP = Pickup() + +struct Drop <: AbstractGridWorldAction end +const DROP = Drop() diff --git a/src/envs/doorkey.jl b/src/envs/doorkey.jl index e3bbed6..8228c96 100644 --- a/src/envs/doorkey.jl +++ b/src/envs/doorkey.jl @@ -6,7 +6,6 @@ mutable struct DoorKey{W<:GridWorldBase} <: AbstractGridWorld world::W agent_pos::CartesianIndex{2} agent::Agent - has_key::Bool end function DoorKey(;n=8, agent_start_pos=CartesianIndex(2,2), rng=Random.GLOBAL_RNG) @@ -32,7 +31,7 @@ function DoorKey(;n=8, agent_start_pos=CartesianIndex(2,2), rng=Random.GLOBAL_RN world[EMPTY, key_pos] = false world[Key(:yellow), key_pos] = true - DoorKey(world, agent_start_pos, Agent(dir=RIGHT),false) + DoorKey(world, agent_start_pos, Agent(;dir=RIGHT)) end function (w::DoorKey)(::MoveForward) @@ -40,14 +39,15 @@ function (w::DoorKey)(::MoveForward) dest = dir(w.agent_pos) if w.world[Key(:yellow), dest] - w.has_key = true - w.world[Key(:yellow), dest] = false - w.world[EMPTY, dest] = true + if PICK_UP(w.agent, Key(:yellow)) + w.world[Key(:yellow), dest] = false + w.world[EMPTY, dest] = true + end w.agent_pos = dest - elseif w.world[Door(:yellow), dest] && w.has_key - w.agent_pos = dest - elseif w.world[Door(:yellow), dest] && !w.has_key + elseif w.world[Door(:yellow), dest] && w.agent.inventory !== Key(:yellow) nothing + elseif w.world[Door(:yellow), dest] && w.agent.inventory === Key(:yellow) + w.agent_pos = dest elseif dest ∈ CartesianIndices((size(w.world, 2), size(w.world, 3))) && !w.world[WALL,dest] w.agent_pos = dest end diff --git a/src/objects.jl b/src/objects.jl index bf7e6b9..1fb947e 100644 --- a/src/objects.jl +++ b/src/objects.jl @@ -50,17 +50,22 @@ const OBSTACLE = Obstacle() Base.convert(::Type{Char}, ::Obstacle) = '⊗' get_color(::Obstacle) = :blue +##### +# Agent +##### + Base.@kwdef mutable struct Agent <: AbstractObject color::Symbol=:red dir::LRUD + inventory::Union{Nothing, AbstractObject, Vector}=nothing end function Base.convert(::Type{Char}, a::Agent) - if a.dir === UP + if a.dir === UP '↑' - elseif a.dir === DOWN + elseif a.dir === DOWN '↓' - elseif a.dir === LEFT + elseif a.dir === LEFT '←' elseif a.dir === RIGHT '→' @@ -70,3 +75,51 @@ end get_color(a::Agent) = a.color get_dir(a::Agent) = a.dir set_dir!(a::Agent, d) = a.dir = d + +struct Transportable end +struct Nontransportable end +const TRANSPORTABLE = Transportable() +const NONTRANSPORTABLE = Nontransportable() +istransportable(::Type{<:Key}) = TRANSPORTABLE +istransportable(::Type{Gem}) = TRANSPORTABLE +istransportable(x::AbstractObject) = istransportable(typeof(x)) + +(x::Pickup)(a::Agent, o) = x(istransportable(o), a, o) + +function (::Pickup)(::Transportable, a::Agent, o::AbstractObject) + if isnothing(a.inventory) + a.inventory = o + true + elseif a.inventory isa Vector + i = findfirst(isnothing, a.v) + if isnothing(i) + false + else + a.inventory[i] = o + true + end + else + false + end +end + +function (::Drop)(a::Agent) + if isnothing(a.inventory) + nothing + elseif a.inventory isa AbstractObject + x = a.inventory + a.inventory = nothing + x + elseif a.inventory isa Vector + i = findlast(x -> x isa AbstractObject, a.inventory) + if isnothing(i) + nothing + else + x = a.inventory[i] + a.inventory[i] = nothing + x + end + else + @error "unknown inventory type $(a.inventory)" + end +end