Skip to content

Commit

Permalink
fix distinct aggregation bug
Browse files Browse the repository at this point in the history
  • Loading branch information
atangwbd committed Apr 4, 2024
1 parent 6847ef9 commit 9e98ae3
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 227 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,72 +21,37 @@ class DistinctAggregate(val metricCol: String) extends AggregationSpec {

override def isCalculateAggregateNeeded: Boolean = true


override def calculateAggregate(aggregate: Any, dataType: DataType): Any = {
if (aggregate == null) {
aggregate
} else {
dataType match {
case ArrayType(IntegerType, false) => aggregate
case ArrayType(LongType, false) => aggregate
case ArrayType(DoubleType, false) => aggregate
case ArrayType(FloatType, false) => aggregate
case ArrayType(StringType, false) => aggregate
case _ => throw new RuntimeException(s"Invalid data type for DISTINCT metric col $metricCol. " +
s"Only Array[Int], Array[Long], Array[Double], Array[Float] and Array[String] are supported, but got ${dataType.typeName}")
val result = dataType match {
case IntegerType => aggregate.asInstanceOf[Set[Int]]
case LongType => aggregate.asInstanceOf[Set[Long]]
case DoubleType => aggregate.asInstanceOf[Set[Double]]
case FloatType => aggregate.asInstanceOf[Set[Float]]
case StringType => aggregate.asInstanceOf[Set[String]]
case _ => throw new RuntimeException(s"Invalid data type for COUNT_DISTINCT metric col $metricCol. " +
s"Only Int, Long, Double, Float, and String are supported, but got ${dataType.typeName}")
}
result.mkString(",")
}
}

/*
Record is what we get from SlidingWindowFeatureUtils. Aggregate is what we return here for first time.
The datatype of both should match. This is a limitation of Feathr
*/
override def agg(aggregate: Any, record: Any, dataType: DataType): Any = {
override def agg(aggregate: Any, record: Any, dataType: DataType): Any = {
if (aggregate == null) {
val wrappedArray = record.asInstanceOf[mutable.WrappedArray[Int]]
return ArrayBuffer(wrappedArray: _*)
Set(record)
} else if (record == null) {
aggregate
} else {
dataType match {
case ArrayType(IntegerType, false) =>
val set1 = aggregate.asInstanceOf[mutable.ArrayBuffer[Int]].toSet
val set2 = record.asInstanceOf[mutable.WrappedArray[Int]].toArray.toSet
val set3 = set1.union(set2)
val new_aggregate = ArrayBuffer(set3.toSeq: _*)
return new_aggregate

case ArrayType(LongType, false) =>
val set1 = aggregate.asInstanceOf[mutable.WrappedArray[Long]]toSet
val set2 = record.asInstanceOf[mutable.WrappedArray[Long]]toSet
val set3 = set1.union(set2)
val new_aggregate = ArrayBuffer(set3.toSeq: _*)
return new_aggregate

case ArrayType(DoubleType, false) =>
val set1 = aggregate.asInstanceOf[mutable.WrappedArray[Double]].toSet
val set2 = record.asInstanceOf[mutable.WrappedArray[Double]].toSet
val set3 = set1.union(set2)
val new_aggregate = ArrayBuffer(set3.toSeq: _*)
return new_aggregate

case ArrayType(FloatType, false) =>
val set1 = aggregate.asInstanceOf[mutable.WrappedArray[Float]].toSet
val set2 = record.asInstanceOf[mutable.WrappedArray[Float]].toSet
val set3 = set1.union(set2)
val new_aggregate = ArrayBuffer(set3.toSeq: _*)
return new_aggregate

case ArrayType(StringType, false) =>
val set1 = aggregate.asInstanceOf[mutable.ArrayBuffer[String]].toSet
val set2 = record.asInstanceOf[mutable.WrappedArray[String]].toArray.toSet
val set3 = set1.union(set2)
val new_aggregate = ArrayBuffer(set3.toSeq: _*)
return new_aggregate

case IntegerType => aggregate.asInstanceOf[Set[Int]] + record.asInstanceOf[Int]
case LongType => aggregate.asInstanceOf[Set[Long]] + record.asInstanceOf[Long]
case DoubleType => aggregate.asInstanceOf[Set[Double]] + record.asInstanceOf[Double]
case FloatType => aggregate.asInstanceOf[Set[Float]] + record.asInstanceOf[Float]
case StringType=> aggregate.asInstanceOf[Set[String]] + record.asInstanceOf[String]
case _ => throw new RuntimeException(s"Invalid data type for DISTINCT metric col $metricCol. " +
s"Only Array[Int], Array[Long], Array[Double], Array[Float] and Array[String] are supported, but got ${dataType.typeName}")
s"Only Int, Long, Double, Float, and String are supported, but got ${dataType.typeName}")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1888,180 +1888,5 @@ class SlidingWindowAggIntegTest extends FeathrIntegTest {

validateRows(dfs.select(keyField, features: _*).collect().sortBy(row => row.getAs[Int](keyField)), expectedRows)
}

@Test
def testSWADistinctIntegers(): Unit = {
val featureDefAsString =
"""
|sources: {
| swaSource: {
| location: { path: "generation/daily/" }
| isTimeSeries: true
| timeWindowParameters: {
| timestampColumn: "timestamp"
| timestampColumnFormat: "yyyy-MM-dd"
| }
| }
|}
|anchors: {
| swaAnchorWithKeyExtractor: {
| source: "swaSource"
| key: [x]
| features: {
| g: {
| def: "Id" // the column that contains the raw view count
| aggregation: DISTINCT
| window: 10d
| }
| }
| }
|}
""".stripMargin

val features = Seq("g")
val keyField = "x"
val featureJoinAsString =
s"""
| settings: {
| joinTimeSettings: {
| timestampColumn: {
| def: timestamp
| format: yyyy-MM-dd
| }
| }
|}
|features: [
| {
| key: [$keyField],
| featureList: [${features.mkString(",")}]
| }
|]
""".stripMargin


/**
* Expected output:
* +--------+----+----+
* |x| f| g|
* +--------+----+----+
* | 1| 6| 2|
* | 2| 5| 2|
* | 3| 1| 1|
* +--------+----+----+
*/
val expectedSchema = StructType(
Seq(
StructField(keyField, LongType),
StructField(features.last, ArrayType(FloatType, false))
))
import scala.collection.mutable.WrappedArray
val expectedRows = Array(
new GenericRowWithSchema(Array(1, Array(10.0f, 11.0f)), expectedSchema),
new GenericRowWithSchema(Array(2, Array(10.0f, 11.0f)), expectedSchema),
new GenericRowWithSchema(Array(3, Array(9.0f)), expectedSchema)
)

val dfs = runLocalFeatureJoinForTest(featureJoinAsString, featureDefAsString, "featuresWithFilterObs.avro.json").data
val result = dfs.select(keyField, features: _*).collect().sortBy(row => row.getAs[Int](keyField))
val actualRows = result.map(row => Row(row.get(0), row.get(1).asInstanceOf[WrappedArray[Float]].toArray))
validateComplexRows(actualRows, expectedRows)

}

@Test
def testSWADistinctStrings(): Unit = {
val featureDefAsString =
"""
|sources: {
| swaSource: {
| location: { path: "generation/daily/" }
| isTimeSeries: true
| timeWindowParameters: {
| timestampColumn: "timestamp"
| timestampColumnFormat: "yyyy-MM-dd"
| }
| }
|}
|anchors: {
| swaAnchorWithKeyExtractor: {
| source: "swaSource"
| key: [x]
| features: {
| f: {
| def: "user" // the column that contains the user as string
| aggregation: DISTINCT
| window: 10d
| }
| }
| }
|}
""".stripMargin

val features = Seq("f")
val keyField = "x"
val featureJoinAsString =
s"""
| settings: {
| joinTimeSettings: {
| timestampColumn: {
| def: timestamp
| format: yyyy-MM-dd
| }
| }
|}
|features: [
| {
| key: [$keyField],
| featureList: [${features.mkString(",")}]
| }
|]
""".stripMargin


/**
* Expected output:
* +--------+----+----+
* |x| f| g|
* +--------+----+----+
* | 1| 6| 2|
* | 2| 5| 2|
* | 3| 1| 1|
* +--------+----+----+
*/
val expectedSchema = StructType(
Seq(
StructField(keyField, LongType),
StructField(features.last, ArrayType(StringType, false))
))
import scala.collection.mutable.WrappedArray
val expectedRows = Array(
new GenericRowWithSchema(Array(1, Array("user10", "user11")), expectedSchema),
new GenericRowWithSchema(Array(2, Array("user10", "user11")), expectedSchema),
new GenericRowWithSchema(Array(3, Array("user9")), expectedSchema)
)
val dfs = runLocalFeatureJoinForTest(featureJoinAsString, featureDefAsString, "featuresWithFilterObs.avro.json").data
val result = dfs.select(keyField, features: _*).collect().sortBy(row => row.getAs[Int](keyField))
//val col1 = result.map(row => Row(row.get(0)))
val col2_initial = result.map(row => Row(row.get(1)))

val actualRows = new Array[Row](3);
var count = 1

for (row <- col2_initial) {
val genericRow = row.get(0);
genericRow.getClass;
val genericRow2 = genericRow.asInstanceOf[GenericRowWithSchema];
val val1 = genericRow2.get(0);
val val2 = val1.asInstanceOf[WrappedArray[String]].toArray
print(val1)
val resultRow = Row(count, val2)
actualRows(count - 1) = resultRow
count += 1
}

validateComplexRows(actualRows, expectedRows)

}

}

0 comments on commit 9e98ae3

Please sign in to comment.