From 34d9bb7dbbb33cce286dd25ac4d2150577d55362 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Mon, 26 Apr 2021 09:34:43 +0800 Subject: [PATCH] remove datasets --- Project.toml | 6 - src/TFRecord.jl | 1 - src/datasets/datasets.jl | 92 --------------- src/datasets/rl_unplugged/atari.jl | 130 ---------------------- src/datasets/rl_unplugged/rl_unplugged.jl | 1 - 5 files changed, 230 deletions(-) delete mode 100644 src/datasets/datasets.jl delete mode 100644 src/datasets/rl_unplugged/atari.jl delete mode 100644 src/datasets/rl_unplugged/rl_unplugged.jl diff --git a/Project.toml b/Project.toml index 6f3f5d2..b37e4f7 100644 --- a/Project.toml +++ b/Project.toml @@ -7,21 +7,15 @@ version = "0.2.2" BufferedStreams = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" CRC32c = "8bf52ea8-c179-5cab-976a-9e18b702a9bc" CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193" -ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -PNGFiles = "f57f5aa1-a3ce-4bc8-8ab9-96f992907883" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" -ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" ProtoBuf = "3349acd9-ac6a-5e09-bcdb-63829b23a429" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] BufferedStreams = "1.0" CodecZlib = "0.7" -ImageCore = "0.8" MacroTools = "0.5" -PNGFiles = "0.3" -ProgressMeter = "1" ProtoBuf = "0.10" julia = "1.5" diff --git a/src/TFRecord.jl b/src/TFRecord.jl index f6b96f8..66d3dee 100644 --- a/src/TFRecord.jl +++ b/src/TFRecord.jl @@ -2,6 +2,5 @@ module TFRecord include("jlout/example_pb.jl") include("core.jl") -include("datasets/datasets.jl") end diff --git a/src/datasets/datasets.jl b/src/datasets/datasets.jl deleted file mode 100644 index 28278a2..0000000 --- a/src/datasets/datasets.jl +++ /dev/null @@ -1,92 +0,0 @@ -module Datasets - -export buffered_shuffle - -using Random -using Base.Iterators -using ProgressMeter - -##### -# BufferedShuffle -##### - -struct BufferedShuffle{T, R<:AbstractRNG} <: AbstractChannel{T} - src::Channel{T} - buffer::Vector{T} - rng::R -end - -function buffered_shuffle(src::Channel{T}, buffer_size;rng=Random.GLOBAL_RNG) where T - buffer = Array{T}(undef, buffer_size) - p = Progress(buffer_size) - Threads.@threads for i in 1:buffer_size - buffer[i] = take!(src) - next!(p) - end - BufferedShuffle(src, buffer, rng) -end - -Base.close(b::BufferedShuffle) = close(b.src) - -function Base.take!(b::BufferedShuffle) - if length(b.buffer) == 0 - throw(InvalidStateException("buffer is empty", :empty)) - else - i = rand(b.rng, 1:length(b.buffer)) - res = b.buffer[i] - if isopen(b.src) - b.buffer[i] = popfirst!(b.src) - else - deleteat!(b.buffer, i) - end - res - end -end - -function Base.iterate(b::BufferedShuffle, state=nothing) - try - return (popfirst!(b), nothing) - catch e - if isa(e, InvalidStateException) && e.state === :empty - return nothing - else - rethrow() - end - end -end - -##### -# RingBuffer -##### - -mutable struct RingBuffer{T} <: AbstractChannel{T} - buffers::Channel{T} - current::T - results::Channel{T} -end - -Base.close(b::RingBuffer) = close(b.buffers) # will propergate to b.results - -function RingBuffer(f!, buffer::T; sz = Threads.nthreads(), taskref = nothing) where {T} - buffers = Channel{T}(sz) - for _ in 1:sz - put!(buffers, deepcopy(buffer)) - end - results = Channel{T}(sz, spawn = true, taskref = taskref) do ch - Threads.@threads :static for x in buffers - f!(x) - put!(ch, x) - end - end - RingBuffer(buffers, buffer, results) -end - -function Base.take!(b::RingBuffer) - put!(b.buffers, b.current) - b.current = take!(b.results) - b.current -end - -include("rl_unplugged/rl_unplugged.jl") - -end diff --git a/src/datasets/rl_unplugged/atari.jl b/src/datasets/rl_unplugged/atari.jl deleted file mode 100644 index 54b6189..0000000 --- a/src/datasets/rl_unplugged/atari.jl +++ /dev/null @@ -1,130 +0,0 @@ -export atari_dataset - -using Base.Threads -using Printf:@sprintf -using Base.Iterators -using TFRecord -using ImageCore -using PNGFiles - -""" -f = example.features.feature -o.bytes_list.value -""" -struct RLTransition - state - action - reward - terminal - next_state - next_action - episode_id - episode_return -end - -function batch!(dest::RLTransition, src) - for (i, src) in enumerate(src) - batch!(dest, src, i) - end -end - -function batch!(dest::RLTransition, src::RLTransition, i::Int) - for fn in fieldnames(RLTransition) - xs = getfield(dest, fn) - x = getfield(src, fn) - selectdim(xs, ndims(xs), i) .= x - end -end - -function decode_frame(bytes) - bytes |> IOBuffer |> PNGFiles.load |> channelview |> rawview -end - -function decode_state(bytes) - PermutedDimsArray(StackedView((decode_frame(x) for x in bytes)...), (2,3,1)) -end - -function RLTransition(example::TFRecord.Example) - f = example.features.feature - s = decode_state(f["o_t"].bytes_list.value) - s′ = decode_state(f["o_tp1"].bytes_list.value) - a = f["a_t"].int64_list.value[] - a′ = f["a_tp1"].int64_list.value[] - r = f["r_t"].float_list.value[] - t = f["d_t"].float_list.value[] != 1.0 - episode_id = f["episode_id"].int64_list.value[] - episode_return = f["episode_return"].float_list.value[] - RLTransition(s, a, r, t, s′, a′, episode_id, episode_return) -end - -function atari_dataset(; - dir, - game = "Pong", - run = 1, - num_shards = 100, - shuffle_buffer_size = 100_000, - tf_reader_bufsize = 10*1024*1024, - tf_reader_sz = 10_000, - batch_size = 256, - n_preallocations = nthreads() * 8 -) - n = nthreads() - @info "Loading the $run run of atari game ($game) from dir: $dir with $(n) threads" - - files = [ - joinpath(dir, game, @sprintf("run_%i-%05i-of-%05i", run, i, num_shards)) - for i in 0:num_shards-1 - ] - - ch_files = Channel{String}(length(files)) do ch - for f in cycle(files) - put!(ch, f) - end - end - - shuffled_files = buffered_shuffle(ch_files, length(files)) - - ch_src = Channel{RLTransition}(n * tf_reader_sz) do ch - for fs in partition(shuffled_files, n) - Threads.foreach( - TFRecord.read( - RLTransition, - fs; - compression=:gzip, - bufsize=tf_reader_bufsize, - channel_size=tf_reader_sz, - record_type=RLTransition - ); - schedule=Threads.StaticSchedule() - ) do x - put!(ch, x) - end - end - end - - transitions = buffered_shuffle( - ch_src, - shuffle_buffer_size - ) - - buffer = RLTransition( - Array{UInt8, 4}(undef, 84, 84, 4, batch_size), - Array{Int, 1}(undef, batch_size), - Array{Float32, 1}(undef, batch_size), - Array{Bool, 1}(undef, batch_size), - Array{UInt8, 4}(undef, 84, 84, 4, batch_size), - Array{Int, 1}(undef, batch_size), - Array{Int, 1}(undef, batch_size), - Array{Float32, 1}(undef, batch_size), - ) - - taskref = Ref{Task}() - res = RingBuffer(buffer;taskref=taskref, sz=n_preallocations) do buff - Threads.@threads for i in 1:batch_size - batch!(buff, popfirst!(transitions), i) - end - end - bind(ch_src, taskref[]) - bind(ch_files, taskref[]) - res -end \ No newline at end of file diff --git a/src/datasets/rl_unplugged/rl_unplugged.jl b/src/datasets/rl_unplugged/rl_unplugged.jl deleted file mode 100644 index 648710a..0000000 --- a/src/datasets/rl_unplugged/rl_unplugged.jl +++ /dev/null @@ -1 +0,0 @@ -include("atari.jl") \ No newline at end of file