Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for chars and strings of chars #60

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions luasrc/dataset.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ local unpack = unpack or table.unpack
local HDF5DataSet = torch.class("hdf5.HDF5DataSet")

--[[ Get the sizes and max sizes of an HDF5 dataspace, returning them in Lua tables ]]
local function getDataspaceSize(nDims, spaceID)
local function getDataspaceSize(nDims, spaceID, datasetID)
local size_t = hdf5.ffi.typeof("hsize_t[" .. nDims .. "]")
local dims = size_t()
local maxDims = size_t()
Expand All @@ -19,6 +19,12 @@ local function getDataspaceSize(nDims, spaceID)
size[k] = tonumber(dims[k-1])
maxSize[k] = tonumber(maxDims[k-1])
end

local typeID = hdf5.C.H5Dget_type(datasetID)
if hdf5._datatypeName(typeID) == 'STRING' then
size[nDims+1] = tonumber(hdf5.C.H5Tget_size(typeID))
end

return size, maxSize
end

Expand Down Expand Up @@ -65,7 +71,7 @@ function HDF5DataSet:all()

-- Create a new tensor of the correct type and size
local nDims = hdf5.C.H5Sget_simple_extent_ndims(self._dataspaceID)
local size = getDataspaceSize(nDims, self._dataspaceID)
local size = getDataspaceSize(nDims, self._dataspaceID, self._datasetID)
local factory, nativeType = self:getTensorFactory()

local tensor = factory():resize(unpack(size))
Expand Down Expand Up @@ -177,6 +183,6 @@ end

function HDF5DataSet:dataspaceSize()
local nDims = hdf5.C.H5Sget_simple_extent_ndims(self._dataspaceID)
local size = getDataspaceSize(nDims, self._dataspaceID)
local size = getDataspaceSize(nDims, self._dataspaceID, self._datasetID)
return size
end
7 changes: 6 additions & 1 deletion luasrc/ffi.lua
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ addConstants('h5t', {
'VLEN',
'ARRAY',
'NCLASSES',
'SGN_NONE',
'SGN_2',
}, addH5t)
local function addG(x) return addH5t(x) .. "_g" end

Expand Down Expand Up @@ -310,7 +312,8 @@ function hdf5._getTorchType(typeID)
local size = tonumber(hdf5.C.H5Tget_size(typeID))
if className == 'INTEGER' then
if size == 1 then
return 'torch.ByteTensor'
local signed = hdf5.C.H5Tget_sign(typeID) == hdf5.h5t.SGN_2
return signed and 'torch.CharTensor' or 'torch.ByteTensor'
end
if size == 2 then
return 'torch.ShortTensor'
Expand All @@ -330,6 +333,8 @@ function hdf5._getTorchType(typeID)
return 'torch.DoubleTensor'
end
error("Cannot support reading float data with size = " .. size .. " bytes")
elseif className == 'STRING' then
return 'torch.CharTensor'

else
error("Reading data of class " .. tostring(className) .. "(" .. typeID .. ") is unsupported")
Expand Down
2 changes: 0 additions & 2 deletions tests/testData.lua
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,12 @@ local function intTensorEqual(typename, a, b)
return a:add(-b):apply(function(x) return math.abs(tonumber(x)) end):sum() == 0
end

--[[ Not supported yet
function myTests:testCharTensor()
local k = 0
local testData = torch.CharTensor(4, 6):apply(function() k = k + 1; return k end)
local got = writeAndReread(testData)
tester:assert(intTensorEqual("torch.CharTensor", got, testData), "Data read does not match data written!")
end
]]

function myTests:testByteTensor()
local k = 0
Expand Down