Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
ezvz committed Feb 25, 2024
1 parent f6c7b0e commit ec7a744
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 36 deletions.
1 change: 1 addition & 0 deletions api/src/main/scala/ai/chronon/api/Constants.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,5 @@ object Constants {
val LabelViewPropertyFeatureTable: String = "feature_table"
val LabelViewPropertyKeyLabelTable: String = "label_table"
val ChrononRunDs: String = "CHRONON_RUN_DS"
val SmallJoinCutoff: Int = 5000
}
55 changes: 44 additions & 11 deletions spark/src/main/scala/ai/chronon/spark/Join.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package ai.chronon.spark

import java.util

import org.slf4j.LoggerFactory
import ai.chronon.api
import ai.chronon.api.Extensions._
Expand All @@ -37,6 +39,8 @@ import scala.jdk.CollectionConverters.{asJavaIterableConverter, asScalaBufferCon
import scala.util.ScalaJavaConversions.{IterableOps, ListOps, MapOps}
import scala.util.{Failure, Success}

import ai.chronon.api.Constants.SmallJoinCutoff

/*
* hashes: a list containing bootstrap hashes that represent the list of bootstrap parts that a record has matched
* during the bootstrap join
Expand Down Expand Up @@ -186,14 +190,24 @@ class Join(joinConf: api.Join,
coveringSetsPerJoinPart
}

def getAllLeftSideKeyNames(): Seq[String] = {
joinConf.getJoinParts.asScala.flatMap { joinPart =>
if (joinPart.keyMapping != null) {
joinPart.keyMapping.asScala.keys.toSeq
} else {
joinPart.groupBy.getKeyColumns.asScala
}
}
}

def injectKeyFilter(leftDf: DataFrame, joinPart: api.JoinPart): Unit = {
// Modifies the joinPart to inject the key filter into the

val groupByKeyNames = joinPart.groupBy.getKeyColumns.asScala

// In case the joinPart uses a keymapping
val leftSideKeyNames: Map[String, String] = if (joinPart.keyMapping != null) {
joinPart.keyMapping.asScala.toMap
joinPart.rightToLeft
} else {
groupByKeyNames.map { k =>
(k, k)
Expand All @@ -210,9 +224,12 @@ class Join(joinConf: api.Join,
val joinSelects: Map[String, String] = Option(joinConf.left.rootQuery.getQuerySelects).getOrElse(Map.empty[String, String])

groupByKeyExpressions.map{ case (keyName, groupByKeyExpression) =>
println("---------------------------------------")
println(s"Left side keynames ${leftSideKeyNames.mkString(",")}")
println(s"keyName: $keyName, expression: $groupByKeyExpressions")
println("---------------------------------------")
val leftSideKeyName = leftSideKeyNames.get(keyName).get
val leftSelectExpression = joinSelects.getOrElse(leftSideKeyName, keyName)
val values = leftDf.select(leftSelectExpression).collect().map(row => row(0))
val values = leftDf.select(leftSideKeyName).collect().map(row => row(0))

// Check for null keys, warn if found, err if all null
val (notNullValues, nullValues) = values.partition(_ != null)
Expand All @@ -230,11 +247,15 @@ class Join(joinConf: api.Join,

// Form the final WHERE clause for injection
s"$groupByKeyExpression in (${valueSet.mkString(sep = ",")})"
}.foreach(source.rootQuery.getWheres.add(_))
}.foreach { whereClause =>
val currentWheres = Option(source.rootQuery.getWheres).getOrElse(new util.ArrayList[String]())
currentWheres.add(whereClause)
source.rootQuery.setWheres(currentWheres)
}
}
}

override def computeRange(leftDf: DataFrame, leftRange: PartitionRange, bootstrapInfo: BootstrapInfo): DataFrame = {
override def computeRange(leftDf: DataFrame, leftRange: PartitionRange, bootstrapInfo: BootstrapInfo, runSmallMode: Boolean = false): DataFrame = {
val leftTaggedDf = if (leftDf.schema.names.contains(Constants.TimeColumn)) {
leftDf.withTimeBasedColumn(Constants.TimePartitionColumn)
} else {
Expand All @@ -251,7 +272,7 @@ class Join(joinConf: api.Join,
val bootstrapCoveringSets = findBootstrapSetCoverings(bootstrapDf, bootstrapInfo, leftRange)

// compute a single bloomfilter at join level if there is no bootstrap operation
val joinLevelBloomMapOpt = if (bootstrapDf.columns.contains(Constants.MatchedHashes)) {
lazy val joinLevelBloomMapOpt = if (bootstrapDf.columns.contains(Constants.MatchedHashes)) {
// do not compute if any bootstrap is involved
None
} else {
Expand All @@ -266,8 +287,16 @@ class Join(joinConf: api.Join,
}
}

val parallellism = if (runSmallMode) {
// Max out parallelism
joinConf.getJoinParts.asScala.length
} else {
tableUtils.joinPartParallelism
}

implicit val executionContext: ExecutionContextExecutorService =
ExecutionContext.fromExecutorService(Executors.newFixedThreadPool(tableUtils.joinPartParallelism))
ExecutionContext.fromExecutorService(Executors.newFixedThreadPool(parallellism))


val joinedDfTry = tableUtils
.wrapWithCache("Computing left parts for bootstrap table", bootstrapDf) {
Expand Down Expand Up @@ -301,10 +330,14 @@ class Join(joinConf: api.Join,
s"Macro ${Constants.ChrononRunDs} is only supported for single day join, current range is ${leftRange}")
}

// If left DF is small, hardcode the key filter into the joinPart's GroupBy's where clause.
if (unfilledLeftDf.isDefined && unfilledLeftDf.get.df.)
val df =
computeRightTable(unfilledLeftDf, joinPart, leftRange, joinLevelBloomMapOpt).map(df => joinPart -> df)
val (bloomFilterOpt, skipFilter) = if (runSmallMode) {
// If left DF is small, hardcode the key filter into the joinPart's GroupBy's where clause.
injectKeyFilter(leftDf, joinPart)
(None, true)
} else {
(joinLevelBloomMapOpt, false)
}
val df = computeRightTable(unfilledLeftDf, joinPart, leftRange, bloomFilterOpt, skipFilter).map(df => joinPart -> df)
Thread.currentThread().setName(s"done-$threadName")
df
}
Expand Down
75 changes: 53 additions & 22 deletions spark/src/main/scala/ai/chronon/spark/JoinBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ import com.google.gson.Gson
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.util.sketch.BloomFilter

import java.time.Instant

import scala.collection.JavaConverters._
import scala.collection.Seq

import ai.chronon.api.Constants.SmallJoinCutoff

abstract class JoinBase(joinConf: api.Join,
endPartition: String,
tableUtils: TableUtils,
Expand Down Expand Up @@ -117,13 +119,14 @@ abstract class JoinBase(joinConf: api.Join,
def computeRightTable(leftDf: Option[DfWithStats],
joinPart: JoinPart,
leftRange: PartitionRange,
joinLevelBloomMapOpt: Option[Map[String, BloomFilter]]): Option[DataFrame] = {
joinLevelBloomMapOpt: Option[Map[String, BloomFilter]],
skipBloom: Boolean = false): Option[DataFrame] = {

val partTable = joinConf.partOutputTable(joinPart)
val partMetrics = Metrics.Context(metrics, joinPart)
if (joinPart.groupBy.aggregations == null) {
// for non-aggregation cases, we directly read from the source table and there is no intermediate join part table
computeJoinPart(leftDf, joinPart, joinLevelBloomMapOpt)
computeJoinPart(leftDf, joinPart, joinLevelBloomMapOpt, skipBloom)
} else {
// in Events <> batch GB case, the partition dates are offset by 1
val shiftDays =
Expand All @@ -145,15 +148,19 @@ abstract class JoinBase(joinConf: api.Join,
skipFirstHole = false
)
.getOrElse(Seq())
val partitionCount = unfilledRanges.map(_.partitions.length).sum

// todo: undo this, just for debugging
val unfilledRangeCombined = Seq(PartitionRange(unfilledRanges.minBy(_.start).start, unfilledRanges.maxBy(_.end).end)(tableUtils))

val partitionCount = unfilledRangeCombined.map(_.partitions.length).sum
if (partitionCount > 0) {
val start = System.currentTimeMillis()
unfilledRanges
unfilledRangeCombined
.foreach(unfilledRange => {
val leftUnfilledRange = unfilledRange.shift(-shiftDays)
val prunedLeft = leftDf.flatMap(_.prunePartitions(leftUnfilledRange))
val filledDf =
computeJoinPart(prunedLeft, joinPart, joinLevelBloomMapOpt)
computeJoinPart(prunedLeft, joinPart, joinLevelBloomMapOpt, skipBloom)
// Cache join part data into intermediate table
if (filledDf.isDefined) {
logger.info(s"Writing to join part table: $partTable for partition range $unfilledRange")
Expand Down Expand Up @@ -182,7 +189,8 @@ abstract class JoinBase(joinConf: api.Join,

def computeJoinPart(leftDfWithStats: Option[DfWithStats],
joinPart: JoinPart,
joinLevelBloomMapOpt: Option[Map[String, BloomFilter]]): Option[DataFrame] = {
joinLevelBloomMapOpt: Option[Map[String, BloomFilter]],
skipBloom: Boolean = false): Option[DataFrame] = {

if (leftDfWithStats.isEmpty) {
// happens when all rows are already filled by bootstrap tables
Expand All @@ -196,14 +204,17 @@ abstract class JoinBase(joinConf: api.Join,

logger.info(
s"\nBackfill is required for ${joinPart.groupBy.metaData.name} for $rowCount rows on range $unfilledRange")
val rightBloomMap =
val rightBloomMap = if (skipBloom) {
None
} else {
JoinUtils.genBloomFilterIfNeeded(leftDf,
joinPart,
joinConf,
rowCount,
unfilledRange,
tableUtils,
joinLevelBloomMapOpt)
joinPart,
joinConf,
rowCount,
unfilledRange,
tableUtils,
joinLevelBloomMapOpt)
}
val rightSkewFilter = joinConf.partSkewFilter(joinPart)
def genGroupBy(partitionRange: PartitionRange) =
GroupBy.from(joinPart.groupBy,
Expand Down Expand Up @@ -286,7 +297,7 @@ abstract class JoinBase(joinConf: api.Join,
Some(rightDfWithDerivations)
}

def computeRange(leftDf: DataFrame, leftRange: PartitionRange, bootstrapInfo: BootstrapInfo): DataFrame
def computeRange(leftDf: DataFrame, leftRange: PartitionRange, bootstrapInfo: BootstrapInfo, runSmallMode: Boolean = false): DataFrame

def computeJoin(stepDays: Option[Int] = None, overrideStartPartition: Option[String] = None): DataFrame = {

Expand All @@ -302,8 +313,8 @@ abstract class JoinBase(joinConf: api.Join,
val today = tableUtils.partitionSpec.at(System.currentTimeMillis())
val analyzer = new Analyzer(tableUtils, joinConf, today, today, silenceMode = true)
try {
analyzer.analyzeJoin(joinConf, validationAssert = true)
metrics.gauge(Metrics.Name.validationSuccess, 1)
//analyzer.analyzeJoin(joinConf, validationAssert = true)
//metrics.gauge(Metrics.Name.validationSuccess, 1)
logger.info("Join conf validation succeeded. No error found.")
} catch {
case ex: AssertionError =>
Expand All @@ -323,7 +334,7 @@ abstract class JoinBase(joinConf: api.Join,
// detect holes and chunks to fill
// OverrideStartPartition is used to replace the start partition of the join config. This is useful when
// 1 - User would like to test run with different start partition
// 2 - User has entity table which is accumulative and only want to run backfill for the latest partition
// 2 - User has entity table which is cumulative and only want to run backfill for the latest partition
val rangeToFill = JoinUtils.getRangesToFill(joinConf.left,
tableUtils,
endPartition,
Expand All @@ -349,16 +360,36 @@ abstract class JoinBase(joinConf: api.Join,
// build bootstrap info once for the entire job
val bootstrapInfo = BootstrapInfo.from(joinConf, rangeToFill, tableUtils, leftSchema, mutationScan = mutationScan)

logger.info(s"Join ranges to compute: ${stepRanges.map { _.toString }.pretty}")
stepRanges.zipWithIndex.foreach {
val wholeRange = PartitionRange(unfilledRanges.minBy(_.start).start, unfilledRanges.maxBy(_.end).end)(tableUtils)

val runSmallMode = {
val thresholdCount = leftDf(joinConf, wholeRange, tableUtils, limit = Some(SmallJoinCutoff + 1)).get.count()
val result = thresholdCount <= SmallJoinCutoff
if (result) {
logger.info(s"Counted $thresholdCount rows, running join in small mode.")
tableUtils.shouldRepartition = false
} else {
logger.info(s"Counted greater than $SmallJoinCutoff rows, proceeding with normal computation.")
}
result
}

val effectiveRanges = if (runSmallMode) {
Seq(wholeRange)
} else {
stepRanges
}

logger.info(s"Join ranges to compute: ${effectiveRanges.map { _.toString }.pretty}")
effectiveRanges.zipWithIndex.foreach {
case (range, index) =>
val startMillis = System.currentTimeMillis()
val progress = s"| [${index + 1}/${stepRanges.size}]"
val progress = s"| [${index + 1}/${effectiveRanges.size}]"
logger.info(s"Computing join for range: ${range.toString} $progress")
leftDf(joinConf, range, tableUtils).map { leftDfInRange =>
if (showDf) leftDfInRange.prettyPrint()
// set autoExpand = true to ensure backward compatibility due to column ordering changes
computeRange(leftDfInRange, range, bootstrapInfo).save(outputTable, tableProps, autoExpand = true)
computeRange(leftDfInRange, range, bootstrapInfo, runSmallMode).save(outputTable, tableProps, autoExpand = true)
val elapsedMins = (System.currentTimeMillis() - startMillis) / (60 * 1000)
metrics.gauge(Metrics.Name.LatencyMinutes, elapsedMins)
metrics.gauge(Metrics.Name.PartitionCount, range.partitions.length)
Expand Down
29 changes: 26 additions & 3 deletions spark/src/main/scala/ai/chronon/spark/TableUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package ai.chronon.spark

import java.io.{PrintWriter, StringWriter}

import org.slf4j.LoggerFactory
import ai.chronon.aggregator.windowing.TsUtils
import ai.chronon.api.{Constants, PartitionSpec}
Expand All @@ -29,10 +31,10 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
import org.apache.spark.storage.StorageLevel

import java.time.format.DateTimeFormatter
import java.time.{Instant, ZoneId}
import java.util.concurrent.{ExecutorService, Executors}

import scala.collection.{Seq, mutable}
import scala.concurrent.{ExecutionContext, ExecutionContextExecutor}
import scala.util.{Failure, Success, Try}
Expand All @@ -46,6 +48,8 @@ case class TableUtils(sparkSession: SparkSession) {
.withZone(ZoneId.systemDefault())
val partitionColumn: String =
sparkSession.conf.get("spark.chronon.partition.column", "ds")
var shouldRepartition: Boolean = false
//sparkSession.conf.get("spark.chronon.repartition", "true").toBoolean
private val partitionFormat: String =
sparkSession.conf.get("spark.chronon.partition.format", "yyyy-MM-dd")
val partitionSpec: PartitionSpec = PartitionSpec(partitionFormat, WindowUtils.Day.millis)
Expand Down Expand Up @@ -76,6 +80,9 @@ case class TableUtils(sparkSession: SparkSession) {
sparkSession.sparkContext.setLogLevel("ERROR")
// converts String-s like "a=b/c=d" to Map("a" -> "b", "c" -> "d")

def setRepartition(setTo: Boolean): Unit = {
this.shouldRepartition = setTo
}
def preAggRepartition(df: DataFrame): DataFrame =
if (df.rdd.getNumPartitions < aggregationParallelism) {
df.repartition(aggregationParallelism)
Expand Down Expand Up @@ -322,8 +329,15 @@ case class TableUtils(sparkSession: SparkSession) {

def sql(query: String): DataFrame = {
val partitionCount = sparkSession.sparkContext.getConf.getInt("spark.default.parallelism", 1000)
val sw = new StringWriter()
val pw = new PrintWriter(sw)
new Throwable().printStackTrace(pw)
val stackTraceString = sw.toString
val stackTraceStringPretty = stackTraceString.split("\n").filter(_.contains("chronon")).map(_.replace("at ai.chronon.spark.", "")).mkString("\n")

logger.info(
s"\n----[Running query coalesced into at most $partitionCount partitions]----\n$query\n----[End of Query]----\n")
s"\n----[Running query coalesced into at most $partitionCount partitions]----\n$query\n----[End of Query]----\n\n Query call path (not an error stack trace): \n$stackTraceStringPretty \n\n --------")

val df = sparkSession.sql(query).coalesce(partitionCount)
df
}
Expand Down Expand Up @@ -383,7 +397,15 @@ case class TableUtils(sparkSession: SparkSession) {
saveMode: SaveMode,
stats: Option[DfStats]): Unit = {
wrapWithCache(s"repartition & write to $tableName", df) {
repartitionAndWriteInternal(df, tableName, saveMode, stats)
if (shouldRepartition) {
logger.info(s"Repartitioning before writing...")
repartitionAndWriteInternal(df, tableName, saveMode, stats)
} else {
logger.info(s"Skipping repartition...")
df.write.mode(saveMode).insertInto(tableName)
logger.info(s"Finished writing to $tableName")
}

}.get
}

Expand All @@ -392,6 +414,7 @@ case class TableUtils(sparkSession: SparkSession) {
saveMode: SaveMode,
stats: Option[DfStats]): Unit = {
// get row count and table partition count statistics

val (rowCount: Long, tablePartitionCount: Int) =
if (df.schema.fieldNames.contains(partitionColumn)) {
if (stats.isDefined && stats.get.partitionRange.wellDefined) {
Expand Down

0 comments on commit ec7a744

Please sign in to comment.