-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from findmyway/add_basic
finish core part
- Loading branch information
Showing
9 changed files
with
398 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
/Manifest.toml | ||
|
||
example.tfrecord |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ language: julia | |
notifications: | ||
email: false | ||
julia: | ||
- 1.0 | ||
- 1.3 | ||
- 1.5 | ||
- nightly | ||
os: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,8 +3,21 @@ uuid = "841416d8-1a6a-485a-b0fc-1328d0f53d5e" | |
authors = ["Jun Tian <[email protected]> and contributors"] | ||
version = "0.1.0" | ||
|
||
[deps] | ||
BufferedStreams = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" | ||
CRC32c = "8bf52ea8-c179-5cab-976a-9e18b702a9bc" | ||
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193" | ||
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" | ||
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" | ||
ProtoBuf = "3349acd9-ac6a-5e09-bcdb-63829b23a429" | ||
|
||
[compat] | ||
julia = "1" | ||
julia = "1.3" | ||
BufferedStreams = "1.0" | ||
CodecZlib = "0.7" | ||
Glob = "1.3" | ||
MacroTools = "0.5" | ||
ProtoBuf = "0.8" | ||
|
||
[extras] | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,55 @@ | ||
# TFRecord | ||
|
||
[![Build Status](https://travis-ci.com/JuliaReinforcementLearning/TFRecord.jl.svg?branch=master)](https://travis-ci.com/JuliaReinforcementLearning/TFRecord.jl) | ||
|
||
## Usage | ||
|
||
### Install | ||
|
||
```julia | ||
julia> ] add TFRecord | ||
``` | ||
|
||
### Write TFRecord | ||
|
||
```julia | ||
using TFRecord | ||
|
||
writer = TFRecordWriter("example.tfrecord") | ||
|
||
for i in 1:10 | ||
write(writer, Dict( | ||
"feature1" => rand(Bool), | ||
"feature2" => rand(1:5), | ||
"feature3" => rand(("cat", "dog", "chicken", "horse", "goat")), | ||
"feature4" => randn(Float32), | ||
)) | ||
end | ||
|
||
close(writer) | ||
``` | ||
|
||
Here we write `10` observations into the file `example.tfrecord`. Internally each dictionary is converted into a `TFRecord.Example` first, which is a known prototype by TensorFlow. Note that the type of key must be `AbstractString` and the type of value can be one of the following types: | ||
|
||
- `Bool`, `Int64`, `Float32`, `AbstractString` | ||
- `Vector` of the above types | ||
|
||
For customized data types, you need to convert it into `TFRecord.Example` first. | ||
|
||
### Read TFRecord | ||
|
||
```julia | ||
reader = TFRecordReader("example.tfrecord") | ||
|
||
for example in reader | ||
println(example) | ||
end | ||
``` | ||
|
||
For more fine-grained control, please read the doc: | ||
|
||
```julia | ||
julia> ? TFRecordReader | ||
|
||
julia> ? TFRecordWriter | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
module TFRecord | ||
|
||
# Write your package code here. | ||
include("jlout/example_pb.jl") | ||
include("core.jl") | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
export TFRecordReader, TFRecordWriter | ||
|
||
using CRC32c | ||
using Glob | ||
using Base.Threads | ||
using CodecZlib | ||
using BufferedStreams | ||
using MacroTools: @forward | ||
|
||
# Ref: https://github.com/tensorflow/tensorflow/blob/295ad2781683835be974faba0a191528d8079768/tensorflow/core/lib/hash/crc32c.h#L50-L59 | ||
|
||
const MASK_DELTA = 0xa282ead8 | ||
|
||
mask(crc::UInt32) = ((crc >> 15) | (crc << 17)) + MASK_DELTA | ||
|
||
function unmask(masked_crc::UInt32) | ||
rot = masked_crc - MASK_DELTA | ||
((rot >> 17) | (rot << 15)) | ||
end | ||
|
||
""" | ||
Ref: https://github.com/tensorflow/tensorflow/blob/295ad2781683835be974faba0a191528d8079768/tensorflow/core/lib/io/record_reader.cc#L164-L199 | ||
Each record is stored in the following format: | ||
``` | ||
uint64 n | ||
uint32 masked_crc32_of_n | ||
byte data[n] | ||
uint32 masked_crc32_of_data | ||
``` | ||
""" | ||
function read_record(io::IO) | ||
n = read(io, sizeof(UInt64)) | ||
masked_crc32_n = read(io, UInt32) | ||
crc32c(n) == unmask(masked_crc32_n) || error("record corrupted, did you set the correct compression?") | ||
|
||
data = read(io, Int(reinterpret(UInt64, n)[])) # !!! watch https://github.com/JuliaIO/TranscodingStreams.jl/pull/104 | ||
masked_crc32_data = read(io, UInt32) | ||
crc32c(data) == unmask(masked_crc32_data) || error("record corrupted, did you set the correct compression?") | ||
data | ||
end | ||
|
||
##### | ||
# TFRecordReader | ||
##### | ||
|
||
""" | ||
TFRecordReader(s;kwargs...) | ||
# Keyword Arguments | ||
- `compression=nothing`. No compression by default. Optional values are `:zlib` and `:gzip`. | ||
- `bufsize=1024*1024`. Set the buffer size of internal `BufferedOutputStream`. The default value is `1M`. Suggested value is between `1M`~`100M`. | ||
- `channel_size=1000`. The number of pre-fetched elements. | ||
!!!note | ||
To enable reading records from multiple files concurrently, remember to set the number of threads correctly (See [JULIA_NUM_THREADS](https://docs.julialang.org/en/v1/manual/environment-variables/#JULIA_NUM_THREADS)). Unfortunately, the feature is currently broken. Please watch https://github.com/JuliaIO/ProtoBuf.jl/issues/140 . | ||
""" | ||
struct TFRecordReader{T} | ||
ch::Channel | ||
end | ||
|
||
@forward TFRecordReader.ch Base.close, Base.iterate, Base.isopen, Base.take! | ||
|
||
TFRecordReader(s::String;kwargs...) = TFRecordReader{Example}(identity, s;kwargs...) | ||
|
||
function TFRecordReader{T}(f, s;compression=nothing,bufsize=1024*1024, channel_size=1000) where T | ||
files = glob(s) | ||
length(files) > 0 || error("can not find any files under: $s") | ||
chnl = Channel{T}(channel_size) do ch | ||
# watch https://github.com/JuliaIO/ProtoBuf.jl/issues/140 | ||
#= @threads =# for file_name in files | ||
open(file_name, "r") do io | ||
|
||
io = BufferedInputStream(io, bufsize) | ||
if compression == :gzip | ||
io = GzipDecompressorStream(io) | ||
elseif compression == :zlib | ||
io = ZlibDecompressorStream(io) | ||
else | ||
isnothing(compression) || throw(ArgumentError("unsupported compression method: $compression")) | ||
end | ||
|
||
while !eof(io) | ||
example = readproto(IOBuffer(read_record(io)), Example()) | ||
put!(ch, f(example)) | ||
end | ||
end | ||
end | ||
end | ||
TFRecordReader{T}(chnl) | ||
end | ||
|
||
##### | ||
# TFRecordWriter | ||
##### | ||
|
||
struct TFRecordWriter{X<:IO} | ||
io::X | ||
end | ||
|
||
""" | ||
TFRecordWriter(s;compression=nothing, bufsize=1024*1024) | ||
Supported `compression` methods are: `:gzip` or `:zlib`. | ||
Default value is `nothing`, which means do not do compression. | ||
`bufsize` is used to set the size of buffer used by an internal | ||
`BufferedOutputStream`, the default value is `1M` (1024*1024). | ||
You may want to change it to a larger value when writing large datasets, | ||
for example `100M`. | ||
""" | ||
function TFRecordWriter(s::AbstractString;compression=nothing, bufsize=1024*1024) | ||
io = BufferedOutputStream(open(s, "w"), bufsize) | ||
if compression == :gzip | ||
io = GzipCompressorStream(io) | ||
elseif compression == :zlib | ||
io = ZlibCompressorStream(io) | ||
else | ||
isnothing(compression) || throw(ArgumentError("unsupported compression method: $compression")) | ||
end | ||
TFRecordWriter(io) | ||
end | ||
|
||
Base.close(w::TFRecordWriter) = close(w.io) | ||
|
||
Base.write(w::TFRecordWriter, x) = write(w, convert(Example, x)) | ||
|
||
function Base.write(w::TFRecordWriter, x::Example) | ||
buff = IOBuffer() | ||
writeproto(buff, x) | ||
|
||
data_crc = mask(crc32c(seekstart(buff))) | ||
data = take!(seekstart(buff)) | ||
n = length(data) | ||
|
||
buff = IOBuffer() | ||
write(buff, n) | ||
n_crc = mask(crc32c(seekstart(buff))) | ||
|
||
write(w.io, n) | ||
write(w.io, n_crc) | ||
write(w.io, data) | ||
write(w.io, data_crc) | ||
end | ||
|
||
##### | ||
# convert | ||
##### | ||
|
||
Base.convert(::Type{Feature}, x::Int) = Feature(;int64_list=Int64List(value=[x])) | ||
Base.convert(::Type{Feature}, x::Bool) = Feature(;int64_list=Int64List(value=[Int(x)])) | ||
Base.convert(::Type{Feature}, x::Float32) = Feature(;float_list=FloatList(value=[x])) | ||
Base.convert(::Type{Feature}, x::AbstractString) = Feature(;bytes_list=BytesList(value=[unsafe_wrap(Vector{UInt8}, x)])) | ||
|
||
Base.convert(::Type{Feature}, x::Vector{Int}) = Feature(;int64_list=Int64List(value=x)) | ||
Base.convert(::Type{Feature}, x::Vector{Bool}) = Feature(;int64_list=Int64List(value=convert(Vector{Int}, x))) | ||
Base.convert(::Type{Feature}, x::Vector{Float32}) = Feature(;float_list=FloatList(value=x)) | ||
Base.convert(::Type{Feature}, x::Vector{<:AbstractString}) = Feature(;bytes_list=BytesList(value=[unsafe_wrap(Vector{UInt8}, s) for s in x])) | ||
Base.convert(::Type{Feature}, x::Vector{Array{UInt8,1}}) = Feature(;bytes_list=BytesList(value=x)) | ||
|
||
Base.convert(::Type{Features}, x::Dict) = Features(;feature=Dict(k=>convert(Feature, v) for (k, v) in x)) | ||
|
||
function Base.convert(::Type{Example}, x::Dict) | ||
d = Example() | ||
d.features = convert(Features, x) | ||
d | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# syntax: proto3 | ||
using ProtoBuf | ||
import ProtoBuf.meta | ||
|
||
mutable struct BytesList <: ProtoType | ||
value::Base.Vector{Array{UInt8,1}} | ||
BytesList(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) | ||
end #mutable struct BytesList | ||
|
||
mutable struct FloatList <: ProtoType | ||
value::Base.Vector{Float32} | ||
FloatList(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) | ||
end #mutable struct FloatList | ||
const __pack_FloatList = Symbol[:value] | ||
meta(t::Type{FloatList}) = meta(t, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, true, __pack_FloatList, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES, ProtoBuf.DEF_FIELD_TYPES) | ||
|
||
mutable struct Int64List <: ProtoType | ||
value::Base.Vector{Int64} | ||
Int64List(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) | ||
end #mutable struct Int64List | ||
const __pack_Int64List = Symbol[:value] | ||
meta(t::Type{Int64List}) = meta(t, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, true, __pack_Int64List, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES, ProtoBuf.DEF_FIELD_TYPES) | ||
|
||
mutable struct Feature <: ProtoType | ||
bytes_list::BytesList | ||
float_list::FloatList | ||
int64_list::Int64List | ||
Feature(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) | ||
end #mutable struct Feature | ||
const __oneofs_Feature = Int[1,1,1] | ||
const __oneof_names_Feature = [Symbol("kind")] | ||
meta(t::Type{Feature}) = meta(t, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, true, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, __oneofs_Feature, __oneof_names_Feature, ProtoBuf.DEF_FIELD_TYPES) | ||
|
||
mutable struct Features_FeatureEntry <: ProtoType | ||
key::AbstractString | ||
value::Feature | ||
Features_FeatureEntry(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) | ||
end #mutable struct Features_FeatureEntry (mapentry) | ||
|
||
mutable struct Features <: ProtoType | ||
feature::Base.Dict{AbstractString,Feature} # map entry | ||
Features(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) | ||
end #mutable struct Features | ||
|
||
mutable struct Example <: ProtoType | ||
features::Features | ||
Example(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) | ||
end #mutable struct Example | ||
|
||
mutable struct FeatureList <: ProtoType | ||
feature::Base.Vector{Feature} | ||
FeatureList(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) | ||
end #mutable struct FeatureList | ||
|
||
mutable struct FeatureLists_FeatureListEntry <: ProtoType | ||
key::AbstractString | ||
value::FeatureList | ||
FeatureLists_FeatureListEntry(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) | ||
end #mutable struct FeatureLists_FeatureListEntry (mapentry) | ||
|
||
mutable struct FeatureLists <: ProtoType | ||
feature_list::Base.Dict{AbstractString,FeatureList} # map entry | ||
FeatureLists(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) | ||
end #mutable struct FeatureLists | ||
|
||
mutable struct SequenceExample <: ProtoType | ||
context::Features | ||
feature_lists::FeatureLists | ||
SequenceExample(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) | ||
end #mutable struct SequenceExample | ||
|
||
export Example, SequenceExample, BytesList, FloatList, Int64List, Feature, Features_FeatureEntry, Features, FeatureList, FeatureLists_FeatureListEntry, FeatureLists | ||
# mapentries: "FeatureLists_FeatureListEntry" => ("AbstractString", "FeatureList"), "Features_FeatureEntry" => ("AbstractString", "Feature") |
Oops, something went wrong.