Skip to content
This repository has been archived by the owner on May 6, 2021. It is now read-only.

Commit

Permalink
Fix SLART type (#224)
Browse files Browse the repository at this point in the history
* Fix SLART type

* Add test case for SLART

* Fix order of CircularArraySLARTTrajectory arguments

* Remove dumb line
  • Loading branch information
ilancoulon authored Apr 1, 2021
1 parent e2c5152 commit 706e534
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/policies/agents/trajectories/abstract_trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ end
const SART = (:state, :action, :reward, :terminal)
const SARTS = (:state, :action, :reward, :terminal, :next_state)
const SARTSA = (:state, :action, :reward, :terminal, :next_state, :next_action)
const SLART = (:state, :legal_actions_mask, :action, :reward, :terminal, :next_state)
const SLART = (:state, :legal_actions_mask, :action, :reward, :terminal)
const SLARTSL = (
:state,
:legal_actions_mask,
Expand Down
4 changes: 2 additions & 2 deletions src/policies/agents/trajectories/trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,16 @@ const CircularArraySLARTTrajectory = Trajectory{
CircularArraySLARTTrajectory(;
capacity::Int,
state = Int => (),
action = Int => (),
legal_actions_mask,
action = Int => (),
reward = Float32 => (),
terminal = Bool => (),
) = merge(
CircularArrayTrajectory(;
capacity = capacity + 1,
state = state,
action = action,
legal_actions_mask = legal_actions_mask,
action = action,
),
CircularArrayTrajectory(; capacity = capacity, reward = reward, terminal = terminal),
)
Expand Down
11 changes: 11 additions & 0 deletions test/components/trajectories.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,17 @@
@test t[:reward] == [2, 3, 4]
end

@testset "CircularArraySLARTTrajectory" begin
t = CircularArraySLARTTrajectory(
capacity = 3,
state = Matrix{Float32} => (2,2),
legal_actions_mask = Vector{Bool} => (4, ),
)

# test instance type is same as type
@test isa(t, CircularArraySLARTTrajectory)
end

@testset "ReservoirTrajectory" begin
# test length
t = ReservoirTrajectory(3; a = Array{Float64,2}, b = Bool)
Expand Down

0 comments on commit 706e534

Please sign in to comment.