Skip to content

Commit

Permalink
Cleaned up Code
Browse files Browse the repository at this point in the history
  • Loading branch information
vamshiwmd committed Jul 19, 2023
1 parent e86767f commit e167dbe
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ private[offline] object SlidingWindowFeatureUtils {
case AggregationType.BUCKETED_COUNT_DISTINCT => new DummyAggregate(featureDef)
case AggregationType.BUCKETED_SUM => new DummyAggregate(featureDef)
case AggregationType.DISTINCT =>
// val rewrittenDef = s"CASE WHEN ${featureDef} IS NOT NULL THEN array(${featureDef}) ELSE NULL END "
val rewrittenDef = s"CASE WHEN ${featureDef} IS NOT NULL THEN array(${featureDef}) ELSE NULL END "
new DistinctAggregate(rewrittenDef)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,10 @@ class DistinctAggregate(val metricCol: String) extends AggregationSpec {


override def calculateAggregate(aggregate: Any, dataType: DataType): Any = {
//print("dataType" + dataType)
println("dataType class: "+ dataType.getClass)
//println("aggregate class: "+ aggregate.getClass)
if (aggregate == null) {
aggregate
} else {
dataType match {
// case IntegerType => aggregate.asInstanceOf[Set[Int]]
// case LongType => aggregate.asInstanceOf[Set[Long]]
// case DoubleType => aggregate.asInstanceOf[Set[Double]]
// case FloatType => aggregate.asInstanceOf[Set[Float]]
// case StringType => aggregate.asInstanceOf[Set[String]]
case ArrayType(IntegerType, false) => aggregate
case ArrayType(LongType, false) => aggregate
case ArrayType(DoubleType, false) => aggregate
Expand All @@ -46,90 +38,50 @@ class DistinctAggregate(val metricCol: String) extends AggregationSpec {
}
}

// record is what we get from SlidingWindowFeatureUtils. Aggregate data type is what we return here for first time
/*
Record is what we get from SlidingWindowFeatureUtils. Aggregate is what we return here for first time.
The datatype of both should match. This is a limitation of Feathr
*/
override def agg(aggregate: Any, record: Any, dataType: DataType): Any = {
//print("dataType" + dataType)
if (dataType != null) {
println("dataType class: "+ dataType.getClass)
}
if(aggregate != null) {
println("aggregate class: "+ aggregate.getClass)
aggregate match {
case set: Set[_] =>
set.foreach(println)
case wrappedArray: scala.collection.mutable.WrappedArray[_] =>
wrappedArray.foreach(println)
case _ =>
print(" ")
}
}
if(record != null) {
println("record class class: "+ record.getClass)
println("record " + record)
record match {
case wrappedArray: scala.collection.mutable.WrappedArray[_] =>
wrappedArray.foreach(println)
case _ =>
print(" ")
}
}
if (aggregate == null) {
val wrappedArray = record.asInstanceOf[mutable.WrappedArray[Int]]
return ArrayBuffer(wrappedArray: _*)
}
else if (record == null) {
} else if (record == null) {
aggregate
} else {
dataType match {
// case IntegerType => aggregate.asInstanceOf[Set[Int]] + record.asInstanceOf[Int]
// case LongType => aggregate.asInstanceOf[Set[Long]] + record.asInstanceOf[Long]
// case DoubleType => aggregate.asInstanceOf[Set[Double]] + record.asInstanceOf[Double]
// case FloatType => aggregate.asInstanceOf[Set[Float]] + record.asInstanceOf[Float]
// case StringType=> aggregate.asInstanceOf[Set[String]] + record.asInstanceOf[String]

case ArrayType(IntegerType, false) =>
// val set1 = aggregate.asInstanceOf[Set[Int]]
print("Testing: ")
// val set1 = aggregate.asInstanceOf[mutable.WrappedArray[Int]].toSet
// val set2 = record.asInstanceOf[mutable.WrappedArray[Int]].toSet
// val test1 = aggregate.asInstanceOf[mutable.WrappedArray[Int]].toArray
// val set1 = aggregate.asInstanceOf[mutable.WrappedArray[Int]].toArray.toSet
val set1 = aggregate.asInstanceOf[mutable.ArrayBuffer[Int]].toSet
val set2 = record.asInstanceOf[mutable.WrappedArray[Int]].toArray.toSet

val set3 = set1.union(set2)
// val new_aggregate = mutable.WrappedArray.make(set3.toArray)
// val new_aggregate = set3.toArray
val new_aggregate = ArrayBuffer(set3.toSeq: _*)
return new_aggregate

case ArrayType(LongType, false) =>
val set1 = aggregate.asInstanceOf[mutable.WrappedArray[Long]]toSet
val set2 = record.asInstanceOf[mutable.WrappedArray[Long]]toSet
val set3 = set1.union(set2)
val new_aggregate = mutable.WrappedArray.make(set3)
//val new_aggregate = set3.toArray
val new_aggregate = ArrayBuffer(set3.toSeq: _*)
return new_aggregate

case ArrayType(DoubleType, false) =>
val set1 = aggregate.asInstanceOf[mutable.WrappedArray[Double]].toSet
val set2 = record.asInstanceOf[mutable.WrappedArray[Double]].toSet
val set3 = set1.union(set2)
val new_aggregate = mutable.WrappedArray.make(set3.toArray)
val new_aggregate = ArrayBuffer(set3.toSeq: _*)
return new_aggregate

case ArrayType(FloatType, false) =>
val set1 = aggregate.asInstanceOf[mutable.WrappedArray[Float]].toSet
val set2 = record.asInstanceOf[mutable.WrappedArray[Float]].toSet
val set3 = set1.union(set2)
val new_aggregate = mutable.WrappedArray.make(set3)
val new_aggregate = ArrayBuffer(set3.toSeq: _*)
return new_aggregate

case ArrayType(StringType, false) =>
val set1 = aggregate.asInstanceOf[mutable.ArrayBuffer[String]].toSet
val set2 = record.asInstanceOf[mutable.WrappedArray[String]].toArray.toSet
val set3 = set1.union(set2)
//val new_aggregate = mutable.WrappedArray.make(set3)
val new_aggregate = ArrayBuffer(set3.toSeq: _*)
return new_aggregate

Expand All @@ -143,23 +95,4 @@ class DistinctAggregate(val metricCol: String) extends AggregationSpec {
throw new RuntimeException("Method deagg for DISTINCT aggregate is not implemented because DISTINCT is " +
"not an incremental aggregation.")
}
}



















}
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ object AssertFeatureUtils {
def validateComplexRows(actualRows: Array[Row], expectedRows: Array[GenericRowWithSchema]): Unit = {
assertNotNull(actualRows)
assertEquals(actualRows.length, expectedRows.length)
// val expected_modified = expectedRows.map(row => Row(row.get(0), row.get(1).asInstanceOf[Array[Float]].toArray))
for ((actual, expected) <- actualRows zip expectedRows) {
for( (field, index) <- expected.schema.fields.zipWithIndex) {
val actualValue = actual.get(index)
Expand All @@ -199,7 +198,7 @@ object AssertFeatureUtils {
} else if(field.dataType == ArrayType(StringType, false)) {
assertStringArrayEquals(actualValue.asInstanceOf[Array[String]], expectedValue.asInstanceOf[Array[String]])
} else {
// Unsupported

}
}
}
Expand Down

0 comments on commit e167dbe

Please sign in to comment.