diff --git a/src/core.jl b/src/core.jl index 9718cd6..2bddc14 100644 --- a/src/core.jl +++ b/src/core.jl @@ -3,6 +3,7 @@ using Base.Threads using CodecZlib using BufferedStreams using MacroTools: @forward +using ProtoBuf: ProtoType # Ref: https://github.com/tensorflow/tensorflow/blob/295ad2781683835be974faba0a191528d8079768/tensorflow/core/lib/hash/crc32c.h#L50-L59 @@ -62,10 +63,10 @@ read(f, s::String; kwargs...) = read(f, [s]; kw...) function read( f, files::Vector; - compression=nothing, - bufsize=10*1024*1024, - channel_size=1_000, - record_type=Example + compression = nothing, + bufsize = 10 * 1024 * 1024, + channel_size = 1_000, + record_type = Example, ) Channel{record_type}(channel_size) do ch @threads for file_name in files @@ -77,12 +78,13 @@ function read( elseif compression == :zlib io = ZlibDecompressorStream(io) else - isnothing(compression) || throw(ArgumentError("unsupported compression method: $compression")) + isnothing(compression) || + throw(ArgumentError("unsupported compression method: $compression")) end while !eof(io) - example = readproto(IOBuffer(read_record(io)), Example()) - put!(ch, f(example)) + instance = readproto(IOBuffer(read_record(io)), record_type()) + put!(ch, f(instance)) end end end @@ -128,7 +130,7 @@ function write(io::IO, xs) end end -function write(io::IO, x::Example) +function write(io::IO, x::ProtoType) buff = IOBuffer() writeproto(buff, x)