Skip to content

Commit 4c935e7

Browse files
Simplify Experiment code after dropping RLExperiment (#1044)
1 parent 5bde0ac commit 4c935e7

File tree

3 files changed

+35
-59
lines changed

3 files changed

+35
-59
lines changed

src/ReinforcementLearningCore/Project.toml

-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1414
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1515
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
1616
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
17-
Parsers = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
1817
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1918
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2019
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -37,7 +36,6 @@ Flux = "0.14"
3736
Functors = "0.1, 0.2, 0.3, 0.4"
3837
GPUArrays = "8, 9, 10"
3938
Metal = "1.0"
40-
Parsers = "2"
4139
ProgressMeter = "1"
4240
Reexport = "1"
4341
ReinforcementLearningBase = "0.12"

src/ReinforcementLearningCore/src/core/run.jl

+7-57
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,13 @@
1-
export @E_cmd, Experiment
1+
export Experiment
22

3-
4-
import Parsers
5-
6-
macro E_cmd(s)
7-
Experiment(s)
8-
end
9-
10-
function try_parse(s, TS=(Bool, Int, Float32, Float64))
11-
if s == "nothing"
12-
nothing
13-
else
14-
for T in TS
15-
res = Parsers.tryparse(T, s)
16-
if !isnothing(res)
17-
return res
18-
end
19-
end
20-
s
21-
end
22-
end
23-
24-
function try_parse_kw(s)
25-
kw = []
26-
# !!! obviously, it's not correct when a value is string and contains ","
27-
for part in split(s, ",")
28-
kv = split(part, "=")
29-
@assert length(kv) == 2
30-
k, v = kv
31-
push!(kw, Symbol(strip(k)) => try_parse(strip(v)))
32-
end
33-
NamedTuple(kw)
34-
end
35-
36-
struct Experiment{S}
37-
policy::Any
38-
env::Any
39-
stop_condition::Any
40-
hook::Any
41-
end
42-
43-
Experiment(args...) = Experiment{Symbol()}(args...)
44-
45-
function Experiment(s::String)
46-
m = match(r"(?<source>\w+)_(?<method>\w+)_(?<env>\w+)(\((?<game>.*)\))?", s)
47-
isnothing(m) && throw(
48-
ArgumentError(
49-
"invalid format, got $s, expected format is JuliaRL_DQN_Atari(game=\"pong\")`",
50-
),
51-
)
52-
source = m[:source]
53-
method = m[:method]
54-
env = m[:env]
55-
kw_args = isnothing(m[:game]) ? (;) : try_parse_kw(m[:game])
56-
ex = Experiment(Val(Symbol(source)), Val(Symbol(method)), Val(Symbol(env)); kw_args...)
57-
Experiment{Symbol(s)}(ex.policy, ex.env, ex.stop_condition, ex.hook)
3+
struct Experiment
4+
policy::AbstractPolicy
5+
env::AbstractEnv
6+
stop_condition::AbstractStopCondition
7+
hook::AbstractHook
588
end
599

60-
Base.show(io::IO, m::MIME"text/plain", t::Experiment{S}) where {S} = show(io, m, convert(AnnotatedStructTree, t; description=string(S)))
10+
Base.show(io::IO, m::MIME"text/plain", t::Experiment) = show(io, m, convert(AnnotatedStructTree, t))
6111

6212
function Base.run(ex::Experiment)
6313
run(ex.policy, ex.env, ex.stop_condition, ex.hook)

src/ReinforcementLearningCore/test/core/base.jl

+28
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,32 @@ using TimerOutputs
5555
run(agent, env, stop_condition, hook)
5656
@test RLCore.timer isa TimerOutputs.TimerOutput
5757
end
58+
59+
@testset "Experiment" begin
60+
# Create an instance of Experiment
61+
policy = Agent(
62+
RandomPolicy(),
63+
Trajectory(
64+
CircularArraySARTSTraces(; capacity = 1_000),
65+
BatchSampler(1),
66+
InsertSampleRatioController(n_inserted = -1),
67+
),
68+
)
69+
env = RandomWalk1D()
70+
stop_condition = StopAfterEpisode(10)
71+
hook = StepsPerEpisode()
72+
73+
exp = Experiment(policy, env, stop_condition, hook)
74+
75+
# Test that the fields are correctly assigned
76+
@test exp.policy === policy
77+
@test exp.env === env
78+
@test exp.stop_condition === stop_condition
79+
@test exp.hook === hook
80+
81+
# Test that the Experiment is callable
82+
run(exp)
83+
@test length(hook[]) == 10
84+
end
85+
5886
end

0 commit comments

Comments
 (0)