Skip to content

Commit

Permalink
Alternative aggregate functions to calculate histogram values. (#475)
Browse files Browse the repository at this point in the history
* Alternative aggregate functions to calculate histogram values.

* Reorder expected json

* Alternative aggregate functions to calculate histogram values.

* Alternative aggregate functions to calculate histogram values

* Alternative aggregate functions to calculate histogram values

---------

Co-authored-by: Aliaksei Kalotkin <[email protected]>
  • Loading branch information
akalotkin and Aliaksei Kalotkin authored Jul 7, 2023
1 parent f53283e commit d3dbe2d
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 17 deletions.
79 changes: 68 additions & 11 deletions src/main/scala/com/amazon/deequ/analyzers/Histogram.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,20 @@

package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Histogram.{AggregateFunction, Count}
import com.amazon.deequ.analyzers.runners.{IllegalAnalyzerParameterException, MetricCalculationException}
import com.amazon.deequ.metrics.{Distribution, DistributionValue, HistogramMetric}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.functions.{col, sum}
import org.apache.spark.sql.types.{DoubleType, LongType, StringType, StructType}
import org.apache.spark.sql.{DataFrame, Row}

import scala.util.{Failure, Try}

/**
* Histogram is the summary of values in a column of a DataFrame. Groups the given column's values,
* and calculates the number of rows with that specific value and the fraction of this value.
* and calculates either number of rows or with that specific value and the fraction of this value or
* sum of values in other column.
*
* @param column Column to do histogram analysis on
* @param binningUdf Optional binning function to run before grouping to re-categorize the
Expand All @@ -37,13 +40,15 @@ import scala.util.{Failure, Try}
* maxBins sets the N.
* This limit does not affect what is being returned as number of bins. It
* always returns the dictinct value count.
* @param aggregateFunction function that implements aggregation logic.
*/
case class Histogram(
column: String,
binningUdf: Option[UserDefinedFunction] = None,
maxDetailBins: Integer = Histogram.MaximumAllowedDetailBins,
where: Option[String] = None,
computeFrequenciesAsRatio: Boolean = true)
computeFrequenciesAsRatio: Boolean = true,
aggregateFunction: AggregateFunction = Count)
extends Analyzer[FrequenciesAndNumRows, HistogramMetric]
with FilterableAnalyzer {

Expand All @@ -58,19 +63,15 @@ case class Histogram(

// TODO figure out a way to pass this in if its known before hand
val totalCount = if (computeFrequenciesAsRatio) {
data.count()
aggregateFunction.total(data)
} else {
1
}

val frequencies = data
val df = data
.transform(filterOptional(where))
.transform(binOptional(binningUdf))
.select(col(column).cast(StringType))
.na.fill(Histogram.NullFieldReplacement)
.groupBy(column)
.count()
.withColumnRenamed("count", Analyzers.COUNT_COL)
val frequencies = query(df)

Some(FrequenciesAndNumRows(frequencies, totalCount))
}
Expand Down Expand Up @@ -125,11 +126,67 @@ case class Histogram(
case _ => data
}
}

private def query(data: DataFrame): DataFrame = {
aggregateFunction.query(this.column, data)
}
}

object Histogram {
val NullFieldReplacement = "NullValue"
val MaximumAllowedDetailBins = 1000
val count_function = "count"
val sum_function = "sum"

sealed trait AggregateFunction {
def query(column: String, data: DataFrame): DataFrame

def total(data: DataFrame): Long

def aggregateColumn(): Option[String]

def function(): String
}

case object Count extends AggregateFunction {
override def query(column: String, data: DataFrame): DataFrame = {
data
.select(col(column).cast(StringType))
.na.fill(Histogram.NullFieldReplacement)
.groupBy(column)
.count()
.withColumnRenamed("count", Analyzers.COUNT_COL)
}

override def aggregateColumn(): Option[String] = None

override def function(): String = count_function

override def total(data: DataFrame): Long = {
data.count()
}
}

case class Sum(aggColumn: String) extends AggregateFunction {
override def query(column: String, data: DataFrame): DataFrame = {
data
.select(col(column).cast(StringType), col(aggColumn).cast(LongType))
.na.fill(Histogram.NullFieldReplacement)
.groupBy(column)
.sum(aggColumn)
.withColumnRenamed("count", Analyzers.COUNT_COL)
}

override def total(data: DataFrame): Long = {
data.groupBy().sum(aggColumn).first().getLong(0)
}

override def aggregateColumn(): Option[String] = {
Some(aggColumn)
}

override def function(): String = sum_function
}
}

object OrderByAbsoluteCount extends Ordering[Row] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import scala.collection._
import scala.collection.JavaConverters._
import java.util.{ArrayList => JArrayList, HashMap => JHashMap, List => JList, Map => JMap}
import JsonSerializationConstants._
import com.amazon.deequ.analyzers.Histogram.{AggregateFunction, Count => HistogramCount, Sum => HistogramSum}
import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.expr

Expand Down Expand Up @@ -302,6 +303,12 @@ private[deequ] object AnalyzerSerializer
result.addProperty(ANALYZER_NAME_FIELD, "Histogram")
result.addProperty(COLUMN_FIELD, histogram.column)
result.addProperty("maxDetailBins", histogram.maxDetailBins)
// Count is initial and default implementation for Histogram
// We don't include fields below in json to preserve json backward compatibility.
if (histogram.aggregateFunction != Histogram.Count) {
result.addProperty("aggregateFunction", histogram.aggregateFunction.function())
result.addProperty("aggregateColumn", histogram.aggregateFunction.aggregateColumn().get)
}

case _ : Histogram =>
throw new IllegalArgumentException("Unable to serialize Histogram with binningUdf!")
Expand Down Expand Up @@ -436,7 +443,10 @@ private[deequ] object AnalyzerDeserializer
Histogram(
json.get(COLUMN_FIELD).getAsString,
None,
json.get("maxDetailBins").getAsInt)
json.get("maxDetailBins").getAsInt,
aggregateFunction = createAggregateFunction(
getOptionalStringParam(json, "aggregateFunction").getOrElse(Histogram.count_function),
getOptionalStringParam(json, "aggregateColumn").getOrElse("")))

case "DataType" =>
DataType(
Expand Down Expand Up @@ -489,12 +499,24 @@ private[deequ] object AnalyzerDeserializer
}

private[this] def getOptionalWhereParam(jsonObject: JsonObject): Option[String] = {
if (jsonObject.has(WHERE_FIELD)) {
Option(jsonObject.get(WHERE_FIELD).getAsString)
getOptionalStringParam(jsonObject, WHERE_FIELD)
}

private[this] def getOptionalStringParam(jsonObject: JsonObject, field: String): Option[String] = {
if (jsonObject.has(field)) {
Option(jsonObject.get(field).getAsString)
} else {
None
}
}

private[this] def createAggregateFunction(function: String, aggregateColumn: String): AggregateFunction = {
function match {
case Histogram.count_function => HistogramCount
case Histogram.sum_function => HistogramSum(aggregateColumn)
case _ => throw new IllegalArgumentException("Wrong aggregate function name: " + function)
}
}
}

private[deequ] object MetricSerializer extends JsonSerializer[Metric[_]] {
Expand Down
20 changes: 20 additions & 0 deletions src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,26 @@ class AnalyzerTests extends AnyWordSpec with Matchers with SparkContextSpec with
}
}

"compute correct sum metrics " in withSparkSession { sparkSession =>
val dfFull = getDateDf(sparkSession)
val histogram = Histogram("product", aggregateFunction = Histogram.Sum("units")).calculate(dfFull)
assert(histogram.value.isSuccess)

histogram.value.get match {
case hv =>
assert(hv.numberOfBins == 3)
assert(hv.values.size == 3)
assert(hv.values.keys == Set("Furniture", "Cosmetics", "Electronics"))
assert(hv("Furniture").absolute == 55)
assert(hv("Furniture").ratio == 55.0 / (55 + 20 + 60))
assert(hv("Cosmetics").absolute == 20)
assert(hv("Cosmetics").ratio == 20.0 / (55 + 20 + 60))
assert(hv("Electronics").absolute == 60)
assert(hv("Electronics").ratio == 60.0 / (55 + 20 + 60))

}
}

"compute correct metrics on numeric values" in withSparkSession { sparkSession =>
val dfFull = getDfWithNumericValues(sparkSession)
val histogram = Histogram("att2").calculate(dfFull)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ class AnalyzerContextTest extends AnyWordSpec
|{"entity":"Column","instance":"item","name":"Distinctness","value":1.0},
|{"entity":"Column","instance":"att1","name":"Completeness","value":1.0},
|{"entity":"Multicolumn","instance":"att1,att2","name":"Uniqueness","value":0.25},
|{"entity":"Dataset","instance":"*","name":"Size (where: att2 == 'd')","value":1.0},
|{"entity":"Dataset","instance":"*","name":"Size","value":4.0},
|{"entity":"Column","instance":"att1","name":"Histogram.bins","value":2.0},
|{"entity":"Column","instance":"att1","name":"Histogram.abs.a","value":3.0},
|{"entity":"Column","instance":"att1","name":"Histogram.ratio.a","value":0.75},
|{"entity":"Column","instance":"att1","name":"Histogram.abs.b","value":1.0},
|{"entity":"Column","instance":"att1","name":"Histogram.ratio.b","value":0.25},
|{"entity":"Dataset","instance":"*","name":"Size (where: att2 == 'd')","value":1.0},
|{"entity":"Dataset","instance":"*","name":"Size","value":4.0}
|{"entity":"Column","instance":"att1","name":"Histogram.ratio.b","value":0.25}
|]"""
.stripMargin.replaceAll("\n", "")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,112 @@ class AnalysisResultSerdeTest extends FlatSpec with Matchers {
assertCorrectlyConvertsAnalysisResults(Seq(result))
}

val histogramSumJson =
"""[
| {
| "resultKey": {
| "dataSetDate": 0,
| "tags": {}
| },
| "analyzerContext": {
| "metricMap": [
| {
| "analyzer": {
| "analyzerName": "Histogram",
| "column": "columnA",
| "maxDetailBins": 1000,
| "aggregateFunction": "sum",
| "aggregateColumn": "columnB"
| },
| "metric": {
| "metricName": "HistogramMetric",
| "column": "columnA",
| "numberOfBins": 10,
| "value": {
| "numberOfBins": 10,
| "values": {
| "some": {
| "absolute": 10,
| "ratio": 0.5
| }
| }
| }
| }
| }
| ]
| }
| }
|]""".stripMargin
val histogramCountJson =
"""[
| {
| "resultKey": {
| "dataSetDate": 0,
| "tags": {}
| },
| "analyzerContext": {
| "metricMap": [
| {
| "analyzer": {
| "analyzerName": "Histogram",
| "column": "columnA",
| "maxDetailBins": 1000
| },
| "metric": {
| "metricName": "HistogramMetric",
| "column": "columnA",
| "numberOfBins": 10,
| "value": {
| "numberOfBins": 10,
| "values": {
| "some": {
| "absolute": 10,
| "ratio": 0.5
| }
| }
| }
| }
| }
| ]
| }
| }
|]""".stripMargin

"Histogram serialization" should "be backward compatible for count" in {
val expected = histogramCountJson
val analyzer = Histogram("columnA")
val metric = HistogramMetric("columnA", Success(Distribution(Map("some" -> DistributionValue(10, 0.5)), 10)))
val context = AnalyzerContext(Map(analyzer -> metric))
val result = new AnalysisResult(ResultKey(0), context)
assert(serialize(Seq(result)) == expected)
}

"Histogram serialization" should "properly serialize sum" in {
val expected = histogramSumJson
val analyzer = Histogram("columnA", aggregateFunction = Histogram.Sum("columnB"))
val metric = HistogramMetric("columnA", Success(Distribution(Map("some" -> DistributionValue(10, 0.5)), 10)))
val context = AnalyzerContext(Map(analyzer -> metric))
val result = new AnalysisResult(ResultKey(0), context)
assert(serialize(Seq(result)) == expected)
}

"Histogram deserialization" should "be backward compatible for count" in {
val analyzer = Histogram("columnA")
val metric = HistogramMetric("columnA", Success(Distribution(Map("some" -> DistributionValue(10, 0.5)), 10)))
val context = AnalyzerContext(Map(analyzer -> metric))
val expected = new AnalysisResult(ResultKey(0), context)
assert(deserialize(histogramCountJson) == List(expected))
}

"Histogram deserialization" should "properly deserialize sum" in {
val analyzer = Histogram("columnA", aggregateFunction = Histogram.Sum("columnB"))
val metric = HistogramMetric("columnA", Success(Distribution(Map("some" -> DistributionValue(10, 0.5)), 10)))
val context = AnalyzerContext(Map(analyzer -> metric))
val expected = new AnalysisResult(ResultKey(0), context)
assert(deserialize(histogramSumJson) == List(expected))
}


def assertCorrectlyConvertsAnalysisResults(
analysisResults: Seq[AnalysisResult],
shouldFail: Boolean = false)
Expand Down

0 comments on commit d3dbe2d

Please sign in to comment.