diff --git a/src/ReinforcementLearningCore/src/components/trajectories/trajectory.jl b/src/ReinforcementLearningCore/src/components/trajectories/trajectory.jl index b99a6afc4..4531bcce0 100644 --- a/src/ReinforcementLearningCore/src/components/trajectories/trajectory.jl +++ b/src/ReinforcementLearningCore/src/components/trajectories/trajectory.jl @@ -185,29 +185,14 @@ isfull(t::CombinedTrajectory) = isfull(t.t1) && isfull(t.t2) ##### const VectCompactSATrajectory = CombinedTrajectory{ - <:SharedTrajectory{ - <:Vector, - <:NamedTuple{(:state, :next_state, :full_state)}, - }, - <:SharedTrajectory{ - <:Vector, - <:NamedTuple{(:action, :next_action, :full_action)}, - } + <:SharedTrajectory{<:Vector,<:NamedTuple{(:state, :next_state, :full_state)}}, + <:SharedTrajectory{<:Vector,<:NamedTuple{(:action, :next_action, :full_action)}}, } -function VectCompactSATrajectory(; - state_type = Int, - action_type = Int, - ) +function VectCompactSATrajectory(; state_type = Int, action_type = Int) CombinedTrajectory( - SharedTrajectory( - Vector{state_type}(), - :state, - ), - SharedTrajectory( - Vector{action_type}(), - :action, - ), + SharedTrajectory(Vector{state_type}(), :state), + SharedTrajectory(Vector{action_type}(), :action), ) end @@ -250,14 +235,8 @@ end ##### const ElasticCompactSATrajectory = CombinedTrajectory{ - <:SharedTrajectory{ - <:ElasticArray, - <:NamedTuple{(:state, :next_state, :full_state)}, - }, - <:SharedTrajectory{ - <:ElasticArray, - <:NamedTuple{(:action, :next_action, :full_action)}, - }, + <:SharedTrajectory{<:ElasticArray,<:NamedTuple{(:state, :next_state, :full_state)}}, + <:SharedTrajectory{<:ElasticArray,<:NamedTuple{(:action, :next_action, :full_action)}}, } function ElasticCompactSATrajectory(; @@ -267,14 +246,8 @@ function ElasticCompactSATrajectory(; action_size = (), ) CombinedTrajectory( - SharedTrajectory( - ElasticArray{state_type}(undef, state_size..., 0), - :state, - ), - SharedTrajectory( - ElasticArray{action_type}(undef, action_size..., 0), - :action, - ), + SharedTrajectory(ElasticArray{state_type}(undef, state_size..., 0), :state), + SharedTrajectory(ElasticArray{action_type}(undef, action_size..., 0), :action), ) end @@ -316,25 +289,13 @@ end ##### const VectCompactSARTSATrajectory = CombinedTrajectory{ - <:Trajectory{ - <:NamedTuple{ - (:reward, :terminal), - <:Tuple{<:Vector,<:Vector}, - }, - }, + <:Trajectory{<:NamedTuple{(:reward, :terminal),<:Tuple{<:Vector,<:Vector}}}, <:VectCompactSATrajectory, } -function VectCompactSARTSATrajectory(; - reward_type = Float32, - terminal_type = Bool, - kw..., -) +function VectCompactSARTSATrajectory(; reward_type = Float32, terminal_type = Bool, kw...) CombinedTrajectory( - Trajectory( - reward = Vector{reward_type}(), - terminal = Vector{terminal_type}(), - ), + Trajectory(reward = Vector{reward_type}(), terminal = Vector{terminal_type}()), VectCompactSATrajectory(; kw...), ) end @@ -375,12 +336,7 @@ end ##### const ElasticCompactSARTSATrajectory = CombinedTrajectory{ - <:Trajectory{ - <:NamedTuple{ - (:reward, :terminal), - <:Tuple{<:ElasticArray,<:ElasticArray}, - }, - }, + <:Trajectory{<:NamedTuple{(:reward, :terminal),<:Tuple{<:ElasticArray,<:ElasticArray}}}, <:ElasticCompactSATrajectory, } diff --git a/src/ReinforcementLearningCore/src/extensions/ElasticArrays.jl b/src/ReinforcementLearningCore/src/extensions/ElasticArrays.jl index 5820b2bed..84f4d9b81 100644 --- a/src/ReinforcementLearningCore/src/extensions/ElasticArrays.jl +++ b/src/ReinforcementLearningCore/src/extensions/ElasticArrays.jl @@ -6,6 +6,6 @@ Base.empty!(a::ElasticArray) = ElasticArrays.resize_lastdim!(A, 0) function Base.pop!(a::ElasticArray) # ??? Is it safe to do so? last_frame = selectdim(a, ndims(a), size(a, ndims(a))) - ElasticArrays.resize_lastdim!(A, size(a, ndims(a))-1) + ElasticArrays.resize_lastdim!(A, size(a, ndims(a)) - 1) last_frame -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningCore/test/components/trajectories.jl b/src/ReinforcementLearningCore/test/components/trajectories.jl index 844a2b239..8c257a8fa 100644 --- a/src/ReinforcementLearningCore/test/components/trajectories.jl +++ b/src/ReinforcementLearningCore/test/components/trajectories.jl @@ -157,24 +157,35 @@ end @testset "VectCompactSARTSATrajectory" begin - t = VectCompactSARTSATrajectory(;state_type=Vector{Float32}, action_type=Int, reward_type=Float32, terminal_type=Bool) - push!(t; state=Float32[1,1], action=1) - push!(t; reward=1f0, terminal=false, state=Float32[2,2], action=2) - push!(t; reward=2f0, terminal=true, state=Float32[3,3], action=3) - - @test t[:state] == [Float32[1,1], Float32[2,2]] - @test t[:action] == [1,2] - @test t[:reward] == [1f0,2f0] - @test t[:terminal] == [false,true] - @test t[:next_state] == [Float32[2,2], Float32[3,3]] - @test t[:next_action] == [2,3] + t = VectCompactSARTSATrajectory(; + state_type = Vector{Float32}, + action_type = Int, + reward_type = Float32, + terminal_type = Bool, + ) + push!(t; state = Float32[1, 1], action = 1) + push!(t; reward = 1f0, terminal = false, state = Float32[2, 2], action = 2) + push!(t; reward = 2f0, terminal = true, state = Float32[3, 3], action = 3) + + @test t[:state] == [Float32[1, 1], Float32[2, 2]] + @test t[:action] == [1, 2] + @test t[:reward] == [1f0, 2f0] + @test t[:terminal] == [false, true] + @test t[:next_state] == [Float32[2, 2], Float32[3, 3]] + @test t[:next_action] == [2, 3] end @testset "ElasticCompactSARTSATrajectory" begin - t = ElasticCompactSARTSATrajectory(;state_type=Float32, state_size=(2,), action_type=Int, reward_type=Float32, terminal_type=Bool) - push!(t; state=Float32[1,1], action=1) - push!(t; reward=1f0, terminal=false, state=Float32[2,2], action=2) - push!(t; reward=2f0, terminal=true, state=Float32[3,3], action=3) + t = ElasticCompactSARTSATrajectory(; + state_type = Float32, + state_size = (2,), + action_type = Int, + reward_type = Float32, + terminal_type = Bool, + ) + push!(t; state = Float32[1, 1], action = 1) + push!(t; reward = 1f0, terminal = false, state = Float32[2, 2], action = 2) + push!(t; reward = 2f0, terminal = true, state = Float32[3, 3], action = 3) @test t[:state] == Float32[1 2; 1 2] @test t[:action] == [1, 2]