Skip to content

Commit facd997

Browse files
authored
Merge pull request #235 from JuliaML/sample
Refactor `Sample` implementation: avoid `Tables.rowtable`
2 parents a48458a + c21ed50 commit facd997

File tree

1 file changed

+20
-19
lines changed

1 file changed

+20
-19
lines changed

src/transforms/sample.jl

+20-19
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ isrevertible(::Type{<:Sample}) = true
4444

4545
function preprocess(transform::Sample, feat)
4646
# retrieve valid indices
47-
rows = Tables.rowtable(feat)
48-
inds = 1:length(rows)
47+
inds = 1:_nrows(feat)
4948

5049
size = transform.size
5150
weights = transform.weights
@@ -65,36 +64,38 @@ function preprocess(transform::Sample, feat)
6564
end
6665

6766
function applyfeat(::Sample, feat, prep)
68-
# collect all rows
69-
rows = Tables.rowtable(feat)
70-
7167
# preprocessed indices
7268
sinds, rinds = prep
7369

74-
# select rows
75-
srows = view(rows, sinds)
76-
rrows = view(rows, rinds)
70+
# selected and removed rows
71+
srows = Tables.subset(feat, sinds)
72+
rrows = Tables.subset(feat, rinds)
7773

7874
newfeat = srows |> Tables.materializer(feat)
79-
8075
newfeat, (sinds, rinds, rrows)
8176
end
8277

8378
function revertfeat(::Sample, newfeat, fcache)
84-
# collect all rows
85-
rows = Tables.rowtable(newfeat)
86-
79+
cols = Tables.columns(newfeat)
80+
names = Tables.columnnames(cols)
8781
sinds, rinds, rrows = fcache
8882

89-
uinds = sort(unique(sinds))
90-
urows = map(uinds) do i
91-
j = findfirst(==(i), sinds)
92-
rows[j]
83+
# columns with selected rows in original order
84+
uinds = indexin(sort(unique(sinds)), sinds)
85+
columns = map(names) do name
86+
y = Tables.getcolumn(cols, name)
87+
[y[i] for i in uinds]
9388
end
9489

95-
for (i, row) in zip(rinds, rrows)
96-
insert!(urows, i, row)
90+
# insert removed rows into columns
91+
rrcols = Tables.columns(rrows)
92+
for (name, x) in zip(names, columns)
93+
r = Tables.getcolumn(rrcols, name)
94+
for (i, v) in zip(rinds, r)
95+
insert!(x, i, v)
96+
end
9797
end
9898

99-
urows |> Tables.materializer(newfeat)
99+
𝒯 = (; zip(names, columns)...)
100+
𝒯 |> Tables.materializer(newfeat)
100101
end

0 commit comments

Comments
 (0)