diff --git a/src/transforms/map.jl b/src/transforms/map.jl index b06b9a08..ba12915a 100644 --- a/src/transforms/map.jl +++ b/src/transforms/map.jl @@ -48,11 +48,14 @@ Map() = throw(ArgumentError("cannot create Map transform without arguments")) const TargetName = Union{Symbol,AbstractString} const PairWithTarget = Pair{<:Any,<:Pair{<:Function,<:TargetName}} const PairWithoutTarget = Pair{<:Any,<:Function} -const MapPair = Union{PairWithTarget,PairWithoutTarget} +const PairFunctionTarget = Pair{<:Function,<:TargetName} +const MapPair = Union{PairWithTarget,PairWithoutTarget,PairFunctionTarget,Function} # utility functions _extract(p::PairWithTarget) = selector(first(p)), first(last(p)), Symbol(last(last(p))) _extract(p::PairWithoutTarget) = selector(first(p)), last(p), nothing +_extract(p::PairFunctionTarget) = AllSelector(), first(p), Symbol(last(p)) +_extract(p::Function) = AllSelector(), p, nothing function Map(pairs::MapPair...) tuples = map(_extract, pairs) @@ -93,6 +96,10 @@ function applyfeat(transform::Map, feat, prep) mapped = map(selectors, funs, targets) do selector, fun, target snames = selector(names) newname = isnothing(target) ? _makename(snames, fun) : target + if selector isa AllSelector + newcolumn = map(fun, Tables.rows(cols)) + return newname => newcolumn + end scolumns = (Tables.getcolumn(cols, nm) for nm in snames) newcolumn = map(fun, scolumns...) newname => newcolumn diff --git a/test/transforms/map.jl b/test/transforms/map.jl index b19ec36c..797abffc 100644 --- a/test/transforms/map.jl +++ b/test/transforms/map.jl @@ -98,4 +98,33 @@ # error: cannot create Map transform without arguments @test_throws ArgumentError Map() + + # row functions + ## no target + frow = row -> row.a + row.b - row.c + fname = replace(string(frow), "#" => "f") + colname = Symbol(fname, :_a,:_b,:_c,:_d) + T = Map(frow) + n, c = apply(T, t) + @test Tables.schema(n).names == (:a, :b, :c, :d, colname) + @test Tables.getcolumn(n, colname) == frow.(t) + + ## no target with extra functions + T = Map(frow, :a => (a->a) => :A) + n, c = apply(T, t) + Tables.schema(n).names == (:a, :b, :c, :d, colname,:A) + Tables.getcolumn(n, colname) == frow.(t) + + ## target column + T = Map((row -> sum(row)) => :summation) + n, c = apply(T, t) + @test Tables.schema(n).names == (:a, :b, :c, :d, :summation) + @test map(row->sum(row),t) == n.summation + + ## target column with extra function + T = Map((row -> row.a + row.b) => :a_plus_b, :a => (a -> a) => :A) + n, c = apply(T, t) + @test Tables.schema(n).names == (:a, :b, :c, :d, :a_plus_b,:A) + @test map(row->row.a + row.b,t) == n.a_plus_b + end