Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't request null keys from KVStore #774

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 89 additions & 70 deletions online/src/main/scala/ai/chronon/online/FetcherBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -202,57 +202,63 @@ class FetcherBase(kvStore: KVStore,
// 4. Finally converted to outputSchema
def fetchGroupBys(requests: scala.collection.Seq[Request]): Future[scala.collection.Seq[Response]] = {
// split a groupBy level request into its kvStore level requests
val groupByRequestToKvRequest: Seq[(Request, Try[GroupByRequestMeta])] = requests.iterator.map { request =>
val groupByRequestMetaTry: Try[GroupByRequestMeta] = getGroupByServingInfo(request.name)
.map { groupByServingInfo =>
val context =
request.context.getOrElse(Metrics.Context(Metrics.Environment.GroupByFetching, groupByServingInfo.groupBy))
context.increment("group_by_request.count")
var batchKeyBytes: Array[Byte] = null
var streamingKeyBytes: Array[Byte] = null
try {
// The formats of key bytes for batch requests and key bytes for streaming requests may differ based
// on the KVStore implementation, so we encode each distinctly.
batchKeyBytes =
kvStore.createKeyBytes(request.keys, groupByServingInfo, groupByServingInfo.groupByOps.batchDataset)
streamingKeyBytes =
kvStore.createKeyBytes(request.keys, groupByServingInfo, groupByServingInfo.groupByOps.streamingDataset)
} catch {
// TODO: only gets hit in cli path - make this code path just use avro schema to decode keys directly in cli
// TODO: Remove this code block
case ex: Exception =>
val castedKeys = groupByServingInfo.keyChrononSchema.fields.map {
case StructField(name, typ) => name -> ColumnAggregator.castTo(request.keys.getOrElse(name, null), typ)
}.toMap
try {
batchKeyBytes =
kvStore.createKeyBytes(castedKeys, groupByServingInfo, groupByServingInfo.groupByOps.batchDataset)
streamingKeyBytes =
kvStore.createKeyBytes(castedKeys, groupByServingInfo, groupByServingInfo.groupByOps.streamingDataset)
} catch {
case exInner: Exception =>
exInner.addSuppressed(ex)
throw new RuntimeException("Couldn't encode request keys or casted keys", exInner)
}
}
val batchRequest = GetRequest(batchKeyBytes, groupByServingInfo.groupByOps.batchDataset)
val streamingRequestOpt = groupByServingInfo.groupByOps.inferredAccuracy match {
// fetch batch(ir) and streaming(input) and aggregate
case Accuracy.TEMPORAL =>
Some(
GetRequest(streamingKeyBytes,
groupByServingInfo.groupByOps.streamingDataset,
Some(groupByServingInfo.batchEndTsMillis)))
// no further aggregation is required - the value in KvStore is good as is
case Accuracy.SNAPSHOT => None
val groupByRequestToKvRequest: Seq[(Request, Try[GroupByRequestMeta])] = requests.iterator
.filter(r => r.keys == null || r.keys.values == null || r.keys.values.exists(_ != null))
.map { request =>
val groupByRequestMetaTry: Try[GroupByRequestMeta] = getGroupByServingInfo(request.name)
.map { groupByServingInfo =>
val context =
request.context.getOrElse(
Metrics.Context(Metrics.Environment.GroupByFetching, groupByServingInfo.groupBy))
context.increment("group_by_request.count")
var batchKeyBytes: Array[Byte] = null
var streamingKeyBytes: Array[Byte] = null
try {
// The formats of key bytes for batch requests and key bytes for streaming requests may differ based
// on the KVStore implementation, so we encode each distinctly.
batchKeyBytes =
kvStore.createKeyBytes(request.keys, groupByServingInfo, groupByServingInfo.groupByOps.batchDataset)
streamingKeyBytes =
kvStore.createKeyBytes(request.keys, groupByServingInfo, groupByServingInfo.groupByOps.streamingDataset)
} catch {
// TODO: only gets hit in cli path - make this code path just use avro schema to decode keys directly in cli
// TODO: Remove this code block
case ex: Exception =>
val castedKeys = groupByServingInfo.keyChrononSchema.fields.map {
case StructField(name, typ) =>
name -> ColumnAggregator.castTo(request.keys.getOrElse(name, null), typ)
}.toMap
try {
batchKeyBytes =
kvStore.createKeyBytes(castedKeys, groupByServingInfo, groupByServingInfo.groupByOps.batchDataset)
streamingKeyBytes = kvStore.createKeyBytes(castedKeys,
groupByServingInfo,
groupByServingInfo.groupByOps.streamingDataset)
} catch {
case exInner: Exception =>
exInner.addSuppressed(ex)
throw new RuntimeException("Couldn't encode request keys or casted keys", exInner)
}
}
val batchRequest = GetRequest(batchKeyBytes, groupByServingInfo.groupByOps.batchDataset)
val streamingRequestOpt = groupByServingInfo.groupByOps.inferredAccuracy match {
// fetch batch(ir) and streaming(input) and aggregate
case Accuracy.TEMPORAL =>
Some(
GetRequest(streamingKeyBytes,
groupByServingInfo.groupByOps.streamingDataset,
Some(groupByServingInfo.batchEndTsMillis)))
// no further aggregation is required - the value in KvStore is good as is
case Accuracy.SNAPSHOT => None
}
GroupByRequestMeta(groupByServingInfo, batchRequest, streamingRequestOpt, request.atMillis, context)
}
GroupByRequestMeta(groupByServingInfo, batchRequest, streamingRequestOpt, request.atMillis, context)
if (groupByRequestMetaTry.isFailure) {
request.context.foreach(_.increment("group_by_serving_info_failure.count"))
}
if (groupByRequestMetaTry.isFailure) {
request.context.foreach(_.increment("group_by_serving_info_failure.count"))
request -> groupByRequestMetaTry
}
request -> groupByRequestMetaTry
}.toSeq
.toSeq
val allRequests: Seq[GetRequest] = groupByRequestToKvRequest.flatMap {
case (_, Success(GroupByRequestMeta(_, batchRequest, streamingRequestOpt, _, _))) =>
Some(batchRequest) ++ streamingRequestOpt
Expand Down Expand Up @@ -435,28 +441,8 @@ class FetcherBase(kvStore: KVStore,
case Right(keyMissingException) => {
Map(keyMissingException.requestName + "_exception" -> keyMissingException.getMessage)
}
case Left(PrefixedRequest(prefix, groupByRequest)) => {
responseMap
.getOrElse(groupByRequest,
Failure(new IllegalStateException(
s"Couldn't find a groupBy response for $groupByRequest in response map")))
.map { valueMap =>
if (valueMap != null) {
valueMap.map { case (aggName, aggValue) => prefix + "_" + aggName -> aggValue }
} else {
Map.empty[String, AnyRef]
}
}
// prefix feature names
.recover { // capture exception as a key
case ex: Throwable =>
if (debug || Math.random() < 0.001) {
logger.error(s"Failed to fetch $groupByRequest", ex)
}
Map(groupByRequest.name + "_exception" -> ex.traceString)
}
.get
}
case Left(PrefixedRequest(prefix, groupByRequest)) =>
parseGroupByResponse(prefix, groupByRequest, responseMap)
}.toMap
}
joinValuesTry match {
Expand All @@ -476,6 +462,39 @@ class FetcherBase(kvStore: KVStore,
}
}

def parseGroupByResponse(prefix: String,
groupByRequest: Request,
responseMap: Map[Request, Try[Map[String, AnyRef]]]): Map[String, AnyRef] = {
// Group bys with all null keys won't be requested from the KV store and we don't expect a response.
val isRequiredRequest = groupByRequest.keys.values.exists(_ != null) || groupByRequest.keys.isEmpty

val response: Try[Map[String, AnyRef]] = responseMap.get(groupByRequest) match {
case Some(value) => value
case None =>
if (isRequiredRequest)
Failure(new IllegalStateException(s"Couldn't find a groupBy response for $groupByRequest in response map"))
else Success(null)
}

response
.map { valueMap =>
if (valueMap != null) {
valueMap.map { case (aggName, aggValue) => prefix + "_" + aggName -> aggValue }
} else {
Map.empty[String, AnyRef]
}
}
// prefix feature names
.recover { // capture exception as a key
case ex: Throwable =>
if (debug || Math.random() < 0.001) {
println(s"Failed to fetch $groupByRequest with \n${ex.traceString}")
}
Map(groupByRequest.name + "_exception" -> ex.traceString)
}
.get
}

/**
* Fetch method to simulate a random access interface for Chronon
* by distributing requests to relevant GroupBys. This is a batch
Expand Down
48 changes: 47 additions & 1 deletion online/src/test/scala/ai/chronon/online/FetcherBaseTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.scalatestplus.mockito.MockitoSugar

import scala.concurrent.duration.DurationInt
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.util.{Failure, Success}
import scala.util.{Failure, Success, Try}

class FetcherBaseTest extends MockitoSugar with Matchers {
val GroupBy = "relevance.short_term_user_features"
Expand Down Expand Up @@ -141,4 +141,50 @@ class FetcherBaseTest extends MockitoSugar with Matchers {
actualRequest.get.name shouldBe query.groupByName + "." + query.columnName
actualRequest.get.keys shouldBe query.keyMapping.get
}

@Test
def testParsingGroupByResponse_HappyHase(): Unit = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:
s/HappyHase/HappyCase

val baseFetcher = new FetcherBase(mock[KVStore])
val request = Request(name = "name", keys = Map("email" -> "email"), atMillis = None, context = None)
val response: Map[Request, Try[Map[String, AnyRef]]] = Map(
request -> Success(Map(
"key" -> "value"
))
)

val result = baseFetcher.parseGroupByResponse("prefix", request, response)
result shouldBe Map("prefix_key" -> "value")
}

@Test
def testParsingGroupByResponse_NullKey(): Unit = {
val baseFetcher = new FetcherBase(mock[KVStore])
val request = Request(name = "name", keys = Map("email" -> null), atMillis = None, context = None)
val request2 = Request(name = "name2", keys = Map("email" -> null), atMillis = None, context = None)

val response: Map[Request, Try[Map[String, AnyRef]]] = Map(
request2 -> Success(Map(
"key" -> "value"
))
)

val result = baseFetcher.parseGroupByResponse("prefix", request, response)
result shouldBe Map()
}

@Test
def testParsingGroupByResponse_MissingKey(): Unit = {
val baseFetcher = new FetcherBase(mock[KVStore])
val request = Request(name = "name", keys = Map("email" -> "email"), atMillis = None, context = None)
val request2 = Request(name = "name2", keys = Map("email" -> "email"), atMillis = None, context = None)

val response: Map[Request, Try[Map[String, AnyRef]]] = Map(
request2 -> Success(Map(
"key" -> "value"
))
)

val result = baseFetcher.parseGroupByResponse("prefix", request, response)
result.keySet shouldBe Set("name_exception")
}
}
7 changes: 5 additions & 2 deletions spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,9 @@ class FetcherTest extends TestCase {

val listingEventData = Seq(
Row(1L, toTs("2021-04-10 03:10:00"), "2021-04-10"),
Row(2L, toTs("2021-04-10 03:10:00"), "2021-04-10")
Row(2L, toTs("2021-04-10 03:10:00"), "2021-04-10"),
Row(2L, toTs("2021-04-10 03:10:00"), "2021-04-10"),
Row(null, toTs("2021-04-10 03:10:00"), "2021-04-10")
)
val ratingEventData = Seq(
// 1L listing id event data
Expand All @@ -464,7 +466,8 @@ class FetcherTest extends TestCase {
Row(2L, toTs("2021-04-10 02:30:00"), 5, "2021-04-10"),
Row(2L, toTs("2021-04-10 02:30:00"), 8, "2021-04-10"),
Row(2L, toTs("2021-04-10 02:30:00"), 8, "2021-04-10"),
Row(2L, toTs("2021-04-07 00:30:00"), 10, "2021-04-10") // dated 4/10 but excluded from avg agg based on ts
Row(2L, toTs("2021-04-07 00:30:00"), 10, "2021-04-10"), // dated 4/10 but excluded from avg agg based on ts
Row(null, toTs("2021-04-10 02:30:00"), 8, "2021-04-10")
)
// Schemas
// {..., event (generic event column), ...}
Expand Down