From 1e1d81264a1408628224c129809207ad1f17f653 Mon Sep 17 00:00:00 2001 From: rdsharma26 <65777064+rdsharma26@users.noreply.github.com> Date: Wed, 5 Jul 2023 17:08:27 -0400 Subject: [PATCH] Replace Spark SQL isNull check with Spark Scala based DSL (#493) - This is to ensure columns with spaces in their names get their names escaped correctly in the where condition. - Added a test to verify. --- .../com/amazon/deequ/analyzers/Analyzer.scala | 18 +++++++++----- .../amazon/deequ/analyzers/MaxLength.scala | 5 ++-- .../amazon/deequ/analyzers/MinLength.scala | 9 ++++--- .../deequ/profiles/ColumnProfilerTest.scala | 24 +++++++++++++++++-- .../amazon/deequ/utils/FixtureSupport.scala | 13 ++++++++++ 5 files changed, 56 insertions(+), 13 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala index 12c938617..b3e44c4c7 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala @@ -465,20 +465,26 @@ private[deequ] object Analyzers { conditionalSelection(col(selection), where) } - def conditionalSelection(selection: Column, where: Option[String], replaceWith: Double): Column = { - val conditionColumn = where.map(expr) - conditionColumn + def conditionSelectionGivenColumn(selection: Column, where: Option[Column], replaceWith: Double): Column = { + where .map { condition => when(condition, replaceWith).otherwise(selection) } .getOrElse(selection) } - def conditionalSelection(selection: Column, where: Option[String], replaceWith: String): Column = { - val conditionColumn = where.map(expr) - conditionColumn + def conditionSelectionGivenColumn(selection: Column, where: Option[Column], replaceWith: String): Column = { + where .map { condition => when(condition, replaceWith).otherwise(selection) } .getOrElse(selection) } + def conditionalSelection(selection: Column, where: Option[String], replaceWith: Double): Column = { + conditionSelectionGivenColumn(selection, where.map(expr), replaceWith) + } + + def conditionalSelection(selection: Column, where: Option[String], replaceWith: String): Column = { + conditionSelectionGivenColumn(selection, where.map(expr), replaceWith) + } + def conditionalSelection(selection: Column, condition: Option[String]): Column = { val conditionColumn = condition.map { expression => expr(expression) } conditionalSelectionFromColumns(selection, conditionColumn) diff --git a/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala b/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala index 4898243aa..47ed71a69 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala @@ -49,12 +49,13 @@ case class MaxLength(column: String, where: Option[String] = None, analyzerOptio override def filterCondition: Option[String] = where private def criterion(nullBehavior: NullBehavior): Column = { + val isNullCheck = col(column).isNull nullBehavior match { case NullBehavior.Fail => val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType) - conditionalSelection(colLengths, Option(s"${column} IS NULL"), replaceWith = Double.MaxValue) + conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = Double.MaxValue) case NullBehavior.EmptyString => - length(conditionalSelection(col(column), Option(s"${column} IS NULL"), replaceWith = "")).cast(DoubleType) + length(conditionSelectionGivenColumn(col(column), Option(isNullCheck), replaceWith = "")).cast(DoubleType) case _ => length(conditionalSelection(column, where)).cast(DoubleType) } } diff --git a/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala b/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala index 97178149d..b63c4b4be 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala @@ -22,7 +22,9 @@ import com.amazon.deequ.analyzers.Preconditions.hasColumn import com.amazon.deequ.analyzers.Preconditions.isString import org.apache.spark.sql.Column import org.apache.spark.sql.Row -import org.apache.spark.sql.functions.{col, length, min} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.length +import org.apache.spark.sql.functions.min import org.apache.spark.sql.types.DoubleType import org.apache.spark.sql.types.StructType @@ -47,12 +49,13 @@ case class MinLength(column: String, where: Option[String] = None, analyzerOptio override def filterCondition: Option[String] = where private[deequ] def criterion(nullBehavior: NullBehavior): Column = { + val isNullCheck = col(column).isNull nullBehavior match { case NullBehavior.Fail => val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType) - conditionalSelection(colLengths, Option(s"${column} IS NULL"), replaceWith = Double.MinValue) + conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = Double.MinValue) case NullBehavior.EmptyString => - length(conditionalSelection(col(column), Option(s"${column} IS NULL"), replaceWith = "")).cast(DoubleType) + length(conditionSelectionGivenColumn(col(column), Option(isNullCheck), replaceWith = "")).cast(DoubleType) case _ => length(conditionalSelection(column, where)).cast(DoubleType) } } diff --git a/src/test/scala/com/amazon/deequ/profiles/ColumnProfilerTest.scala b/src/test/scala/com/amazon/deequ/profiles/ColumnProfilerTest.scala index 492be17d1..6eabc8f8a 100644 --- a/src/test/scala/com/amazon/deequ/profiles/ColumnProfilerTest.scala +++ b/src/test/scala/com/amazon/deequ/profiles/ColumnProfilerTest.scala @@ -78,6 +78,27 @@ class ColumnProfilerTest extends WordSpec with Matchers with SparkContextSpec assert(actualColumnProfile == expectedColumnProfile) } + "return correct StringColumnProfile for column names with spaces" in withSparkSession { session => + val data = getDfCompleteAndInCompleteColumnsWithSpacesInNames(session) + val columnNames = data.columns.toSeq + + val lengthMap = Map( + "att 1" -> (1, 3), + "att 2" -> (0, 7) + ) + + lengthMap.foreach { case (columnName, (minLength, maxLength)) => + val actualColumnProfile = ColumnProfiler.profile(data, Option(columnNames), false, 1) + .profiles(columnName) + + assert(actualColumnProfile.isInstanceOf[StringColumnProfile]) + val actualStringColumnProfile = actualColumnProfile.asInstanceOf[StringColumnProfile] + + assert(actualStringColumnProfile.minLength.contains(minLength)) + assert(actualStringColumnProfile.maxLength.contains(maxLength)) + } + } + "return correct columnProfiles with predefined dataType" in withSparkSession { session => val data = getDfCompleteAndInCompleteColumns(session) @@ -131,7 +152,6 @@ class ColumnProfilerTest extends WordSpec with Matchers with SparkContextSpec assert(actualColumnProfile == expectedColumnProfile) } - "return correct NumericColumnProfiles for numeric String DataType columns" in withSparkSession { session => @@ -171,6 +191,7 @@ class ColumnProfilerTest extends WordSpec with Matchers with SparkContextSpec assertProfilesEqual(expectedColumnProfile, actualColumnProfile.asInstanceOf[NumericColumnProfile]) } + "return correct NumericColumnProfiles for numeric String DataType columns when " + "kllProfiling disabled" in withSparkSession { session => @@ -562,7 +583,6 @@ class ColumnProfilerTest extends WordSpec with Matchers with SparkContextSpec ) assertSameColumnProfiles(columnProfiles.profiles, expectedProfiles) - } private[this] def assertSameColumnProfiles( diff --git a/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala b/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala index 9d64d8320..9b6ad9d4e 100644 --- a/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala +++ b/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala @@ -159,6 +159,19 @@ trait FixtureSupport { ).toDF("item", "att1", "att2") } + def getDfCompleteAndInCompleteColumnsWithSpacesInNames(sparkSession: SparkSession): DataFrame = { + import sparkSession.implicits._ + + Seq( + ("1", "ab", "abc1"), + ("2", "bc", null), + ("3", "ab", "def2ghi"), + ("4", "ab", null), + ("5", "bcd", "ab"), + ("6", "a", "pqrs") + ).toDF("some item", "att 1", "att 2") + } + def getDfCompleteAndInCompleteColumnsAndVarLengthStrings(sparkSession: SparkSession): DataFrame = { import sparkSession.implicits._