From af402d8f5ea4cb698cd2156fddbe511fb7f1a831 Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Mon, 22 Apr 2024 08:15:51 -0500 Subject: [PATCH] Let big data gen set nullability recursively Signed-off-by: Robert (Bobby) Evans --- .../spark/sql/tests/datagen/bigDataGen.scala | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/datagen/src/main/scala/org/apache/spark/sql/tests/datagen/bigDataGen.scala b/datagen/src/main/scala/org/apache/spark/sql/tests/datagen/bigDataGen.scala index da8f9461e2e..91335afe4e6 100644 --- a/datagen/src/main/scala/org/apache/spark/sql/tests/datagen/bigDataGen.scala +++ b/datagen/src/main/scala/org/apache/spark/sql/tests/datagen/bigDataGen.scala @@ -609,6 +609,15 @@ abstract class DataGen(var conf: ColumnConf, this } + def setNullProbabilityRecursively(probability: Double): DataGen = { + this.userProvidedNullGen = Some(NullProbabilityGenerationFunction(probability)) + children.foreach { + case (_, dataGen) => + dataGen.setNullProbabilityRecursively(probability) + } + this + } + /** * Set a specific location to seed mapping for the value generation. */ @@ -672,6 +681,7 @@ abstract class DataGen(var conf: ColumnConf, * Get the default value generator for this specific data gen. */ protected def getValGen: GeneratorFunction + def children: Seq[(String, DataGen)] /** * Get the final ready to use GeneratorFunction for the data generator. @@ -823,6 +833,8 @@ class BooleanGen(conf: ColumnConf, override def dataType: DataType = BooleanType override protected def getValGen: GeneratorFunction = BooleanGenFunc() + + override def children: Seq[(String, DataGen)] = Seq.empty } /** @@ -878,6 +890,8 @@ class ByteGen(conf: ColumnConf, extends DataGen(conf, defaultValueRange) { override def getValGen: GeneratorFunction = ByteGenFunc() override def dataType: DataType = ByteType + + override def children: Seq[(String, DataGen)] = Seq.empty } /** @@ -935,6 +949,8 @@ class ShortGen(conf: ColumnConf, override def getValGen: GeneratorFunction = ShortGenFunc() override def dataType: DataType = ShortType + + override def children: Seq[(String, DataGen)] = Seq.empty } /** @@ -991,6 +1007,8 @@ class IntGen(conf: ColumnConf, override def getValGen: GeneratorFunction = IntGenFunc() override def dataType: DataType = IntegerType + + override def children: Seq[(String, DataGen)] = Seq.empty } /** @@ -1045,6 +1063,8 @@ class LongGen(conf: ColumnConf, override def getValGen: GeneratorFunction = LongGenFunc() override def dataType: DataType = LongType + + override def children: Seq[(String, DataGen)] = Seq.empty } case class Decimal32GenFunc( @@ -1284,6 +1304,8 @@ class DecimalGen(dt: DecimalType, val max = DecimalGen.genMaxUnscaled(dt.precision) DecimalGenFunc(dt.precision, dt.scale, -max, max) } + + override def children: Seq[(String, DataGen)] = Seq.empty } /** @@ -1341,6 +1363,8 @@ class TimestampGen(conf: ColumnConf, override protected def getValGen: GeneratorFunction = TimestampGenFunc() override def dataType: DataType = TimestampType + + override def children: Seq[(String, DataGen)] = Seq.empty } object BigDataGenConsts { @@ -1418,6 +1442,8 @@ class DateGen(conf: ColumnConf, override protected def getValGen: GeneratorFunction = DateGenFunc() override def dataType: DataType = DateType + + override def children: Seq[(String, DataGen)] = Seq.empty } /** @@ -1440,6 +1466,8 @@ class DoubleGen(conf: ColumnConf, defaultValueRange: Option[(Any, Any)]) override def dataType: DataType = DoubleType override protected def getValGen: GeneratorFunction = DoubleGenFunc() + + override def children: Seq[(String, DataGen)] = Seq.empty } /** @@ -1462,6 +1490,8 @@ class FloatGen(conf: ColumnConf, defaultValueRange: Option[(Any, Any)]) override def dataType: DataType = FloatType override protected def getValGen: GeneratorFunction = FloatGenFunc() + + override def children: Seq[(String, DataGen)] = Seq.empty } trait JSONType { @@ -1648,6 +1678,8 @@ class StringGen(conf: ColumnConf, defaultValueRange: Option[(Any, Any)]) override def dataType: DataType = StringType override protected def getValGen: GeneratorFunction = ASCIIGenFunc() + + override def children: Seq[(String, DataGen)] = Seq.empty } case class StructGenFunc(childGens: Array[GeneratorFunction]) extends GeneratorFunction { @@ -1752,6 +1784,8 @@ class ArrayGen(child: DataGen, None } } + + override def children: Seq[(String, DataGen)] = Seq(("data", child)) } case class MapGenFunc( @@ -1816,6 +1850,8 @@ class MapGen(key: DataGen, None } } + + override def children: Seq[(String, DataGen)] = Seq(("key", key), ("value", value)) } @@ -1864,6 +1900,11 @@ class ColumnGen(val dataGen: DataGen) { this } + def setNullProbabilityRecursively(probability: Double): ColumnGen = { + dataGen.setNullProbabilityRecursively(probability) + this + } + def setNullGen(f: NullGeneratorFunction): ColumnGen = { dataGen.setNullGen(f) this @@ -1973,6 +2014,14 @@ class TableGen(val columns: Seq[(String, ColumnGen)], numRows: Long) { this } + def setNullProbabilityRecursively(probability: Double): TableGen = { + columns.foreach { + case (_, columnGen) => + columnGen.setNullProbabilityRecursively(probability) + } + this + } + /** * Convert this table into a `DataFrame` that can be * written out or used directly. Writing it out to parquet