diff --git a/.gitignore b/.gitignore index b067edd..55d0f3f 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ /Manifest.toml + +example.tfrecord \ No newline at end of file diff --git a/.travis.yml b/.travis.yml index f3bd2ac..3570473 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,7 +3,7 @@ language: julia notifications: email: false julia: - - 1.0 + - 1.3 - 1.5 - nightly os: diff --git a/Project.toml b/Project.toml index 3f987e8..3d09c53 100644 --- a/Project.toml +++ b/Project.toml @@ -3,8 +3,21 @@ uuid = "841416d8-1a6a-485a-b0fc-1328d0f53d5e" authors = ["Jun Tian and contributors"] version = "0.1.0" +[deps] +BufferedStreams = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" +CRC32c = "8bf52ea8-c179-5cab-976a-9e18b702a9bc" +CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193" +Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" +MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +ProtoBuf = "3349acd9-ac6a-5e09-bcdb-63829b23a429" + [compat] -julia = "1" +julia = "1.3" +BufferedStreams = "1.0" +CodecZlib = "0.7" +Glob = "1.3" +MacroTools = "0.5" +ProtoBuf = "0.8" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/README.md b/README.md index e2d1635..488b09a 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,55 @@ # TFRecord [![Build Status](https://travis-ci.com/JuliaReinforcementLearning/TFRecord.jl.svg?branch=master)](https://travis-ci.com/JuliaReinforcementLearning/TFRecord.jl) + +## Usage + +### Install + +```julia +julia> ] add TFRecord +``` + +### Write TFRecord + +```julia +using TFRecord + +writer = TFRecordWriter("example.tfrecord") + +for i in 1:10 + write(writer, Dict( + "feature1" => rand(Bool), + "feature2" => rand(1:5), + "feature3" => rand(("cat", "dog", "chicken", "horse", "goat")), + "feature4" => randn(Float32), + )) +end + +close(writer) +``` + +Here we write `10` observations into the file `example.tfrecord`. Internally each dictionary is converted into a `TFRecord.Example` first, which is a known prototype by TensorFlow. Note that the type of key must be `AbstractString` and the type of value can be one of the following types: + +- `Bool`, `Int64`, `Float32`, `AbstractString` +- `Vector` of the above types + +For customized data types, you need to convert it into `TFRecord.Example` first. + +### Read TFRecord + +```julia +reader = TFRecordReader("example.tfrecord") + +for example in reader + println(example) +end +``` + +For more fine-grained control, please read the doc: + +```julia +julia> ? TFRecordReader + +julia> ? TFRecordWriter +``` diff --git a/src/TFRecord.jl b/src/TFRecord.jl index 10dceb5..66d3dee 100644 --- a/src/TFRecord.jl +++ b/src/TFRecord.jl @@ -1,5 +1,6 @@ module TFRecord -# Write your package code here. +include("jlout/example_pb.jl") +include("core.jl") end diff --git a/src/core.jl b/src/core.jl new file mode 100644 index 0000000..46f5fc2 --- /dev/null +++ b/src/core.jl @@ -0,0 +1,171 @@ +export TFRecordReader, TFRecordWriter + +using CRC32c +using Glob +using Base.Threads +using CodecZlib +using BufferedStreams +using MacroTools: @forward + +# Ref: https://github.com/tensorflow/tensorflow/blob/295ad2781683835be974faba0a191528d8079768/tensorflow/core/lib/hash/crc32c.h#L50-L59 + +const MASK_DELTA = 0xa282ead8 + +mask(crc::UInt32) = ((crc >> 15) | (crc << 17)) + MASK_DELTA + +function unmask(masked_crc::UInt32) + rot = masked_crc - MASK_DELTA + ((rot >> 17) | (rot << 15)) +end + +""" + +Ref: https://github.com/tensorflow/tensorflow/blob/295ad2781683835be974faba0a191528d8079768/tensorflow/core/lib/io/record_reader.cc#L164-L199 + +Each record is stored in the following format: + +``` +uint64 n +uint32 masked_crc32_of_n +byte data[n] +uint32 masked_crc32_of_data +``` +""" +function read_record(io::IO) + n = read(io, sizeof(UInt64)) + masked_crc32_n = read(io, UInt32) + crc32c(n) == unmask(masked_crc32_n) || error("record corrupted, did you set the correct compression?") + + data = read(io, Int(reinterpret(UInt64, n)[])) # !!! watch https://github.com/JuliaIO/TranscodingStreams.jl/pull/104 + masked_crc32_data = read(io, UInt32) + crc32c(data) == unmask(masked_crc32_data) || error("record corrupted, did you set the correct compression?") + data +end + +##### +# TFRecordReader +##### + +""" + TFRecordReader(s;kwargs...) + +# Keyword Arguments + +- `compression=nothing`. No compression by default. Optional values are `:zlib` and `:gzip`. +- `bufsize=1024*1024`. Set the buffer size of internal `BufferedOutputStream`. The default value is `1M`. Suggested value is between `1M`~`100M`. +- `channel_size=1000`. The number of pre-fetched elements. + +!!!note + + To enable reading records from multiple files concurrently, remember to set the number of threads correctly (See [JULIA_NUM_THREADS](https://docs.julialang.org/en/v1/manual/environment-variables/#JULIA_NUM_THREADS)). Unfortunately, the feature is currently broken. Please watch https://github.com/JuliaIO/ProtoBuf.jl/issues/140 . + +""" +struct TFRecordReader{T} + ch::Channel +end + +@forward TFRecordReader.ch Base.close, Base.iterate, Base.isopen, Base.take! + +TFRecordReader(s::String;kwargs...) = TFRecordReader{Example}(identity, s;kwargs...) + +function TFRecordReader{T}(f, s;compression=nothing,bufsize=1024*1024, channel_size=1000) where T + files = glob(s) + length(files) > 0 || error("can not find any files under: $s") + chnl = Channel{T}(channel_size) do ch + # watch https://github.com/JuliaIO/ProtoBuf.jl/issues/140 + #= @threads =# for file_name in files + open(file_name, "r") do io + + io = BufferedInputStream(io, bufsize) + if compression == :gzip + io = GzipDecompressorStream(io) + elseif compression == :zlib + io = ZlibDecompressorStream(io) + else + isnothing(compression) || throw(ArgumentError("unsupported compression method: $compression")) + end + + while !eof(io) + example = readproto(IOBuffer(read_record(io)), Example()) + put!(ch, f(example)) + end + end + end + end + TFRecordReader{T}(chnl) +end + +##### +# TFRecordWriter +##### + +struct TFRecordWriter{X<:IO} + io::X +end + +""" + TFRecordWriter(s;compression=nothing, bufsize=1024*1024) + +Supported `compression` methods are: `:gzip` or `:zlib`. +Default value is `nothing`, which means do not do compression. +`bufsize` is used to set the size of buffer used by an internal +`BufferedOutputStream`, the default value is `1M` (1024*1024). +You may want to change it to a larger value when writing large datasets, +for example `100M`. +""" +function TFRecordWriter(s::AbstractString;compression=nothing, bufsize=1024*1024) + io = BufferedOutputStream(open(s, "w"), bufsize) + if compression == :gzip + io = GzipCompressorStream(io) + elseif compression == :zlib + io = ZlibCompressorStream(io) + else + isnothing(compression) || throw(ArgumentError("unsupported compression method: $compression")) + end + TFRecordWriter(io) +end + +Base.close(w::TFRecordWriter) = close(w.io) + +Base.write(w::TFRecordWriter, x) = write(w, convert(Example, x)) + +function Base.write(w::TFRecordWriter, x::Example) + buff = IOBuffer() + writeproto(buff, x) + + data_crc = mask(crc32c(seekstart(buff))) + data = take!(seekstart(buff)) + n = length(data) + + buff = IOBuffer() + write(buff, n) + n_crc = mask(crc32c(seekstart(buff))) + + write(w.io, n) + write(w.io, n_crc) + write(w.io, data) + write(w.io, data_crc) +end + +##### +# convert +##### + +Base.convert(::Type{Feature}, x::Int) = Feature(;int64_list=Int64List(value=[x])) +Base.convert(::Type{Feature}, x::Bool) = Feature(;int64_list=Int64List(value=[Int(x)])) +Base.convert(::Type{Feature}, x::Float32) = Feature(;float_list=FloatList(value=[x])) +Base.convert(::Type{Feature}, x::AbstractString) = Feature(;bytes_list=BytesList(value=[unsafe_wrap(Vector{UInt8}, x)])) + +Base.convert(::Type{Feature}, x::Vector{Int}) = Feature(;int64_list=Int64List(value=x)) +Base.convert(::Type{Feature}, x::Vector{Bool}) = Feature(;int64_list=Int64List(value=convert(Vector{Int}, x))) +Base.convert(::Type{Feature}, x::Vector{Float32}) = Feature(;float_list=FloatList(value=x)) +Base.convert(::Type{Feature}, x::Vector{<:AbstractString}) = Feature(;bytes_list=BytesList(value=[unsafe_wrap(Vector{UInt8}, s) for s in x])) +Base.convert(::Type{Feature}, x::Vector{Array{UInt8,1}}) = Feature(;bytes_list=BytesList(value=x)) + +Base.convert(::Type{Features}, x::Dict) = Features(;feature=Dict(k=>convert(Feature, v) for (k, v) in x)) + +function Base.convert(::Type{Example}, x::Dict) + d = Example() + d.features = convert(Features, x) + d +end diff --git a/src/jlout/example_pb.jl b/src/jlout/example_pb.jl new file mode 100644 index 0000000..f599f6b --- /dev/null +++ b/src/jlout/example_pb.jl @@ -0,0 +1,73 @@ +# syntax: proto3 +using ProtoBuf +import ProtoBuf.meta + +mutable struct BytesList <: ProtoType + value::Base.Vector{Array{UInt8,1}} + BytesList(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) +end #mutable struct BytesList + +mutable struct FloatList <: ProtoType + value::Base.Vector{Float32} + FloatList(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) +end #mutable struct FloatList +const __pack_FloatList = Symbol[:value] +meta(t::Type{FloatList}) = meta(t, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, true, __pack_FloatList, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES, ProtoBuf.DEF_FIELD_TYPES) + +mutable struct Int64List <: ProtoType + value::Base.Vector{Int64} + Int64List(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) +end #mutable struct Int64List +const __pack_Int64List = Symbol[:value] +meta(t::Type{Int64List}) = meta(t, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, true, __pack_Int64List, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES, ProtoBuf.DEF_FIELD_TYPES) + +mutable struct Feature <: ProtoType + bytes_list::BytesList + float_list::FloatList + int64_list::Int64List + Feature(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) +end #mutable struct Feature +const __oneofs_Feature = Int[1,1,1] +const __oneof_names_Feature = [Symbol("kind")] +meta(t::Type{Feature}) = meta(t, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, true, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, __oneofs_Feature, __oneof_names_Feature, ProtoBuf.DEF_FIELD_TYPES) + +mutable struct Features_FeatureEntry <: ProtoType + key::AbstractString + value::Feature + Features_FeatureEntry(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) +end #mutable struct Features_FeatureEntry (mapentry) + +mutable struct Features <: ProtoType + feature::Base.Dict{AbstractString,Feature} # map entry + Features(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) +end #mutable struct Features + +mutable struct Example <: ProtoType + features::Features + Example(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) +end #mutable struct Example + +mutable struct FeatureList <: ProtoType + feature::Base.Vector{Feature} + FeatureList(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) +end #mutable struct FeatureList + +mutable struct FeatureLists_FeatureListEntry <: ProtoType + key::AbstractString + value::FeatureList + FeatureLists_FeatureListEntry(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) +end #mutable struct FeatureLists_FeatureListEntry (mapentry) + +mutable struct FeatureLists <: ProtoType + feature_list::Base.Dict{AbstractString,FeatureList} # map entry + FeatureLists(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) +end #mutable struct FeatureLists + +mutable struct SequenceExample <: ProtoType + context::Features + feature_lists::FeatureLists + SequenceExample(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) +end #mutable struct SequenceExample + +export Example, SequenceExample, BytesList, FloatList, Int64List, Feature, Features_FeatureEntry, Features, FeatureList, FeatureLists_FeatureListEntry, FeatureLists +# mapentries: "FeatureLists_FeatureListEntry" => ("AbstractString", "FeatureList"), "Features_FeatureEntry" => ("AbstractString", "Feature") diff --git a/src/proto/example.proto b/src/proto/example.proto new file mode 100644 index 0000000..b770e68 --- /dev/null +++ b/src/proto/example.proto @@ -0,0 +1,55 @@ +// This .proto file is based on those that are in https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/example + +syntax = "proto3"; +option cc_enable_arenas = true; + +message Example +{ + Features features = 1; +} + +message SequenceExample +{ + Features context = 1; + FeatureLists feature_lists = 2; +} + +message BytesList +{ + repeated bytes value = 1; +} + +message FloatList +{ + repeated float value = 1 [packed = true]; +} + +message Int64List +{ + repeated int64 value = 1 [packed = true]; +} + +message Feature +{ + oneof kind + { + BytesList bytes_list = 1; + FloatList float_list = 2; + Int64List int64_list = 3; + } +} + +message Features +{ + map feature = 1; +} + +message FeatureList +{ + repeated Feature feature = 1; +} + +message FeatureLists +{ + map feature_list = 1; +} diff --git a/test/runtests.jl b/test/runtests.jl index 733185f..918f212 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,5 +2,32 @@ using TFRecord using Test @testset "TFRecord.jl" begin - # Write your tests here. + + writer = TFRecordWriter("example.tfrecord") + + n = 10 + f1 = rand(Bool, n) + f2 = rand(1:5, n) + f3 = rand(("cat", "dog", "chicken", "horse", "goat"), n) + f4 = rand(Float32, n) + + for i in 1:n + write(writer, Dict( + "feature1" => f1[i], + "feature2" => f2[i], + "feature3" => f3[i], + "feature4" => f4[i], + )) + end + + close(writer) + + reader = TFRecordReader("example.tfrecord") + + for (i, example) in enumerate(reader) + @test example.features.feature["feature1"].int64_list.value[] == Int(f1[i]) + @test example.features.feature["feature2"].int64_list.value[] == f2[i] + @test String(example.features.feature["feature3"].bytes_list.value[]) == f3[i] + @test example.features.feature["feature4"].float_list.value[] == f4[i] + end end