Skip to content

Commit b88f12b

Browse files
authored
add ElasticArray as container (#121)
1 parent 7c28ede commit b88f12b

File tree

5 files changed

+171
-0
lines changed

5 files changed

+171
-0
lines changed

src/ReinforcementLearningCore/Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
99
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
1010
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1111
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
12+
ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4"
1213
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1314
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1415
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"

src/ReinforcementLearningCore/src/components/trajectories/trajectory.jl

+130
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,18 @@ export Trajectory,
33
EpisodicTrajectory,
44
CombinedTrajectory,
55
CircularCompactSATrajectory,
6+
VectCompactSATrajectory,
7+
ElasticCompactSATrajectory,
68
CircularCompactSALTrajectory,
79
CircularCompactSARTSATrajectory,
10+
VectCompactSARTSATrajectory,
11+
ElasticCompactSARTSATrajectory,
812
CircularCompactPSARTSATrajectory,
913
CircularCompactSALRTSALTrajectory,
1014
CircularCompactPSALRTSALTrajectory
1115

1216
using MacroTools: @forward
17+
using ElasticArrays
1318

1419
#####
1520
# Trajectory
@@ -175,6 +180,37 @@ end
175180

176181
isfull(t::CombinedTrajectory) = isfull(t.t1) && isfull(t.t2)
177182

183+
#####
184+
# VectCompactSATrajectory
185+
#####
186+
187+
const VectCompactSATrajectory = CombinedTrajectory{
188+
<:SharedTrajectory{
189+
<:Vector,
190+
<:NamedTuple{(:state, :next_state, :full_state)},
191+
},
192+
<:SharedTrajectory{
193+
<:Vector,
194+
<:NamedTuple{(:action, :next_action, :full_action)},
195+
}
196+
}
197+
198+
function VectCompactSATrajectory(;
199+
state_type = Int,
200+
action_type = Int,
201+
)
202+
CombinedTrajectory(
203+
SharedTrajectory(
204+
Vector{state_type}(),
205+
:state,
206+
),
207+
SharedTrajectory(
208+
Vector{action_type}(),
209+
:action,
210+
),
211+
)
212+
end
213+
178214
#####
179215
# CircularCompactSATrajectory
180216
#####
@@ -209,6 +245,40 @@ function CircularCompactSATrajectory(;
209245
)
210246
end
211247

248+
#####
249+
# ElasticCompactSATrajectory
250+
#####
251+
252+
const ElasticCompactSATrajectory = CombinedTrajectory{
253+
<:SharedTrajectory{
254+
<:ElasticArray,
255+
<:NamedTuple{(:state, :next_state, :full_state)},
256+
},
257+
<:SharedTrajectory{
258+
<:ElasticArray,
259+
<:NamedTuple{(:action, :next_action, :full_action)},
260+
},
261+
}
262+
263+
function ElasticCompactSATrajectory(;
264+
state_type = Int,
265+
state_size = (),
266+
action_type = Int,
267+
action_size = (),
268+
)
269+
CombinedTrajectory(
270+
SharedTrajectory(
271+
ElasticArray{state_type}(undef, state_size..., 0),
272+
:state,
273+
),
274+
SharedTrajectory(
275+
ElasticArray{action_type}(undef, action_size..., 0),
276+
:action,
277+
),
278+
)
279+
end
280+
281+
212282
#####
213283
# CircularCompactSALTrajectory
214284
#####
@@ -240,6 +310,35 @@ function CircularCompactSALTrajectory(;
240310
CircularCompactSATrajectory(; capacity = capacity, kw...),
241311
)
242312
end
313+
314+
#####
315+
# VectCompactSARTSATrajectory
316+
#####
317+
318+
const VectCompactSARTSATrajectory = CombinedTrajectory{
319+
<:Trajectory{
320+
<:NamedTuple{
321+
(:reward, :terminal),
322+
<:Tuple{<:Vector,<:Vector},
323+
},
324+
},
325+
<:VectCompactSATrajectory,
326+
}
327+
328+
function VectCompactSARTSATrajectory(;
329+
reward_type = Float32,
330+
terminal_type = Bool,
331+
kw...,
332+
)
333+
CombinedTrajectory(
334+
Trajectory(
335+
reward = Vector{reward_type}(),
336+
terminal = Vector{terminal_type}(),
337+
),
338+
VectCompactSATrajectory(; kw...),
339+
)
340+
end
341+
243342
#####
244343
# CircularCompactSARTSATrajectory
245344
#####
@@ -271,6 +370,37 @@ function CircularCompactSARTSATrajectory(;
271370
)
272371
end
273372

373+
#####
374+
# ElasticCompactSARTSATrajectory
375+
#####
376+
377+
const ElasticCompactSARTSATrajectory = CombinedTrajectory{
378+
<:Trajectory{
379+
<:NamedTuple{
380+
(:reward, :terminal),
381+
<:Tuple{<:ElasticArray,<:ElasticArray},
382+
},
383+
},
384+
<:ElasticCompactSATrajectory,
385+
}
386+
387+
function ElasticCompactSARTSATrajectory(;
388+
reward_type = Float32,
389+
reward_size = (),
390+
terminal_type = Bool,
391+
terminal_size = (),
392+
kw...,
393+
)
394+
CombinedTrajectory(
395+
Trajectory(
396+
reward = ElasticArray{reward_type}(undef, reward_size..., 0),
397+
terminal = ElasticArray{terminal_type}(undef, terminal_size..., 0),
398+
),
399+
ElasticCompactSATrajectory(; kw...),
400+
)
401+
end
402+
403+
274404
#####
275405
# CircularCompactSALRTSALTrajectory
276406
#####
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using ElasticArrays
2+
3+
Base.push!(a::ElasticArray, x) = append!(a, x)
4+
Base.empty!(a::ElasticArray) = ElasticArrays.resize_lastdim!(A, 0)
5+
6+
function Base.pop!(a::ElasticArray)
7+
# ??? Is it safe to do so?
8+
last_frame = selectdim(a, ndims(a), size(a, ndims(a)))
9+
ElasticArrays.resize_lastdim!(A, size(a, ndims(a))-1)
10+
last_frame
11+
end

src/ReinforcementLearningCore/src/extensions/extensions.jl

+1
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ include("Flux.jl")
22
include("CUDA.jl")
33
include("Zygote.jl")
44
include("ReinforcementLearningBase.jl")
5+
include("ElasticArrays.jl")

src/ReinforcementLearningCore/test/components/trajectories.jl

+28
Original file line numberDiff line numberDiff line change
@@ -155,4 +155,32 @@
155155
@test t[:full_state] == []
156156
@test t[:full_action] == []
157157
end
158+
159+
@testset "VectCompactSARTSATrajectory" begin
160+
t = VectCompactSARTSATrajectory(;state_type=Vector{Float32}, action_type=Int, reward_type=Float32, terminal_type=Bool)
161+
push!(t; state=Float32[1,1], action=1)
162+
push!(t; reward=1f0, terminal=false, state=Float32[2,2], action=2)
163+
push!(t; reward=2f0, terminal=true, state=Float32[3,3], action=3)
164+
165+
@test t[:state] == [Float32[1,1], Float32[2,2]]
166+
@test t[:action] == [1,2]
167+
@test t[:reward] == [1f0,2f0]
168+
@test t[:terminal] == [false,true]
169+
@test t[:next_state] == [Float32[2,2], Float32[3,3]]
170+
@test t[:next_action] == [2,3]
171+
end
172+
173+
@testset "ElasticCompactSARTSATrajectory" begin
174+
t = ElasticCompactSARTSATrajectory(;state_type=Float32, state_size=(2,), action_type=Int, reward_type=Float32, terminal_type=Bool)
175+
push!(t; state=Float32[1,1], action=1)
176+
push!(t; reward=1f0, terminal=false, state=Float32[2,2], action=2)
177+
push!(t; reward=2f0, terminal=true, state=Float32[3,3], action=3)
178+
179+
@test t[:state] == Float32[1 2; 1 2]
180+
@test t[:action] == [1, 2]
181+
@test t[:reward] == [1f0, 2f0]
182+
@test t[:terminal] == [false, true]
183+
@test t[:next_state] == Float32[2 3; 2 3]
184+
@test t[:next_action] == [2, 3]
185+
end
158186
end

0 commit comments

Comments
 (0)