Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add batch environment #146

Draft
wants to merge 28 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e7f4a8a
add StatsBase package
Sid-Bhatia-0 Jun 15, 2021
fd659a7
add SingleRoomUndirectedBatch
Sid-Bhatia-0 Jun 15, 2021
0b8c1cb
add playability to SingleRoomUndirectedBatch
Sid-Bhatia-0 Jun 16, 2021
1bb2641
add keyword force to RLBase.reset! method
Sid-Bhatia-0 Jun 16, 2021
0d115af
fix replay method
Sid-Bhatia-0 Jun 16, 2021
d39fae7
write characters to terminal out while playing
Sid-Bhatia-0 Jun 16, 2021
20d103f
ignore scratchpad.jl
Sid-Bhatia-0 Jun 16, 2021
4c16b20
add tests for SingleRoomUndirectedBatch
Sid-Bhatia-0 Jun 16, 2021
342deb0
set state style to internal state, copy reward & done
Sid-Bhatia-0 Jun 17, 2021
4926c99
add benchmark_multi_threaded.jl
Sid-Bhatia-0 Jun 17, 2021
7824551
print NUM_ENVS
Sid-Bhatia-0 Jun 17, 2021
d711bcc
move num_envs to the last dimension
Sid-Bhatia-0 Jun 21, 2021
2d68a0a
update tests for batch envs
Sid-Bhatia-0 Jun 21, 2021
3dbd497
don't copy tile_map, reward, and done in RLBase API
Sid-Bhatia-0 Jun 21, 2021
8ec31a0
remove unnecessary RLBase.DefaultPlayer
Sid-Bhatia-0 Jun 21, 2021
73406ea
rename benchmark_multi_threaded.jl to benchmark_batch.jl
Sid-Bhatia-0 Jun 21, 2021
45bf86a
fix and cleanup benchmark_batch
Sid-Bhatia-0 Jun 21, 2021
ed2b37b
make move function type stable (huge improvement in performance)
Sid-Bhatia-0 Jun 21, 2021
2ceedad
add function sample_two_positions_without_replacement
Sid-Bhatia-0 Jun 21, 2021
2426955
add DataStructures package in benchmarking code
Sid-Bhatia-0 Jun 24, 2021
cda573e
add ACTION_NAMES in ModuleSingleRoomUndirectedBatch
Sid-Bhatia-0 Jun 24, 2021
26d1231
refactor benchmark_batch.jl
Sid-Bhatia-0 Jun 24, 2021
bb048cc
rename benchmark_batch.jl to benchmark_utils.jl
Sid-Bhatia-0 Jun 24, 2021
e8854e2
add SingleRoomUndirected
Sid-Bhatia-0 Jun 24, 2021
9f3b1bd
ignore generated benchmark files
Sid-Bhatia-0 Jun 24, 2021
b730958
add benchmarking for non-batch envs
Sid-Bhatia-0 Jun 24, 2021
e2f313f
make SingleRoomUndirected mutable and improve performance
Sid-Bhatia-0 Jun 24, 2021
6d114c6
remove constants NUM_RESETS, STEPS_PER_EPISODE, NUM_ENVS
Sid-Bhatia-0 Jun 24, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ Manifest.toml
# vim temporary files
*~
*.swp

/src/scratchpad.jl

/benchmark/20*
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
Crayons = "4.0"
Expand Down
1 change: 1 addition & 0 deletions benchmark/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
GridWorlds = "e15a9946-cd7f-4d03-83e2-6c30bacb0043"
Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
Expand Down
219 changes: 219 additions & 0 deletions benchmark/benchmark_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import BenchmarkTools as BT
import DataStructures as DS
import Dates
import GridWorlds as GW
import ReinforcementLearningBase as RLBase
import Statistics

ENVS = [GW.ModuleSingleRoomUndirected.SingleRoomUndirected]
BATCH_ENVS = [GW.ModuleSingleRoomUndirectedBatch.SingleRoomUndirectedBatch]

function run_random_policy_env!(env, num_resets, steps_per_reset)
for _ in 1:num_resets
RLBase.reset!(env)
for _ in 1:steps_per_reset
state = RLBase.state(env)
action = rand(RLBase.action_space(env))
env(action)
is_terminated = RLBase.is_terminated(env)
reward = RLBase.reward(env)
end
end

return nothing
end

function run_random_policy_batch_env!(env, num_resets, steps_per_reset)
num_envs = size(env.tile_map, 4)
action = Array{eltype(RLBase.action_space(env))}(undef, num_envs)
for _ in 1:num_resets
RLBase.reset!(env, force = true)
for _ in 1:steps_per_reset
state = RLBase.state(env)
for i in 1:num_envs
action[i] = rand(RLBase.action_space(env))
end
env(action)
is_terminated = RLBase.is_terminated(env)
reward = RLBase.reward(env)
end
end

return nothing
end

# function compile_envs(Envs, num_resets, steps_per_reset)
# for Env in Envs
# env = Env()
# run_random_policy!(env, num_resets, steps_per_reset)
# end

# @info "Compiled and ran all environments"

# return nothing
# end

function benchmark_env(Env, num_resets, steps_per_reset)
benchmark = DS.OrderedDict()

parent_module = parentmodule(Env)

env = Env()

benchmark[:random_policy] = BT.@benchmark run_random_policy_env!($(Ref(env))[], $(Ref(num_resets))[], $(Ref(steps_per_reset))[])
benchmark[:reset] = BT.@benchmark RLBase.reset!($(Ref(env))[])
benchmark[:state] = BT.@benchmark RLBase.state($(Ref(env))[])

for action in RLBase.action_space(env)
action_name = parent_module.ACTION_NAMES[action]
benchmark[action_name] = BT.@benchmark $(Ref(env))[]($(Ref(action))[])
end

benchmark[:action_space] = BT.@benchmark RLBase.action_space($(Ref(env))[])
benchmark[:is_terminated] = BT.@benchmark RLBase.is_terminated($(Ref(env))[])
benchmark[:reward] = BT.@benchmark RLBase.reward($(Ref(env))[])

@info "$(nameof(Env)) benchmarked"

return benchmark
end

function benchmark_batch_env(Env, num_resets, steps_per_reset, num_envs)
benchmark = DS.OrderedDict()

parent_module = parentmodule(Env)

env = Env(num_envs = num_envs)

benchmark[:random_policy] = BT.@benchmark run_random_policy_batch_env!($(Ref(env))[], $(Ref(num_resets))[], $(Ref(steps_per_reset))[])
benchmark[:reset] = BT.@benchmark RLBase.reset!($(Ref(env))[], force = true)
benchmark[:state] = BT.@benchmark RLBase.state($(Ref(env))[])

for action in RLBase.action_space(env)
action_name = parent_module.ACTION_NAMES[action]
batch_action = fill(action, num_envs)
benchmark[action_name] = BT.@benchmark $(Ref(env))[]($(Ref(batch_action))[])
end

benchmark[:action_space] = BT.@benchmark RLBase.action_space($(Ref(env))[])
benchmark[:is_terminated] = BT.@benchmark RLBase.is_terminated($(Ref(env))[])
benchmark[:reward] = BT.@benchmark RLBase.reward($(Ref(env))[])

@info "$(nameof(Env)) benchmarked"

return benchmark
end

function benchmark_envs(Envs, num_resets, steps_per_reset)
benchmarks = DS.OrderedDict()

for Env in Envs
benchmarks[nameof(Env)] = benchmark_env(Env, num_resets, steps_per_reset)
end

@info "benchmark_envs complete"

return benchmarks
end

function benchmark_batch_envs(Envs, num_resets, steps_per_reset, num_envs)
benchmarks = DS.OrderedDict()

for Env in Envs
benchmarks[nameof(Env)] = benchmark_batch_env(Env, num_resets, steps_per_reset, num_envs)
end

@info "benchmark_batch_envs complete"

return benchmarks
end

function get_summary(trial::BT.Trial)
median_trial = BT.median(trial)
memory = BT.prettymemory(median_trial.memory)
median_time = BT.prettytime(median_trial.time)
return memory, median_time
end

function get_table(benchmark)
title = "|"
separator = "|"
data = "|"

for key in keys(benchmark)
title = title * String(key) * "|"
separator = separator * ":---:|"
memory, median_time = get_summary(benchmark[key])
data = data * "$(memory)<br>$(median_time)|"
end

return title, separator, data
end

function generate_benchmark_file(benchmarks; file_name = nothing)
date = Dates.format(Dates.now(), "yyyy_mm_dd_HH_MM_SS")

if isnothing(file_name)
file_name = date * ".md"
end

io = open(file_name, "w")

println(io, "Date: $(date)")
println(io, "## List of Environments")

for key in keys(benchmarks)
name_string = String(key)
println(io, " 1. [$(name_string)](#$(lowercase(name_string)))")
end

println(io)

for key in keys(benchmarks)
println(io, "### " * String(key))
title, separator, data = get_table(benchmarks[key])
println(io, title)
println(io, separator)
println(io, data)
println(io)
end

close(io)

return nothing
end

# function generate_benchmark_file_batch_envs(Envs, num_resets, steps_per_reset, num_envs; file_name = nothing)
# date = Dates.format(Dates.now(), "yyyy_mm_dd_HH_MM_SS")

# if isnothing(file_name)
# file_name = date * ".md"
# end

# io = open(file_name, "w")

# benchmarks = benchmark_batch_envs(Envs, num_resets, steps_per_reset, num_envs)

# println(io, "Date: $(date)")
# println(io, "## List of Environments")

# for Env in Envs
# name_string = String(nameof(Env))
# println(io, " 1. [$(name_string)](#$(lowercase(name_string)))")
# end

# println(io)

# for key in keys(benchmarks)
# println(io, "### " * String(key))
# title, separator, data = get_table(benchmarks[key])
# println(io, title)
# println(io, separator)
# println(io, data)
# println(io)
# end

# close(io)

# return nothing
# end
1 change: 1 addition & 0 deletions src/GridWorlds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ include("actions.jl")
include("objects.jl")
include("grid_world_base.jl")
include("abstract_grid_world.jl")
include("play.jl")
include("envs/envs.jl")
include("textual_rendering.jl")

Expand Down
2 changes: 2 additions & 0 deletions src/envs/envs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,5 @@ include("snake.jl")
include("catcher.jl")
include("transport.jl")
include("collect_gems_undirected_multi_agent.jl")
include("single_room_undirected_batch.jl")
include("single_room_undirected.jl")
Loading