Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARKC-312: Implementing FilterOptimizer #1019

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ may also be used. ("127.0.0.1,192.168.0.1")

<table class="table">
<tr><th>Property Name</th><th>Default</th><th>Description</th></tr>
<tr>
<td><code>sql.enable.where.clause.optimization</code></td>
<td>false</td>
<td>Connector will try to optimize sql query `where`-clause, to increase
number of filters that can be pushed down. Experimental.</td>
</tr>
<tr>
<td><code>sql.pushdown.additionalClasses</code></td>
<td></td>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@ import com.datastax.spark.connector.embedded.YamlTransformations
import com.datastax.spark.connector.rdd.{CassandraTableScanRDD, CqlWhereClause}
import com.datastax.spark.connector.util.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.execution._
import org.apache.spark.sql.cassandra.CassandraSourceRelation
import org.apache.spark.sql.execution.{
FilterExec,
RDDScanExec,
RowDataSourceScanExec,
SparkPlan,
WholeStageCodegenExec
}

import scala.concurrent.Future

Expand Down Expand Up @@ -34,14 +41,23 @@ class CassandraPrunedFilteredScanSpec extends SparkCassandraITFlatSpecBase with
s"""CREATE TABLE IF NOT EXISTS $ks.fields
|(k INT, a TEXT, b TEXT, c TEXT, d TEXT, e TEXT, PRIMARY KEY (k)) """
.stripMargin)
},
Future {
session.execute(
s"""CREATE TABLE IF NOT EXISTS $ks.metrics
|(k TEXT, a INT, b INT, c INT, PRIMARY KEY (k, a, b)) """
.stripMargin)
}
)
}
}

val colorOptions = Map("keyspace" -> ks, "table" -> "colors")
val fieldsOptions = Map("keyspace" -> ks, "table" -> "fields")
val metricsOptions = Map("keyspace" -> ks, "table" -> "metrics")
val withPushdown = Map("pushdown" -> "true")
val withWhereClauseOptimizationEnabled =
Map(CassandraSourceRelation.EnableWhereClauseOptimizationParam.name -> "true")
val withoutPushdown = Map("pushdown" -> "false")

"CassandraPrunedFilteredScan" should "pushdown predicates for clustering keys" in {
Expand Down Expand Up @@ -76,6 +92,19 @@ class CassandraPrunedFilteredScanSpec extends SparkCassandraITFlatSpecBase with
cts.get.selectedColumnNames should contain theSameElementsAs Seq("b", "c", "d")
}

it should "optimize table scan if all filters can be pushed down" in {
val fieldsDF = sparkSession.read
.format(cassandraFormat)
.options(metricsOptions ++ withPushdown ++ withWhereClauseOptimizationEnabled )
.load()
val df = fieldsDF.filter("a = 5 and (b > 5 or b < 3)")
val executionPlan = df.queryExecution.executedPlan
val cts = findAllCassandraTableScanRDD(executionPlan)
cts.nonEmpty shouldBe true
cts.head.where shouldBe CqlWhereClause(Seq(""""a" = ? AND "b" > ?"""), List(5, 5))
cts.last.where shouldBe CqlWhereClause(Seq(""""a" = ? AND "b" < ?"""), List(5, 3))
}

def findCassandraTableScanRDD(sparkPlan: SparkPlan): Option[CassandraTableScanRDD[_]] = {
def _findCassandraTableScanRDD(rdd: RDD[_]): Option[CassandraTableScanRDD[_]] = {
rdd match {
Expand All @@ -94,4 +123,22 @@ class CassandraPrunedFilteredScanSpec extends SparkCassandraITFlatSpecBase with
}
}

def findAllCassandraTableScanRDD(sparkPlan: SparkPlan): List[CassandraTableScanRDD[_]] = {
def _findAllCassandraTableScanRDD(rdd: RDD[_]): List[CassandraTableScanRDD[_]] = {
rdd match {
case ctsrdd: CassandraTableScanRDD[_] => List(ctsrdd)
case other: RDD[_] => other.dependencies.iterator
.flatMap(dep => _findAllCassandraTableScanRDD(dep.rdd)).toList
}
}

sparkPlan match {
case prdd: RDDScanExec => _findAllCassandraTableScanRDD(prdd.rdd)
case prdd: RowDataSourceScanExec => _findAllCassandraTableScanRDD(prdd.rdd)
case filter: FilterExec => findAllCassandraTableScanRDD(filter.child)
case wsc: WholeStageCodegenExec => findAllCassandraTableScanRDD(wsc.child)
case _ => List.empty
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,7 @@ package org.apache.spark.sql.cassandra
import java.net.InetAddress
import java.util.UUID

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.cassandra.CassandraSQLRow.CassandraSQLRowReader
import org.apache.spark.sql.cassandra.DataTypeConverter._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SQLContext, sources}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.SparkConf

import com.datastax.spark.connector.cql.{CassandraConnector, CassandraConnectorConf, Schema}
import com.datastax.spark.connector.rdd.partitioner.CassandraPartitionGenerator._
import com.datastax.spark.connector.rdd.partitioner.DataSizeEstimates
import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.forSystemLocalPartitioner
import com.datastax.spark.connector.rdd.{CassandraRDD, ReadConf}
Expand All @@ -22,6 +12,14 @@ import com.datastax.spark.connector.util.Quote._
import com.datastax.spark.connector.util.{ConfigParameter, Logging, ReflectionUtil}
import com.datastax.spark.connector.writer.{SqlRowWriter, WriteConf}
import com.datastax.spark.connector.{SomeColumns, _}
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.cassandra.CassandraSQLRow.CassandraSQLRowReader
import org.apache.spark.sql.cassandra.DataTypeConverter._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SQLContext, sources}
import org.apache.spark.unsafe.types.UTF8String

/**
* Implements [[BaseRelation]]]], [[InsertableRelation]]]] and [[PrunedFilteredScan]]]]
Expand All @@ -34,6 +32,7 @@ private[cassandra] class CassandraSourceRelation(
userSpecifiedSchema: Option[StructType],
filterPushdown: Boolean,
tableSizeInBytes: Option[Long],
enableWhereClauseOptimization: Boolean,
connector: CassandraConnector,
readConf: ReadConf,
writeConf: WriteConf,
Expand Down Expand Up @@ -79,7 +78,7 @@ private[cassandra] class CassandraSourceRelation(
def buildScan(): RDD[Row] = baseRdd.asInstanceOf[RDD[Row]]

override def unhandledFilters(filters: Array[Filter]): Array[Filter] = filterPushdown match {
case true => predicatePushDown(filters).handledBySpark.toArray
case true => analyzePredicates(filters).head.handledBySpark.toArray
case false => filters
}

Expand Down Expand Up @@ -127,13 +126,32 @@ private[cassandra] class CassandraSourceRelation(
finalPushdown
}

private def analyzePredicates(filters: Array[Filter]): List[AnalyzedPredicates] = {
if (enableWhereClauseOptimization){
val optimizedFilters = FiltersOptimizer.build(filters)
val partitions = optimizedFilters.map(predicatePushDown)
val allHandledBySparkAreTheSame = partitions.map(_.handledBySpark).sliding(2).forall { tuple =>
tuple.head == tuple.last
}
if(allHandledBySparkAreTheSame){
partitions
} else {
List(predicatePushDown(filters))
}
} else {
List(predicatePushDown(filters))
}
}

override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
val prunedRdd = maybeSelect(baseRdd, requiredColumns)
val prunedFilteredRdd = {
if(filterPushdown) {
val pushdownFilters = predicatePushDown(filters).handledByCassandra.toArray
val filteredRdd = maybePushdownFilters(prunedRdd, pushdownFilters)
filteredRdd.asInstanceOf[RDD[Row]]
val pushdownFilters = analyzePredicates(filters)
pushdownFilters.map { predicate =>
val pushdownFilters = predicate.handledByCassandra.toArray
maybePushdownFilters(prunedRdd, pushdownFilters).asInstanceOf[RDD[Row]]
}.reduce(_ union _)
} else {
prunedRdd
}
Expand Down Expand Up @@ -231,9 +249,19 @@ object CassandraSourceRelation {
""".stripMargin
)

val EnableWhereClauseOptimizationParam = ConfigParameter[Boolean](
name = "spark.cassandra.sql.enable.where.clause.optimization",
section = ReferenceSection,
default = false,
description =
"""Connector will try to optimize sql query `where`-clause, to increase
| number of filters that can be pushed down. Experimental.""".stripMargin
)

val Properties = Seq(
AdditionalCassandraPushDownRulesParam,
TableSizeInBytesParam
TableSizeInBytesParam,
EnableWhereClauseOptimizationParam
)

val defaultClusterName = "default"
Expand All @@ -249,6 +277,9 @@ object CassandraSourceRelation {
val conf =
consolidateConfs(sparkConf, sqlConf, tableRef, options.cassandraConfs)
val tableSizeInBytesString = conf.getOption(TableSizeInBytesParam.name)
val enableWhereClauseOptimization =
conf.getOption(EnableWhereClauseOptimizationParam.name)
.map( _.equalsIgnoreCase("true") ).getOrElse(false)
val cassandraConnector =
new CassandraConnector(CassandraConnectorConf(conf))
val tableSizeInBytes = tableSizeInBytesString match {
Expand All @@ -274,6 +305,7 @@ object CassandraSourceRelation {
userSpecifiedSchema = schema,
filterPushdown = options.pushdown,
tableSizeInBytes = tableSizeInBytes,
enableWhereClauseOptimization = enableWhereClauseOptimization,
connector = cassandraConnector,
readConf = readConf,
writeConf = writeConf,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package org.apache.spark.sql.cassandra

import org.apache.spark.sql.sources._

/**
* Optimizer will try to transform pushdown filter into `sum of products`.
* So that the filter like
* '(field1 < 3 OR field1 > 7) AND (field2 = 'val1' OR field2 = 'val2')'
* will become equivalent
* 'field1 < 3 AND field2 = "val1" OR field1 < 3 AND field2 = "val2" OR
* field1 > 7 AND field2 = "val1" OR field1 > 7 AND field2 = "val2"'
*
*/
object FiltersOptimizer{

/**
* @param filters Array of logical statements [[org.apache.spark.sql.sources.Filter]]
* that forms `where`-clause with `AND` operator, for example:
* val Array(f1, f2, ... fn) = ... // such that `where f1 AND f2 AND ... AND fn`*
* @return list of filters in disjunctive form
*/
def build(filters: Array[Filter]): List[Array[Filter]] = {
if (filters.nonEmpty) {
val ast = filters.reduce((left, right) => And(left, right))
(toNNF andThen toDNF andThen traverse andThen groupByAnd).apply(ast)
} else {
List.empty
}
}

private[cassandra] def dist(predL: Filter, predR: Filter): Filter = (predL, predR) match {
case (Or(l, r), p) => Or(dist(l, p), dist(r, p))
case (p, Or(l, r)) => Or(dist(p, l), dist(p, r))
case (l, r) => And(l, r)
}

/** The 'toNNF' function converts expressions to negation normal form. This
* function is total: it's defined for all expressions, not just those which
* only use negation, conjunction and disjunction, although all expressions in
* negation normal form do in fact only use those connectives.
*
* Then de Morgan's laws are applied to convert negated
* conjunctions and disjunctions into the conjunction or disjunction of the
* negation of their conjuncts: ¬(φ ∧ ψ) is converted to (¬φ ∨ ¬ψ)
* while ¬(φ ∨ ψ) becomes (¬φ ∧ ¬ψ).
*/
private[cassandra] val toNNF: Filter => Filter = {
case a@(EqualTo(_, _) | EqualNullSafe(_, _) | GreaterThan(_, _) |
GreaterThanOrEqual(_, _) | LessThan(_, _) | LessThanOrEqual(_, _) |
In(_, _) | IsNull(_) | IsNotNull(_) |
StringStartsWith(_, _) | StringEndsWith(_, _) | StringContains(_, _)) => a
case a@Not(EqualTo(_, _) | EqualNullSafe(_, _) | In(_, _) |
StringStartsWith(_, _) | StringEndsWith(_, _) | StringContains(_, _)) => a
case Not(GreaterThan(a, v)) => LessThanOrEqual(a, v)
case Not(LessThanOrEqual(a, v)) => GreaterThan(a, v)
case Not(LessThan(a, v)) => GreaterThanOrEqual(a, v)
case Not(GreaterThanOrEqual(a, v)) => LessThan(a, v)
case Not(IsNull(a)) => IsNotNull(a)
case Not(IsNotNull(a)) => IsNull(a)
case Not(Not(p)) => p
case And(l, r) => And(toNNF(l), toNNF(r))
case Not(And(l, r)) => toNNF(Or(Not(l), Not(r)))
case Or(l, r) => Or(toNNF(l), toNNF(r))
case Not(Or(l, r)) => toNNF(And(Not(l), Not(r)))
case p => p
}

/** The 'toDNF' function converts expressions to disjunctive normal form: a
* disjunction of clauses, where a clause is a conjunction of literals
* (variables and negated variables).
*
* The conversion is carried out by first converting the expression into
* negation normal form, and then applying the distributive law.
*/
private[cassandra] val toDNF: Filter => Filter = {
case And(l, r) => dist(toDNF(l), toDNF(r))
case Or(l, r) => Or(toDNF(l), toDNF(r))
case p => p
}

/**
* Traverse over disjunctive clauses of AST
*/
private[cassandra] val traverse: Filter => List[Filter] = {
case Or(l, r) => traverse(l) ++ traverse(r)
case a => a :: Nil
}

/**
* Group all conjunctive clauses into Array[Filter]
* f1 && f2 && ... && fn => Array(f1, f2, ... fn)
*/
private[cassandra] val andToArray: Filter => Array[Filter] = {
case And(l, r) => andToArray(l) ++ andToArray(r)
case a => Array(a)
}

private[cassandra] val groupByAnd: List[Filter] => List[Array[Filter]] = _.map(andToArray)

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package org.apache.spark.sql.cassandra

import org.apache.spark.sql.sources._

import org.scalacheck._
import org.scalacheck.Prop.forAll
import org.scalatest.prop.PropertyChecks
import org.scalatest.{FlatSpec, ShouldMatchers}

class FiltersOptimizerCheck extends FlatSpec with PropertyChecks with ShouldMatchers {

// For testing purpose
case object True extends Filter
case object False extends Filter

val genFullTree = for {
size <- Gen.choose(0, 500)
tree <- genTree(size)
} yield tree

def genTree(maxDepth: Int): Gen[Filter] =
if (maxDepth == 0) leaf else Gen.oneOf(leaf, genAnd(maxDepth), genOr(maxDepth), genNot(maxDepth))

def genAnd(maxDepth: Int): Gen[Filter] = for {
depthL <- Gen.choose(0, maxDepth - 1)
depthR <- Gen.choose(0, maxDepth - 1)
left <- genTree(depthL)
right <- genTree(depthR)
} yield And(left, right)

def genOr(maxDepth: Int): Gen[Filter] = for {
depthL <- Gen.choose(0, maxDepth - 1)
depthR <- Gen.choose(0, maxDepth - 1)
left <- genTree(depthL)
right <- genTree(depthR)
} yield Or(left, right)

def genNot(maxDepth: Int): Gen[Filter] = for {
depth <- Gen.choose(0, maxDepth - 1)
expr <- genTree(depth)
} yield Not(expr)

def leaf: Gen[Filter] = Gen.oneOf(True, False)

/**
* Evaluate logical ADT
**/
private def eval(clause: Filter): Boolean = clause match {
case And(left, right) => eval(left) && eval(right)
case Or(left, right) => eval(left) || eval(right)
case Not(predicate) => !eval(predicate)
case True => true
case False => false
}

"FiltersOptimizer" should "generate equivalent disjunction normal form for arbitrary logical statement" in {
forAll(genFullTree){ expr =>
val dnf = (FiltersOptimizer.toNNF andThen FiltersOptimizer.toDNF).apply(expr)
assert(eval(dnf) == eval(expr))
}
}

}