Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,8 @@ public fun <T, R> Pivot<T>.aggregate(separate: Boolean = false, body: Selector<A
@Refine
@Interpretable("Aggregate")
public fun <T, R> Grouped<T>.aggregate(body: AggregateGroupedBody<T, R>): DataFrame<T> =
aggregateGroupBy((this as GroupBy<*, *>).toDataFrame(), { groups.cast() }, removeColumns = true, body).cast<T>()
aggregateGroupBy(
df = (this as GroupBy<*, *>).toDataFrame(),
selector = { groups.cast() },
body = body,
).cast()
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import org.jetbrains.kotlinx.dataframe.aggregation.AggregateGroupedBody
import org.jetbrains.kotlinx.dataframe.aggregation.NamedValue
import org.jetbrains.kotlinx.dataframe.api.GroupBy
import org.jetbrains.kotlinx.dataframe.api.GroupedRowFilter
import org.jetbrains.kotlinx.dataframe.api.asFrameColumn
import org.jetbrains.kotlinx.dataframe.api.asGroupBy
import org.jetbrains.kotlinx.dataframe.api.cast
import org.jetbrains.kotlinx.dataframe.api.concat
import org.jetbrains.kotlinx.dataframe.api.convert
import org.jetbrains.kotlinx.dataframe.api.getColumn
Expand All @@ -18,6 +20,7 @@ import org.jetbrains.kotlinx.dataframe.api.isColumnGroup
import org.jetbrains.kotlinx.dataframe.api.pathOf
import org.jetbrains.kotlinx.dataframe.api.remove
import org.jetbrains.kotlinx.dataframe.api.rename
import org.jetbrains.kotlinx.dataframe.api.take
import org.jetbrains.kotlinx.dataframe.columns.FrameColumn
import org.jetbrains.kotlinx.dataframe.impl.aggregation.AggregatableInternal
import org.jetbrains.kotlinx.dataframe.impl.aggregation.GroupByReceiverImpl
Expand All @@ -27,8 +30,10 @@ import org.jetbrains.kotlinx.dataframe.impl.api.GroupedDataRowImpl
import org.jetbrains.kotlinx.dataframe.impl.api.insertImpl
import org.jetbrains.kotlinx.dataframe.impl.api.removeImpl
import org.jetbrains.kotlinx.dataframe.impl.columns.toColumnSet
import org.jetbrains.kotlinx.dataframe.impl.schema.createEmptyDataFrame
import org.jetbrains.kotlinx.dataframe.ncol
import org.jetbrains.kotlinx.dataframe.nrow
import org.jetbrains.kotlinx.dataframe.size
import org.jetbrains.kotlinx.dataframe.values

/**
Expand Down Expand Up @@ -70,18 +75,28 @@ internal class GroupByImpl<T, G>(
internal fun <T, G, R> aggregateGroupBy(
df: DataFrame<T>,
selector: ColumnSelector<T, DataFrame<G>?>,
removeColumns: Boolean,
body: AggregateGroupedBody<G, R>,
): DataFrame<T> {
val defaultAggregateName = "aggregated"

val groupedDfIsEmpty = df.size().nrow == 0
val column = df.getColumn(selector)

val removed = df.removeImpl(columns = selector)

val hasKeyColumns = removed.df.ncol > 0

val groupedFrame = column.values.map {
val groups =
if (groupedDfIsEmpty) {
// if the grouped dataframe is empty, make sure the provided AggregateGroupedBody is called at least once
// to create aggregated columns. We empty them below.
listOf(
column.asFrameColumn().schema.value
.createEmptyDataFrame()
.cast(),
)
} else {
column.values
}

val groupedFrame = groups.map {
if (it == null) {
null
} else {
Expand All @@ -101,17 +116,20 @@ internal fun <T, G, R> aggregateGroupBy(
builder.compute()
}
}.concat()
.let {
// empty the aggregated columns that were created by calling the provided AggregateGroupedBody once
// if the grouped dataframe is empty
if (groupedDfIsEmpty) it.take(0) else it
}

val removedNode = removed.removedColumns.single()
val insertPath = removedNode.pathFromRoot().dropLast(1)

if (!removeColumns) removedNode.data.wasRemoved = false

val columnsToInsert = groupedFrame.getColumnsWithPaths {
colsAtAnyDepth().filter { !it.isColumnGroup() }
}.map {
ColumnToInsert(insertPath + it.path, it, removedNode)
}
val src = if (removeColumns) removed.df else df
return src.insertImpl(columnsToInsert)

return removed.df.insertImpl(columnsToInsert)
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ import org.jetbrains.kotlinx.dataframe.api.fill
import org.jetbrains.kotlinx.dataframe.api.fillNulls
import org.jetbrains.kotlinx.dataframe.api.filter
import org.jetbrains.kotlinx.dataframe.api.first
import org.jetbrains.kotlinx.dataframe.api.firstOrNull
import org.jetbrains.kotlinx.dataframe.api.forEach
import org.jetbrains.kotlinx.dataframe.api.forEachIndexed
import org.jetbrains.kotlinx.dataframe.api.frameColumn
Expand Down Expand Up @@ -94,14 +95,17 @@ import org.jetbrains.kotlinx.dataframe.api.match
import org.jetbrains.kotlinx.dataframe.api.matches
import org.jetbrains.kotlinx.dataframe.api.max
import org.jetbrains.kotlinx.dataframe.api.maxBy
import org.jetbrains.kotlinx.dataframe.api.maxByOrNull
import org.jetbrains.kotlinx.dataframe.api.mean
import org.jetbrains.kotlinx.dataframe.api.meanFor
import org.jetbrains.kotlinx.dataframe.api.meanOf
import org.jetbrains.kotlinx.dataframe.api.median
import org.jetbrains.kotlinx.dataframe.api.medianOrNull
import org.jetbrains.kotlinx.dataframe.api.merge
import org.jetbrains.kotlinx.dataframe.api.min
import org.jetbrains.kotlinx.dataframe.api.minBy
import org.jetbrains.kotlinx.dataframe.api.minOf
import org.jetbrains.kotlinx.dataframe.api.minOrNull
import org.jetbrains.kotlinx.dataframe.api.minus
import org.jetbrains.kotlinx.dataframe.api.move
import org.jetbrains.kotlinx.dataframe.api.moveTo
Expand Down Expand Up @@ -710,6 +714,37 @@ class DataFrameTests : BaseTest() {
res.size() shouldBe 2
}

// Issue #1531
@Test
fun `groupBy empty df should generate empty aggregation cols`() {
val empty = typed.take(0)
val resDf = empty.groupBy { name }.aggregate {
count() into "n"
count { age > 25 } into "old count"
medianOrNull { age } into "median age"
minOrNull { age } into "min age"
all { weight != null } into "all with weights"
maxByOrNull { age }?.city into "oldest origin"
sortBy { age }.firstOrNull()?.city into "youngest origin"
pivot { city.map { "from $it" } }.count()
age.toList() into "ages"
}

resDf.columnNames() shouldBe listOf(
"name",
"n",
"old count",
"median age",
"min age",
"all with weights",
"oldest origin",
"youngest origin",
"ages",
)

resDf.alsoDebug()
}

@Test
fun `groupBy`() {
fun AnyFrame.check() {
Expand Down
Loading