Skip to content

Commit ca47a9a

Browse files
committed
Simplification of icn opt
1 parent f0f5f11 commit ca47a9a

File tree

3 files changed

+7
-20
lines changed

3 files changed

+7
-20
lines changed

src/icn.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -122,18 +122,13 @@ function _compose(icn::ICN)
122122
end
123123
end
124124

125-
l = length(funcs[1])
126-
127-
composition = (x; X=zeros(length(x), l), param=nothing, dom_size) -> if l == 1
128-
x |> (y -> funcs[1][1](y; param)) |> funcs[3][1] |>
129-
(y -> funcs[4][1](y; param, dom_size, nvars=length(x)))
130-
else
131-
fill!(@view(X[1:length(x), 1:l]), 0.0)
125+
function composition(x; X=zeros(length(x), length(funcs[1])), param=nothing, dom_size)
132126
tr_in(Tuple(funcs[1]), X, x, param)
133127
for i in 1:length(x)
134128
X[i,1] = funcs[2][1](@view X[i,:])
135129
end
136-
funcs[3][1](@view X[:, 1]) |> (y -> funcs[4][1](y; param, dom_size, nvars=length(x)))
130+
funcs[3][1](@view X[:, 1]) |>
131+
(y -> funcs[4][1](y; param, dom_size, nvars=length(x)))
137132
end
138133

139134
return composition, symbols

src/learn.jl

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -111,24 +111,16 @@ function compose_to_string(symbols, name)
111111
ag = reduce_symbols(symbols[3], ", ", false; prefix=CN * "ag_")
112112
co = reduce_symbols(symbols[4], ", ", false; prefix=CN * "co_")
113113

114-
return if tr_length == 1
115-
"""
116-
function $name(x; X = zeros(length(x), $tr_length), param=nothing, dom_size)
117-
x |> (y -> $tr[1](y; param)) |> $ag |> (y -> $co(y; param, dom_size, nvars=length(x)))
118-
end
119-
"""
120-
else
121-
"""
114+
output = """
122115
function $name(x; X = zeros(length(x), $tr_length), param=nothing, dom_size)
123-
fill!(@view(X[1:length(x), 1:$tr_length]), 0.0)
124116
$(CN)tr_in(Tuple($tr), X, x, param)
125117
for i in 1:length(x)
126118
X[i,1] = $ar(@view X[i,:])
127119
end
128120
return $ag(@view X[:, 1]) |> (y -> $co(y; param, dom_size, nvars=length(x)))
129121
end
130122
"""
131-
end
123+
return output
132124
end
133125

134126
"""

test/layers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ funcs_param_dom = [
164164
for (f, results) in funcs_param_dom
165165
@info f
166166
for (key, vals) in enumerate(data)
167-
@info "Updated" f(vals.first, param=vals.second[1], dom_size=vals.second[2]) results key
167+
# @info "Updated" f(vals.first, param=vals.second[1], dom_size=vals.second[2]) results key
168168
@test f(vals.first, param=vals.second[1], dom_size=vals.second[2]) results[key]
169169
end
170170
end
@@ -176,7 +176,7 @@ funcs_dom = [
176176
for (f, results) in funcs_dom
177177
@info f
178178
for (key, vals) in enumerate(data)
179-
@info "Updated" f(vals.first, dom_size=vals.second[2]) results key
179+
# @info "Updated" f(vals.first, dom_size=vals.second[2]) results key
180180
@test f(vals.first, dom_size=vals.second[2]) results[key]
181181
end
182182
end

0 commit comments

Comments
 (0)