Skip to content

Commit

Permalink
unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Donghan Zhang committed Feb 24, 2024
1 parent 766e04d commit 90243c5
Show file tree
Hide file tree
Showing 15 changed files with 157 additions and 52 deletions.
4 changes: 2 additions & 2 deletions spark/src/main/scala/ai/chronon/spark/Driver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,10 @@ object Driver {
!args.runFirstHole(),
selectedJoinParts = args.selectedJoinParts.toOption
)
val df = join.computeJoin(args.stepDays.toOption, args.startPartitionOverride.toOption)
val df = join.computeJoin(args.stepDays.toOption, args.startPartitionOverride.toOption).get

if (args.selectedJoinParts.isDefined) {
logger.info("Selected join parts are populated successfully. No final join is required. Exiting.")
logger.info("Selected join parts are populated successfully. Exiting.")
return
}
if (args.shouldExport()) {
Expand Down
7 changes: 3 additions & 4 deletions spark/src/main/scala/ai/chronon/spark/GroupBy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,21 @@

package ai.chronon.spark

import org.slf4j.LoggerFactory
import ai.chronon.aggregator.base.TimeTuple
import ai.chronon.aggregator.row.RowAggregator
import ai.chronon.aggregator.windowing._
import ai.chronon.api
import ai.chronon.api.{Accuracy, Constants, DataModel, ParametricMacro}
import ai.chronon.api.DataModel.{Entities, Events}
import ai.chronon.api.Extensions._
import ai.chronon.api.{Accuracy, Constants, DataModel, ParametricMacro}
import ai.chronon.online.{RowWrapper, SparkConversions}
import ai.chronon.spark.Extensions._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions._
import org.apache.spark.sql
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.util.sketch.BloomFilter
import org.slf4j.LoggerFactory

import java.util
import scala.collection.{Seq, mutable}
Expand Down Expand Up @@ -420,7 +419,7 @@ object GroupBy {

val join = new Join(joinConf, endDate, tableUtils, mutationScan = false, showDf = showDf)
if (computeDependency) {
val df = join.computeJoin()
val df = join.computeJoin().get
if (showDf) {
logger.info(
s"printing output data from groupby::join_source: ${groupByConf.metaData.name}::${joinConf.metaData.name}")
Expand Down
17 changes: 7 additions & 10 deletions spark/src/main/scala/ai/chronon/spark/JoinBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package ai.chronon.spark

import org.slf4j.LoggerFactory
import ai.chronon.api
import ai.chronon.api.DataModel.{Entities, Events}
import ai.chronon.api.Extensions._
Expand All @@ -28,6 +27,7 @@ import com.google.gson.Gson
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.util.sketch.BloomFilter
import org.slf4j.LoggerFactory

import java.time.Instant
import scala.collection.JavaConverters._
Expand Down Expand Up @@ -289,7 +289,7 @@ abstract class JoinBase(joinConf: api.Join,

def computeRange(leftDf: DataFrame, leftRange: PartitionRange, bootstrapInfo: BootstrapInfo): Option[DataFrame]

def computeJoin(stepDays: Option[Int] = None, overrideStartPartition: Option[String] = None): DataFrame = {
def computeJoin(stepDays: Option[Int] = None, overrideStartPartition: Option[String] = None): Option[DataFrame] = {

assert(Option(joinConf.metaData.team).nonEmpty,
s"join.metaData.team needs to be set for join ${joinConf.metaData.name}")
Expand Down Expand Up @@ -338,7 +338,7 @@ abstract class JoinBase(joinConf: api.Join,
def finalResult: DataFrame = tableUtils.sql(rangeToFill.genScanQuery(null, outputTable))
if (unfilledRanges.isEmpty) {
logger.info(s"\nThere is no data to compute based on end partition of ${rangeToFill.end}.\n\n Exiting..")
return finalResult
return Some(finalResult)
}

stepDays.foreach(metrics.gauge("step_days", _))
Expand All @@ -361,8 +361,9 @@ abstract class JoinBase(joinConf: api.Join,
// set autoExpand = true to ensure backward compatibility due to column ordering changes
val finalDf = computeRange(leftDfInRange, range, bootstrapInfo)
if (selectedJoinParts.isDefined) {
assert(finalDf.isEmpty, "finalDf should be empty")
assert(finalDf.isEmpty, "The arg `selectedJoinParts` is defined, so no final join is required. `finalDf` should be empty")
logger.info(s"Skipping writing to the output table for range: ${range.toString} $progress")
return None
} else {
finalDf.get.save(outputTable, tableProps, autoExpand = true)
val elapsedMins = (System.currentTimeMillis() - startMillis) / (60 * 1000)
Expand All @@ -373,11 +374,7 @@ abstract class JoinBase(joinConf: api.Join,
}
}
}
if (selectedJoinParts.isDefined) {
logger.info(s"Completed join parts: ${selectedJoinParts.get.mkString(", ")}")
} else {
logger.info(s"Wrote to table $outputTable, into partitions: $unfilledRanges")
}
finalResult
logger.info(s"Wrote to table $outputTable, into partitions: $unfilledRanges")
Some(finalResult)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class ConsistencyJob(session: SparkSession, joinConf: Join, endDate: String) ext
if (unfilledRanges.isEmpty) return
val join = new chronon.spark.Join(buildComparisonJoin(), unfilledRanges.last.end, TableUtils(session))
logger.info("Starting compute Join for comparison table")
val compareDf = join.computeJoin(Some(30))
val compareDf = join.computeJoin(Some(30)).get
logger.info("======= side-by-side comparison schema =======")
logger.info(compareDf.schema.pretty)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class AnalyzerTest {
val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, enableHitter = true)
val analyzerSchema = analyzer.analyzeJoin(joinConf)._1.map { case (k, v) => s"${k} => ${v}" }.toList.sorted
val join = new Join(joinConf = joinConf, endPartition = oneMonthAgo, tableUtils)
val computed = join.computeJoin()
val computed = join.computeJoin().get
val expectedSchema = computed.schema.fields.map(field => s"${field.name} => ${field.dataType}").sorted
logger.info("=== expected schema =====")
logger.info(expectedSchema.mkString("\n"))
Expand Down
2 changes: 1 addition & 1 deletion spark/src/test/scala/ai/chronon/spark/test/AvroTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class AvroTest {
metaData = Builders.MetaData(name = "unit_test.test_decimal", namespace = namespace, team = "chronon")
)
val runner = new Join(joinConf, tableUtils.partitionSpec.minus(today, new Window(40, TimeUnit.DAYS)), tableUtils)
val df = runner.computeJoin()
val df = runner.computeJoin().get
df.printSchema()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ class ChainingFetcherTest extends TestCase {
val inMemoryKvStore = kvStoreFunc()
val mockApi = new MockApi(kvStoreFunc, namespace)

val joinedDf = new ai.chronon.spark.Join(joinConf, endDs, tableUtils).computeJoin()
val joinedDf = new ai.chronon.spark.Join(joinConf, endDs, tableUtils).computeJoin().get
val joinTable = s"$namespace.join_test_expected_${joinConf.metaData.cleanName}"
joinedDf.save(joinTable)
logger.info("=== Expected join table computed: === " + joinTable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class FetchStatsTest extends TestCase {

// Compute daily join.
val joinJob = new Join(joinConf, today, tableUtils)
joinJob.computeJoin()
joinJob.computeJoin().get
// Load some data.
implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1))
val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetchStatsTest")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ class FetcherTest extends TestCase {
val inMemoryKvStore = kvStoreFunc()
val mockApi = new MockApi(kvStoreFunc, namespace)

val joinedDf = new ai.chronon.spark.Join(joinConf, endDs, tableUtils).computeJoin()
val joinedDf = new ai.chronon.spark.Join(joinConf, endDs, tableUtils).computeJoin().get
val joinTable = s"$namespace.join_test_expected_${joinConf.metaData.cleanName}"
joinedDf.save(joinTable)
val endDsExpected = tableUtils.sql(s"SELECT * FROM $joinTable WHERE ds='$endDs'")
Expand Down
Loading

0 comments on commit 90243c5

Please sign in to comment.