Skip to content

Commit

Permalink
Merge pull request #1 from findmyway/add_basic
Browse files Browse the repository at this point in the history
finish core part
  • Loading branch information
findmyway authored Oct 15, 2020
2 parents ae94aa5 + 999c1ab commit 92f72ea
Show file tree
Hide file tree
Showing 9 changed files with 398 additions and 4 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
/Manifest.toml

example.tfrecord
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ language: julia
notifications:
email: false
julia:
- 1.0
- 1.3
- 1.5
- nightly
os:
Expand Down
15 changes: 14 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,21 @@ uuid = "841416d8-1a6a-485a-b0fc-1328d0f53d5e"
authors = ["Jun Tian <[email protected]> and contributors"]
version = "0.1.0"

[deps]
BufferedStreams = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d"
CRC32c = "8bf52ea8-c179-5cab-976a-9e18b702a9bc"
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
ProtoBuf = "3349acd9-ac6a-5e09-bcdb-63829b23a429"

[compat]
julia = "1"
julia = "1.3"
BufferedStreams = "1.0"
CodecZlib = "0.7"
Glob = "1.3"
MacroTools = "0.5"
ProtoBuf = "0.8"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
52 changes: 52 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,55 @@
# TFRecord

[![Build Status](https://travis-ci.com/JuliaReinforcementLearning/TFRecord.jl.svg?branch=master)](https://travis-ci.com/JuliaReinforcementLearning/TFRecord.jl)

## Usage

### Install

```julia
julia> ] add TFRecord
```

### Write TFRecord

```julia
using TFRecord

writer = TFRecordWriter("example.tfrecord")

for i in 1:10
write(writer, Dict(
"feature1" => rand(Bool),
"feature2" => rand(1:5),
"feature3" => rand(("cat", "dog", "chicken", "horse", "goat")),
"feature4" => randn(Float32),
))
end

close(writer)
```

Here we write `10` observations into the file `example.tfrecord`. Internally each dictionary is converted into a `TFRecord.Example` first, which is a known prototype by TensorFlow. Note that the type of key must be `AbstractString` and the type of value can be one of the following types:

- `Bool`, `Int64`, `Float32`, `AbstractString`
- `Vector` of the above types

For customized data types, you need to convert it into `TFRecord.Example` first.

### Read TFRecord

```julia
reader = TFRecordReader("example.tfrecord")

for example in reader
println(example)
end
```

For more fine-grained control, please read the doc:

```julia
julia> ? TFRecordReader

julia> ? TFRecordWriter
```
3 changes: 2 additions & 1 deletion src/TFRecord.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module TFRecord

# Write your package code here.
include("jlout/example_pb.jl")
include("core.jl")

end
171 changes: 171 additions & 0 deletions src/core.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
export TFRecordReader, TFRecordWriter

using CRC32c
using Glob
using Base.Threads
using CodecZlib
using BufferedStreams
using MacroTools: @forward

# Ref: https://github.com/tensorflow/tensorflow/blob/295ad2781683835be974faba0a191528d8079768/tensorflow/core/lib/hash/crc32c.h#L50-L59

const MASK_DELTA = 0xa282ead8

mask(crc::UInt32) = ((crc >> 15) | (crc << 17)) + MASK_DELTA

function unmask(masked_crc::UInt32)
rot = masked_crc - MASK_DELTA
((rot >> 17) | (rot << 15))
end

"""
Ref: https://github.com/tensorflow/tensorflow/blob/295ad2781683835be974faba0a191528d8079768/tensorflow/core/lib/io/record_reader.cc#L164-L199
Each record is stored in the following format:
```
uint64 n
uint32 masked_crc32_of_n
byte data[n]
uint32 masked_crc32_of_data
```
"""
function read_record(io::IO)
n = read(io, sizeof(UInt64))
masked_crc32_n = read(io, UInt32)
crc32c(n) == unmask(masked_crc32_n) || error("record corrupted, did you set the correct compression?")

data = read(io, Int(reinterpret(UInt64, n)[])) # !!! watch https://github.com/JuliaIO/TranscodingStreams.jl/pull/104
masked_crc32_data = read(io, UInt32)
crc32c(data) == unmask(masked_crc32_data) || error("record corrupted, did you set the correct compression?")
data
end

#####
# TFRecordReader
#####

"""
TFRecordReader(s;kwargs...)
# Keyword Arguments
- `compression=nothing`. No compression by default. Optional values are `:zlib` and `:gzip`.
- `bufsize=1024*1024`. Set the buffer size of internal `BufferedOutputStream`. The default value is `1M`. Suggested value is between `1M`~`100M`.
- `channel_size=1000`. The number of pre-fetched elements.
!!!note
To enable reading records from multiple files concurrently, remember to set the number of threads correctly (See [JULIA_NUM_THREADS](https://docs.julialang.org/en/v1/manual/environment-variables/#JULIA_NUM_THREADS)). Unfortunately, the feature is currently broken. Please watch https://github.com/JuliaIO/ProtoBuf.jl/issues/140 .
"""
struct TFRecordReader{T}
ch::Channel
end

@forward TFRecordReader.ch Base.close, Base.iterate, Base.isopen, Base.take!

TFRecordReader(s::String;kwargs...) = TFRecordReader{Example}(identity, s;kwargs...)

function TFRecordReader{T}(f, s;compression=nothing,bufsize=1024*1024, channel_size=1000) where T
files = glob(s)
length(files) > 0 || error("can not find any files under: $s")
chnl = Channel{T}(channel_size) do ch
# watch https://github.com/JuliaIO/ProtoBuf.jl/issues/140
#= @threads =# for file_name in files
open(file_name, "r") do io

io = BufferedInputStream(io, bufsize)
if compression == :gzip
io = GzipDecompressorStream(io)
elseif compression == :zlib
io = ZlibDecompressorStream(io)
else
isnothing(compression) || throw(ArgumentError("unsupported compression method: $compression"))
end

while !eof(io)
example = readproto(IOBuffer(read_record(io)), Example())
put!(ch, f(example))
end
end
end
end
TFRecordReader{T}(chnl)
end

#####
# TFRecordWriter
#####

struct TFRecordWriter{X<:IO}
io::X
end

"""
TFRecordWriter(s;compression=nothing, bufsize=1024*1024)
Supported `compression` methods are: `:gzip` or `:zlib`.
Default value is `nothing`, which means do not do compression.
`bufsize` is used to set the size of buffer used by an internal
`BufferedOutputStream`, the default value is `1M` (1024*1024).
You may want to change it to a larger value when writing large datasets,
for example `100M`.
"""
function TFRecordWriter(s::AbstractString;compression=nothing, bufsize=1024*1024)
io = BufferedOutputStream(open(s, "w"), bufsize)
if compression == :gzip
io = GzipCompressorStream(io)
elseif compression == :zlib
io = ZlibCompressorStream(io)
else
isnothing(compression) || throw(ArgumentError("unsupported compression method: $compression"))
end
TFRecordWriter(io)
end

Base.close(w::TFRecordWriter) = close(w.io)

Base.write(w::TFRecordWriter, x) = write(w, convert(Example, x))

function Base.write(w::TFRecordWriter, x::Example)
buff = IOBuffer()
writeproto(buff, x)

data_crc = mask(crc32c(seekstart(buff)))
data = take!(seekstart(buff))
n = length(data)

buff = IOBuffer()
write(buff, n)
n_crc = mask(crc32c(seekstart(buff)))

write(w.io, n)
write(w.io, n_crc)
write(w.io, data)
write(w.io, data_crc)
end

#####
# convert
#####

Base.convert(::Type{Feature}, x::Int) = Feature(;int64_list=Int64List(value=[x]))
Base.convert(::Type{Feature}, x::Bool) = Feature(;int64_list=Int64List(value=[Int(x)]))
Base.convert(::Type{Feature}, x::Float32) = Feature(;float_list=FloatList(value=[x]))
Base.convert(::Type{Feature}, x::AbstractString) = Feature(;bytes_list=BytesList(value=[unsafe_wrap(Vector{UInt8}, x)]))

Base.convert(::Type{Feature}, x::Vector{Int}) = Feature(;int64_list=Int64List(value=x))
Base.convert(::Type{Feature}, x::Vector{Bool}) = Feature(;int64_list=Int64List(value=convert(Vector{Int}, x)))
Base.convert(::Type{Feature}, x::Vector{Float32}) = Feature(;float_list=FloatList(value=x))
Base.convert(::Type{Feature}, x::Vector{<:AbstractString}) = Feature(;bytes_list=BytesList(value=[unsafe_wrap(Vector{UInt8}, s) for s in x]))
Base.convert(::Type{Feature}, x::Vector{Array{UInt8,1}}) = Feature(;bytes_list=BytesList(value=x))

Base.convert(::Type{Features}, x::Dict) = Features(;feature=Dict(k=>convert(Feature, v) for (k, v) in x))

function Base.convert(::Type{Example}, x::Dict)
d = Example()
d.features = convert(Features, x)
d
end
73 changes: 73 additions & 0 deletions src/jlout/example_pb.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# syntax: proto3
using ProtoBuf
import ProtoBuf.meta

mutable struct BytesList <: ProtoType
value::Base.Vector{Array{UInt8,1}}
BytesList(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o)
end #mutable struct BytesList

mutable struct FloatList <: ProtoType
value::Base.Vector{Float32}
FloatList(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o)
end #mutable struct FloatList
const __pack_FloatList = Symbol[:value]
meta(t::Type{FloatList}) = meta(t, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, true, __pack_FloatList, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES, ProtoBuf.DEF_FIELD_TYPES)

mutable struct Int64List <: ProtoType
value::Base.Vector{Int64}
Int64List(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o)
end #mutable struct Int64List
const __pack_Int64List = Symbol[:value]
meta(t::Type{Int64List}) = meta(t, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, true, __pack_Int64List, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES, ProtoBuf.DEF_FIELD_TYPES)

mutable struct Feature <: ProtoType
bytes_list::BytesList
float_list::FloatList
int64_list::Int64List
Feature(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o)
end #mutable struct Feature
const __oneofs_Feature = Int[1,1,1]
const __oneof_names_Feature = [Symbol("kind")]
meta(t::Type{Feature}) = meta(t, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, true, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, __oneofs_Feature, __oneof_names_Feature, ProtoBuf.DEF_FIELD_TYPES)

mutable struct Features_FeatureEntry <: ProtoType
key::AbstractString
value::Feature
Features_FeatureEntry(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o)
end #mutable struct Features_FeatureEntry (mapentry)

mutable struct Features <: ProtoType
feature::Base.Dict{AbstractString,Feature} # map entry
Features(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o)
end #mutable struct Features

mutable struct Example <: ProtoType
features::Features
Example(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o)
end #mutable struct Example

mutable struct FeatureList <: ProtoType
feature::Base.Vector{Feature}
FeatureList(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o)
end #mutable struct FeatureList

mutable struct FeatureLists_FeatureListEntry <: ProtoType
key::AbstractString
value::FeatureList
FeatureLists_FeatureListEntry(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o)
end #mutable struct FeatureLists_FeatureListEntry (mapentry)

mutable struct FeatureLists <: ProtoType
feature_list::Base.Dict{AbstractString,FeatureList} # map entry
FeatureLists(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o)
end #mutable struct FeatureLists

mutable struct SequenceExample <: ProtoType
context::Features
feature_lists::FeatureLists
SequenceExample(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o)
end #mutable struct SequenceExample

export Example, SequenceExample, BytesList, FloatList, Int64List, Feature, Features_FeatureEntry, Features, FeatureList, FeatureLists_FeatureListEntry, FeatureLists
# mapentries: "FeatureLists_FeatureListEntry" => ("AbstractString", "FeatureList"), "Features_FeatureEntry" => ("AbstractString", "Feature")
Loading

0 comments on commit 92f72ea

Please sign in to comment.