Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARKC-312: Implementing FilterOptimizer #1019

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ may also be used. ("127.0.0.1,192.168.0.1")

<table class="table">
<tr><th>Property Name</th><th>Default</th><th>Description</th></tr>
<tr>
<td><code>sql.enable.where.clause.optimization</code></td>
<td>false</td>
<td>Connector will try to optimize sql query `where`-clause, to increase
number of filters that can be pushed down. Experimental.</td>
</tr>
<tr>
<td><code>sql.pushdown.additionalClasses</code></td>
<td></td>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,22 @@ class CassandraPrunedFilteredScanSpec extends SparkCassandraITFlatSpecBase with
s"""CREATE TABLE IF NOT EXISTS $ks.fields
|(k INT, a TEXT, b TEXT, c TEXT, d TEXT, e TEXT, PRIMARY KEY (k)) """
.stripMargin)
},
Future {
session.execute(
s"""CREATE TABLE IF NOT EXISTS $ks.metrics
|(k TEXT, a INT, b INT, c INT, PRIMARY KEY (k, a, b)) """
.stripMargin)
}
)
}
}

val colorOptions = Map("keyspace" -> ks, "table" -> "colors")
val fieldsOptions = Map("keyspace" -> ks, "table" -> "fields")
val metricsOptions = Map("keyspace" -> ks, "table" -> "metrics")
val withPushdown = Map("pushdown" -> "true")
val withWhereClauseOptimizationEnabled = Map("spark.cassandra.sql.enable.where.clause.optimization" -> "true")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace the string here with the parameter defined in the conf file
EnableWhereClauseOptimizationParam.name just in case we change things later :)

val withoutPushdown = Map("pushdown" -> "false")

"CassandraPrunedFilteredScan" should "pushdown predicates for clustering keys" in {
Expand Down Expand Up @@ -74,6 +82,16 @@ class CassandraPrunedFilteredScanSpec extends SparkCassandraITFlatSpecBase with
cts.get.selectedColumnNames should contain theSameElementsAs Seq("b", "c", "d")
}

it should "optimize table scan if all filters can be pushed down" in {
val fieldsDF = sqlContext.read.format(cassandraFormat).options(metricsOptions ++ withPushdown ++ withWhereClauseOptimizationEnabled ).load()
val df = fieldsDF.filter("a = 5 and (b > 5 or b < 3)")
val executionPlan = df.queryExecution.executedPlan
val cts = findAllCassandraTableScanRDD(executionPlan)
cts.nonEmpty shouldBe true
cts.head.where shouldBe CqlWhereClause(Seq(""""a" = ? AND "b" > ?"""), List(5, 5))
cts.last.where shouldBe CqlWhereClause(Seq(""""a" = ? AND "b" < ?"""), List(5, 3))
}

def findCassandraTableScanRDD(sparkPlan: SparkPlan): Option[CassandraTableScanRDD[_]] = {
def _findCassandraTableScanRDD(rdd: RDD[_]): Option[CassandraTableScanRDD[_]] = {
rdd match {
Expand All @@ -92,4 +110,22 @@ class CassandraPrunedFilteredScanSpec extends SparkCassandraITFlatSpecBase with
}
}

}
def findAllCassandraTableScanRDD(sparkPlan: SparkPlan): List[CassandraTableScanRDD[_]] = {
def _findAllCassandraTableScanRDD(rdd: RDD[_]): List[CassandraTableScanRDD[_]] = {
rdd match {
case ctsrdd: CassandraTableScanRDD[_] => List(ctsrdd)
case other: RDD[_] => other.dependencies.iterator
.flatMap(dep => _findAllCassandraTableScanRDD(dep.rdd)).toList
}
}

sparkPlan match {
case prdd: RDDScanExec => _findAllCassandraTableScanRDD(prdd.rdd)
case prdd: RowDataSourceScanExec => _findAllCassandraTableScanRDD(prdd.rdd)
case filter: FilterExec => findAllCassandraTableScanRDD(filter.child)
case wsc: WholeStageCodegenExec => findAllCassandraTableScanRDD(wsc.child)
case _ => List.empty
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,7 @@ package org.apache.spark.sql.cassandra
import java.net.InetAddress
import java.util.UUID

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.cassandra.CassandraSQLRow.CassandraSQLRowReader
import org.apache.spark.sql.cassandra.DataTypeConverter._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SQLContext, sources}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.SparkConf

import com.datastax.spark.connector.cql.{CassandraConnector, CassandraConnectorConf, Schema}
import com.datastax.spark.connector.rdd.partitioner.CassandraPartitionGenerator._
import com.datastax.spark.connector.rdd.partitioner.DataSizeEstimates
import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.forSystemLocalPartitioner
import com.datastax.spark.connector.rdd.{CassandraRDD, ReadConf}
Expand All @@ -22,6 +12,14 @@ import com.datastax.spark.connector.util.Quote._
import com.datastax.spark.connector.util.{ConfigParameter, Logging, ReflectionUtil}
import com.datastax.spark.connector.writer.{SqlRowWriter, WriteConf}
import com.datastax.spark.connector.{SomeColumns, _}
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.cassandra.CassandraSQLRow.CassandraSQLRowReader
import org.apache.spark.sql.cassandra.DataTypeConverter._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SQLContext, sources}
import org.apache.spark.unsafe.types.UTF8String

/**
* Implements [[BaseRelation]]]], [[InsertableRelation]]]] and [[PrunedFilteredScan]]]]
Expand All @@ -34,6 +32,7 @@ private[cassandra] class CassandraSourceRelation(
userSpecifiedSchema: Option[StructType],
filterPushdown: Boolean,
tableSizeInBytes: Option[Long],
enableWhereClauseOptimization: Boolean,
connector: CassandraConnector,
readConf: ReadConf,
writeConf: WriteConf,
Expand Down Expand Up @@ -79,7 +78,15 @@ private[cassandra] class CassandraSourceRelation(
def buildScan(): RDD[Row] = baseRdd.asInstanceOf[RDD[Row]]

override def unhandledFilters(filters: Array[Filter]): Array[Filter] = filterPushdown match {
case true => predicatePushDown(filters).handledBySpark.toArray
case true =>
val optimizedFilters = FiltersOptimizer(filters).build()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it may be better if we got the FilterOptimzer into the predicatePushDown function, then I think we could skip having it written in a bunch of places.

val optimizationCanBeApplied = isOptimizationAvailable(optimizedFilters)
if(optimizationCanBeApplied) {
// all such filters are the same, take first one
predicatePushDown(optimizedFilters.head).handledBySpark.toArray
} else {
predicatePushDown(filters).handledBySpark.toArray
}
case false => filters
}

Expand Down Expand Up @@ -125,12 +132,29 @@ private[cassandra] class CassandraSourceRelation(
finalPushdown
}

private def isOptimizationAvailable(optimizedFilters: List[Array[Filter]]): Boolean =
enableWhereClauseOptimization && optimizedFilters.size > 1 &&
optimizedFilters.sliding(2).forall{ set =>
// check whether all non-pushed down filters are equals for each separate rdd
predicatePushDown(set.head).handledBySpark == predicatePushDown(set.last).handledBySpark
}


override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
val prunedRdd = maybeSelect(baseRdd, requiredColumns)
val prunedFilteredRdd = {
if(filterPushdown) {
val pushdownFilters = predicatePushDown(filters).handledByCassandra.toArray
val filteredRdd = maybePushdownFilters(prunedRdd, pushdownFilters)
val optimizedFilters = new FiltersOptimizer(filters).build()
val optimizationCanBeApplied = isOptimizationAvailable(optimizedFilters)
val filteredRdd = if(optimizationCanBeApplied) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a rather dangerous optimization sometimes, so I think we should default to off. For example

Table where x < 3 or x > 5 and x ranges from 1 to 10000. Doing two scans here is probably much more expensive than a single scan.

Copy link
Contributor Author

@ponkin ponkin Nov 19, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that both scans will be done in parallel. By default this option is set to false.

optimizedFilters.map { predicate =>
val pushdownFilters = predicatePushDown(predicate).handledByCassandra.toArray
maybePushdownFilters(prunedRdd, pushdownFilters).asInstanceOf[RDD[Row]]
}.reduce(_ union _)
} else {
val pushdownFilters = predicatePushDown(filters).handledByCassandra.toArray
maybePushdownFilters(prunedRdd, pushdownFilters)
}
filteredRdd.asInstanceOf[RDD[Row]]
} else {
prunedRdd
Expand Down Expand Up @@ -229,9 +253,19 @@ object CassandraSourceRelation {
""".stripMargin
)

val EnableWhereClauseOptimizationParam = ConfigParameter[Boolean](
name = "spark.cassandra.sql.enable.where.clause.optimization",
section = ReferenceSection,
default = false,
description =
"""Connector will try to optimize sql query `where`-clause, to increase
| number of filters that can be pushed down. Experimental.""".stripMargin
)

val Properties = Seq(
AdditionalCassandraPushDownRulesParam,
TableSizeInBytesParam
TableSizeInBytesParam,
EnableWhereClauseOptimizationParam
)

val defaultClusterName = "default"
Expand All @@ -247,6 +281,9 @@ object CassandraSourceRelation {
val conf =
consolidateConfs(sparkConf, sqlConf, tableRef, options.cassandraConfs)
val tableSizeInBytesString = conf.getOption(TableSizeInBytesParam.name)
val enableWhereClauseOptimization =
conf.getOption(EnableWhereClauseOptimizationParam.name)
.map( _.equalsIgnoreCase("true") ).getOrElse(false)
val cassandraConnector =
new CassandraConnector(CassandraConnectorConf(conf))
val tableSizeInBytes = tableSizeInBytesString match {
Expand All @@ -272,6 +309,7 @@ object CassandraSourceRelation {
userSpecifiedSchema = schema,
filterPushdown = options.pushdown,
tableSizeInBytes = tableSizeInBytes,
enableWhereClauseOptimization = enableWhereClauseOptimization,
connector = cassandraConnector,
readConf = readConf,
writeConf = writeConf,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package org.apache.spark.sql.cassandra

import org.apache.spark.sql.sources._


/**
* Optimizer will try to transform pushdown filter into `sum of products`.
* So that the filter like
* '(field1 < 3 OR field1 > 7) AND (field2 = 'val1' OR field2 = 'val2')'
* will become equivalent
* 'field1 < 3 AND field2 = "val1" OR field1 < 3 AND field2 = "val2" OR
* field1 > 7 AND field2 = "val1" OR field1 > 7 AND field2 = "val2"'
*
* @param filters Array of logical statements [[org.apache.spark.sql.sources.Filter]]
* that forms `where`-clause with `AND` operator, for example:
* val Array(f1, f2, ... fn) = ... // such that `where f1 AND f2 AND ... AND fn`
*
*/
class FiltersOptimizer(filters: Array[Filter]) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused why there is a separate class here, do we ever use this without calling .build() immediately after?


private val fullFilterAst =
if (filters.nonEmpty) Some(filters.reduce((left, right) => And(left, right))) else None

import FiltersOptimizer._

def build(): List[Array[Filter]] = fullFilterAst match {
case Some(ast) => (toNNF andThen toDNF andThen traverse andThen groupByAnd).apply(ast)
case None => List.empty
}

}

object FiltersOptimizer{

def apply(filters: Array[Filter]): FiltersOptimizer = new FiltersOptimizer(filters)

private[cassandra] def dist(predL: Filter, predR: Filter): Filter = (predL, predR) match {
case (Or(l, r), p) => Or(dist(l, p), dist(r, p))
case (p, Or(l, r)) => Or(dist(p, l), dist(p, r))
case (l, r) => And(l, r)
}

/** The 'toNNF' function converts expressions to negation normal form. This
* function is total: it's defined for all expressions, not just those which
* only use negation, conjunction and disjunction, although all expressions in
* negation normal form do in fact only use those connectives.
*
* Then de Morgan's laws are applied to convert negated
* conjunctions and disjunctions into the conjunction or disjunction of the
* negation of their conjuncts: ¬(φ ∧ ψ) is converted to (¬φ ∨ ¬ψ)
* while ¬(φ ∨ ψ) becomes (¬φ ∧ ¬ψ).
*/
private[cassandra] val toNNF: Filter => Filter = {
case a@(EqualTo(_, _) | EqualNullSafe(_, _) | GreaterThan(_, _) |
GreaterThanOrEqual(_, _) | LessThan(_, _) | LessThanOrEqual(_, _) |
In(_, _) | IsNull(_) | IsNotNull(_) |
StringStartsWith(_, _) | StringEndsWith(_, _) | StringContains(_, _)) => a
case a@Not(EqualTo(_, _) | EqualNullSafe(_, _) | In(_, _) |
StringStartsWith(_, _) | StringEndsWith(_, _) | StringContains(_, _)) => a
case Not(GreaterThan(a, v)) => LessThanOrEqual(a, v)
case Not(LessThanOrEqual(a, v)) => GreaterThan(a, v)
case Not(LessThan(a, v)) => GreaterThanOrEqual(a, v)
case Not(GreaterThanOrEqual(a, v)) => LessThan(a, v)
case Not(IsNull(a)) => IsNotNull(a)
case Not(IsNotNull(a)) => IsNull(a)
case Not(Not(p)) => p
case And(l, r) => And(toNNF(l), toNNF(r))
case Not(And(l, r)) => toNNF(Or(Not(l), Not(r)))
case Or(l, r) => Or(toNNF(l), toNNF(r))
case Not(Or(l, r)) => toNNF(And(Not(l), Not(r)))
case p => p
}

/** The 'toDNF' function converts expressions to disjunctive normal form: a
* disjunction of clauses, where a clause is a conjunction of literals
* (variables and negated variables).
*
* The conversion is carried out by first converting the expression into
* negation normal form, and then applying the distributive law.
*/
private[cassandra] val toDNF: Filter => Filter = {
case And(l, r) => dist(toDNF(l), toDNF(r))
case Or(l, r) => Or(toDNF(l), toDNF(r))
case p => p
}

/**
* Traverse over disjunctive clauses of AST
*/
private[cassandra] val traverse: Filter => List[Filter] = {
case Or(l, r) => traverse(l) ++ traverse(r)
case a => a :: Nil
}

/**
* Group all conjunctive clauses into Array[Filter]
* f1 && f2 && ... && fn => Array(f1, f2, ... fn)
*/
private[cassandra] val andToArray: Filter => Array[Filter] = {
case And(l, r) => andToArray(l) ++ andToArray(r)
case a => Array(a)
}

private[cassandra] val groupByAnd: List[Filter] => List[Array[Filter]] = _.map(andToArray)

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package org.apache.spark.sql.cassandra

import org.apache.spark.sql.sources._

import org.scalacheck._
import org.scalacheck.Prop.forAll
import org.scalatest.prop.PropertyChecks
import org.scalatest.{FlatSpec, ShouldMatchers}

class FiltersOptimizerCheck extends FlatSpec with PropertyChecks with ShouldMatchers {

// For testing purpose
case object True extends Filter
case object False extends Filter

val genFullTree = for {
size <- Gen.choose(0, 500)
tree <- genTree(size)
} yield tree

def genTree(maxDepth: Int): Gen[Filter] =
if (maxDepth == 0) leaf else Gen.oneOf(leaf, genAnd(maxDepth), genOr(maxDepth), genNot(maxDepth))

def genAnd(maxDepth: Int): Gen[Filter] = for {
depthL <- Gen.choose(0, maxDepth - 1)
depthR <- Gen.choose(0, maxDepth - 1)
left <- genTree(depthL)
right <- genTree(depthR)
} yield And(left, right)

def genOr(maxDepth: Int): Gen[Filter] = for {
depthL <- Gen.choose(0, maxDepth - 1)
depthR <- Gen.choose(0, maxDepth - 1)
left <- genTree(depthL)
right <- genTree(depthR)
} yield Or(left, right)

def genNot(maxDepth: Int): Gen[Filter] = for {
depth <- Gen.choose(0, maxDepth - 1)
expr <- genTree(depth)
} yield Not(expr)

def leaf: Gen[Filter] = Gen.oneOf(True, False)

/**
* Evaluate logical ADT
**/
private def eval(clause: Filter): Boolean = clause match {
case And(left, right) => eval(left) && eval(right)
case Or(left, right) => eval(left) || eval(right)
case Not(predicate) => !eval(predicate)
case True => true
case False => false
}

"FiltersOptimizer" should "generate equivalent disjunction normal form for arbitrary logical statement" in {
forAll(genFullTree){ expr =>
val dnf = (FiltersOptimizer.toNNF andThen FiltersOptimizer.toDNF).apply(expr)
assert(eval(dnf) == eval(expr))
}
}

}