Skip to content

Commit

Permalink
Simplification of icn opt
Browse files Browse the repository at this point in the history
  • Loading branch information
Azzaare committed Jul 17, 2021
1 parent f0f5f11 commit ca47a9a
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 20 deletions.
11 changes: 3 additions & 8 deletions src/icn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,18 +122,13 @@ function _compose(icn::ICN)
end
end

l = length(funcs[1])

composition = (x; X=zeros(length(x), l), param=nothing, dom_size) -> if l == 1
x |> (y -> funcs[1][1](y; param)) |> funcs[3][1] |>
(y -> funcs[4][1](y; param, dom_size, nvars=length(x)))
else
fill!(@view(X[1:length(x), 1:l]), 0.0)
function composition(x; X=zeros(length(x), length(funcs[1])), param=nothing, dom_size)
tr_in(Tuple(funcs[1]), X, x, param)
for i in 1:length(x)
X[i,1] = funcs[2][1](@view X[i,:])
end
funcs[3][1](@view X[:, 1]) |> (y -> funcs[4][1](y; param, dom_size, nvars=length(x)))
funcs[3][1](@view X[:, 1]) |>
(y -> funcs[4][1](y; param, dom_size, nvars=length(x)))
end

return composition, symbols
Expand Down
12 changes: 2 additions & 10 deletions src/learn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,24 +111,16 @@ function compose_to_string(symbols, name)
ag = reduce_symbols(symbols[3], ", ", false; prefix=CN * "ag_")
co = reduce_symbols(symbols[4], ", ", false; prefix=CN * "co_")

return if tr_length == 1
"""
function $name(x; X = zeros(length(x), $tr_length), param=nothing, dom_size)
x |> (y -> $tr[1](y; param)) |> $ag |> (y -> $co(y; param, dom_size, nvars=length(x)))
end
"""
else
"""
output = """
function $name(x; X = zeros(length(x), $tr_length), param=nothing, dom_size)
fill!(@view(X[1:length(x), 1:$tr_length]), 0.0)
$(CN)tr_in(Tuple($tr), X, x, param)
for i in 1:length(x)
X[i,1] = $ar(@view X[i,:])
end
return $ag(@view X[:, 1]) |> (y -> $co(y; param, dom_size, nvars=length(x)))
end
"""
end
return output
end

"""
Expand Down
4 changes: 2 additions & 2 deletions test/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ funcs_param_dom = [
for (f, results) in funcs_param_dom
@info f
for (key, vals) in enumerate(data)
@info "Updated" f(vals.first, param=vals.second[1], dom_size=vals.second[2]) results key
# @info "Updated" f(vals.first, param=vals.second[1], dom_size=vals.second[2]) results key
@test f(vals.first, param=vals.second[1], dom_size=vals.second[2]) results[key]
end
end
Expand All @@ -176,7 +176,7 @@ funcs_dom = [
for (f, results) in funcs_dom
@info f
for (key, vals) in enumerate(data)
@info "Updated" f(vals.first, dom_size=vals.second[2]) results key
# @info "Updated" f(vals.first, dom_size=vals.second[2]) results key
@test f(vals.first, dom_size=vals.second[2]) results[key]
end
end

0 comments on commit ca47a9a

Please sign in to comment.