Skip to content

Commit

Permalink
Fix revert of logratio transforms (#296)
Browse files Browse the repository at this point in the history
  • Loading branch information
juliohm authored Jan 22, 2025
1 parent 8758da8 commit 1ab1fe9
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 30 deletions.
48 changes: 29 additions & 19 deletions src/transforms/logratio.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,56 +23,66 @@ assertions(::LogRatio) = [scitypeassert(Continuous)]

function applyfeat(transform::LogRatio, feat, prep)
cols = Tables.columns(feat)
onames = Tables.columnnames(cols)
varnames = collect(onames)
names = Tables.columnnames(cols)
vars = collect(names)

# perform closure for full revertibility
cfeat, ccache = apply(Closure(), feat)

# reference variable
rvar = refvar(transform, varnames)
_assert(rvar varnames, "invalid reference variable")
rind = findfirst(==(rvar), varnames)
rvar = refvar(transform, vars)
_assert(rvar vars, "invalid reference variable")

# reference index
rind = findfirst(==(rvar), vars)

# permute columns if necessary
perm = rind lastindex(varnames)
perm = rind lastindex(vars)
pfeat = if perm
popat!(varnames, rind)
push!(varnames, rvar)
feat |> Select(varnames)
popat!(vars, rind)
push!(vars, rvar)
cfeat |> Select(vars)
else
feat
cfeat
end

# apply transform
X = Tables.matrix(pfeat)
Y = applymatrix(transform, X)

# new variable names
newnames = newvars(transform, varnames)
newnames = newvars(transform, vars)

# return same table type
𝒯 = (; zip(newnames, eachcol(Y))...)
newfeat = 𝒯 |> Tables.materializer(feat)

newfeat, (rind, perm, onames)
newfeat, (ccache, perm, rind, vars)
end

function revertfeat(transform::LogRatio, newfeat, fcache)
# retrieve cache
ccache, perm, rind, vars = fcache

# revert transform
Y = Tables.matrix(newfeat)
X = revertmatrix(transform, Y)

# retrieve cache
rind, perm, onames = fcache
pfeat = (; zip(vars, eachcol(X))...)

# revert the permutation if necessary
if perm
n = length(onames)
cfeat = if perm
n = length(vars)
inds = collect(1:(n - 1))
insert!(inds, rind, n)
X = X[:, inds]
pfeat |> Select(inds)
else
pfeat
end

# revert closure for full revertibility
𝒯 = revert(Closure(), cfeat, ccache)

# return same table type
𝒯 = (; zip(onames, eachcol(X))...)
𝒯 |> Tables.materializer(newfeat)
end

Expand Down
36 changes: 25 additions & 11 deletions test/transforms/logratio.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,22 @@
n, c = apply(T, t)
@test Tables.schema(n).names == (:ARL1, :ARL2)
@test n == t |> ALR(:c)
talr = revert(T, n, c)
r = revert(T, n, c)
@test Tables.matrix(r) Tables.matrix(t)

T = CLR()
n, c = apply(T, t)
@test Tables.schema(n).names == (:CLR1, :CLR2, :CLR3)
tclr = revert(T, n, c)
r = revert(T, n, c)
@test Tables.matrix(r) Tables.matrix(t)

T = ILR()
n, c = apply(T, t)
@test Tables.schema(n).names == (:ILR1, :ILR2)
@test n == t |> ILR(:c)
tilr = revert(T, n, c)
@test Tables.matrix(talr) Tables.matrix(tclr)
@test Tables.matrix(tclr) Tables.matrix(tilr)
@test Tables.matrix(talr) Tables.matrix(tilr)
r = revert(T, n, c)
@test Tables.matrix(r) Tables.matrix(t)

# permute columns
a = [1.0, 0.0, 1.0]
b = [2.0, 2.0, 2.0]
c = [3.0, 3.0, 0.0]
Expand All @@ -35,10 +36,23 @@

T = ALR(:c)
n1, c1 = apply(T, t1)
r1 = revert(T, n1, c1)
n2, c2 = apply(T, t2)
r2 = revert(T, n2, c2)
@test n1 == n2
@test Tables.matrix(r1) Tables.matrix(t1)
@test Tables.schema(r1).names == (:a, :c, :b)
@test Tables.matrix(r2) Tables.matrix(t2)
@test Tables.schema(r2).names == (:c, :a, :b)

T = ILR(:c)
n1, c1 = apply(T, t1)
r1 = revert(T, n1, c1)
n2, c2 = apply(T, t2)
r2 = revert(T, n2, c2)
@test n1 == n2
tₒ = revert(T, n1, c1)
@test Tables.schema(tₒ).names == (:a, :c, :b)
tₒ = revert(T, n2, c2)
@test Tables.schema(tₒ).names == (:c, :a, :b)
@test Tables.matrix(r1) Tables.matrix(t1)
@test Tables.schema(r1).names == (:a, :c, :b)
@test Tables.matrix(r2) Tables.matrix(t2)
@test Tables.schema(r2).names == (:c, :a, :b)
end

0 comments on commit 1ab1fe9

Please sign in to comment.