Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into fix-nullable-annota…
Browse files Browse the repository at this point in the history
…tion

* origin/master:
  Added return code checking to header parsing.  Throw error if the operation failed.
  Pass includePath to gcc so it can find include files that might not be on system path. Issue google-deepmind#105.
  Fix undefined behavior by not casting -1 to unsigned int directly
  Remove test for threads as it requries threads to be installed and hdf5 compiled with --enable-threadsafety
  Add torch serialization so hdf5files can be read from multiple threads
  Prevent H5Dget_space from being called twice
  • Loading branch information
iskra-vitaly committed Mar 10, 2019
2 parents 6613b24 + 54e32ab commit 835f2b0
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 9 deletions.
30 changes: 26 additions & 4 deletions doc/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,32 @@ Checking the type of torch.Tensor without loading the data:
local factory = myFile:read('/path/to/data'):getTensorFactory()
myFile:close()

### Reading HDF5 file from multiple threads

If you want to use HDF5 from multiple threads, you will need a thread-safe build of the underlying HDF5 library. Otherwise, you will get random crashes. See the [HDF5 docs](https://support.hdfgroup.org/ftp/HDF5/current18/src/unpacked/release_docs/INSTALL) for how to build a thread-safe version.

If you want to do this from torch you will also need to install torch [threads](https://github.com/torch/threads). Then you can

local mainfile = hdf5.open('/path/to/read.h5','r')
local nthreads = 2
local data = nil
local worker = function(h5file)
torch.setnumthreads(1)
print(__threadid)
return h5file:read("data" .. __threadid):all()
end
local pool = threads.Threads(nthreads, function(threadid) require'torch' require'hdf5'end)
pool:specific(true)

for i=1,nthreads do
pool:addjob(i, worker, function(_data) data = _data end, mainfile)
end
for i=1,nthreads do
pool:dojob()
print(data:size(1)==10)
end
mainfile:close()

## Command-line

There are also a number of handy command-line tools.
Expand Down Expand Up @@ -150,7 +176,3 @@ See [this page](http://www.hdfgroup.org/HDF5/doc/RM/Tools.html) for many more HD
## Elsewhere

Libraries for many other languages and tools exist, too. See [this list](http://en.wikipedia.org/wiki/Hierarchical_Data_Format#Interfaces) for more information.

## Thread-safety

If you want to use HDF5 from multiple threads, you will need a thread-safe build of the underlying HDF5 library. Otherwise, you will get random crashes. See the [HDF5 docs](https://www.hdfgroup.org/hdf5-quest.html#tsafe) for how to build a thread-safe version.
4 changes: 2 additions & 2 deletions luasrc/dataset.lua
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ local function createTensorDataspace(tensor)
return dataspaceID
end

function HDF5DataSet:__init(parent, datasetID)
function HDF5DataSet:__init(parent, datasetID, dataspaceID)
assert(parent)
assert(datasetID)
self._parent = parent
self._datasetID = datasetID
self._dataspaceID = hdf5.C.H5Dget_space(self._datasetID)
self._dataspaceID = dataspaceID or hdf5.C.H5Dget_space(self._datasetID)
hdf5._logger.debug("Initialising " .. tostring(self))
end

Expand Down
9 changes: 6 additions & 3 deletions luasrc/ffi.lua
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,12 @@ local function loadHDF5Header(includePath)
if headerPath == nil or not path.isfile(headerPath) then
error("Error: unable to locate HDF5 header file at " .. headerPath)
end
local process = io.popen("gcc -D '_Nullable=' -E " .. headerPath) -- TODO pass -I
local process = io.popen("gcc -D '_Nullable=' -E " .. headerPath .. " -I " .. includePath)
local contents = process:read("*all")
process:close()
local success, errorMsg, returnCode = process:close()
if returnCode ~= 0 then
error("Error: unable to parse HDF5 header file at " .. headerPath)
end

-- Strip out the extra junk that GCC returns
local cdef = ""
Expand Down Expand Up @@ -231,7 +234,7 @@ hdf5.H5F_OBJ_LOCAL = 0x0020 -- Restrict search to objects opened through curr

hdf5.H5P_DEFAULT = 0
hdf5.H5S_ALL = 0
hdf5.H5F_UNLIMITED = ffi.new('hsize_t', -1)
hdf5.H5F_UNLIMITED = ffi.new('hsize_t', ffi.cast('hssize_t',-1))
hdf5.H5S_SELECT_SET = 0

-- This table specifies which exact format a given type of Tensor should be saved as.
Expand Down
15 changes: 15 additions & 0 deletions luasrc/file.lua
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@ function HDF5File:__init(filename, fileID)
end
end

function HDF5File.__write(object, self)
local var = {}
for k,v in pairs(object) do
var[k] = v
end
self:writeObject(var, torch.typename(object), hook)
end

function HDF5File.__read(object, self, versionNumber)
local var = self:readObject()
for k,v in pairs(var) do
object[k] = v
end
end

function HDF5File:filename()
return self._filename
end
Expand Down
15 changes: 15 additions & 0 deletions luasrc/group.lua
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ function HDF5Group:__init(parent, groupID)
callback:free()
end

function HDF5Group.__write(object, self)
local var = {}
for k,v in pairs(object) do
var[k] = v
end
self:writeObject(var, torch.typename(object), hook)
end

function HDF5Group.__read(object, self, versionNumber)
local var = self:readObject()
for k,v in pairs(var) do
object[k] = v
end
end

function HDF5Group:__tostring()
return "[HDF5Group " .. self._groupID .. " " .. hdf5._getObjectName(self._groupID) .. "]"
end
Expand Down
40 changes: 40 additions & 0 deletions tests/testSerialization.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
--[[
Test torch serialization.
]]
require 'hdf5'

local totem = require 'totem'
local tester = totem.Tester()
local myTests = {}
local testUtils = hdf5._testUtils

-- Lua 5.2 compatibility
local unpack = unpack or table.unpack

function myTests:testSerialization()
testUtils.withTmpDir(function(tmpDir)
local h5filename = path.join(tmpDir, "foo.h5")
local h5file = hdf5.open(h5filename)
local data = torch.zeros(7, 5)
h5file:write("data", data)
local memfile = torch.MemoryFile()
memfile:binary()
memfile:writeObject(h5file)
local storage = memfile:storage()
memfile:close()

local stofile = torch.MemoryFile(storage)
stofile:binary()
local memh5file = stofile:readObject()
stofile:close()
local memdata = memh5file:read("data"):all()

memh5file:close()

tester:assert(data:eq(memdata):sum() == 7*5)
end)
end

return tester:add(myTests):run()

0 comments on commit 835f2b0

Please sign in to comment.