forked from torch/torch7
-
Notifications
You must be signed in to change notification settings - Fork 0
/
init.lua
96 lines (81 loc) · 2.29 KB
/
init.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
-- We are using paths.require to appease mkl
-- Make this work with LuaJIT in Lua 5.2 compatibility mode, which
-- renames string.gfind (already deprecated in 5.1)
if not string.gfind then
string.gfind = string.gmatch
end
require "paths"
paths.require "libtorch"
--- package stuff
function torch.packageLuaPath(name)
if not name then
local ret = string.match(torch.packageLuaPath('torch'), '(.*)/')
if not ret then --windows?
ret = string.match(torch.packageLuaPath('torch'), '(.*)\\')
end
return ret
end
for path in string.gmatch(package.path, "[^;]+") do
path = string.gsub(path, "%?", name)
local f = io.open(path)
if f then
f:close()
local ret = string.match(path, "(.*)/")
if not ret then --windows?
ret = string.match(path, "(.*)\\")
end
return ret
end
end
end
function include(file, depth)
paths.dofile(file, 3 + (depth or 0))
end
function torch.include(package, file)
dofile(torch.packageLuaPath(package) .. '/' .. file)
end
function torch.class(tname, parenttname)
local function constructor(...)
local self = {}
torch.setmetatable(self, tname)
if self.__init then
self:__init(...)
end
return self
end
local function factory()
local self = {}
torch.setmetatable(self, tname)
return self
end
local mt = torch.newmetatable(tname, parenttname, constructor, nil, factory)
local mpt
if parenttname then
mpt = torch.getmetatable(parenttname)
end
return mt, mpt
end
function torch.setdefaulttensortype(typename)
assert(type(typename) == 'string', 'string expected')
if torch.getconstructortable(typename) then
torch.Tensor = torch.getconstructortable(typename)
torch.Storage = torch.getconstructortable(torch.typename(torch.Tensor(1):storage()))
else
error(string.format("<%s> is not a string describing a torch object", typename))
end
end
function torch.type(obj)
local class = torch.typename(obj)
if not class then
class = type(obj)
end
return class
end
torch.setdefaulttensortype('torch.DoubleTensor')
include('Tensor.lua')
include('File.lua')
include('CmdLine.lua')
include('FFI.lua')
include('Tester.lua')
include('test.lua')
return torch