Skip to content

Commit

Permalink
Updating Spark Migration Rules (#658)
Browse files Browse the repository at this point in the history
- `new JavaSparkContext()` enclosing node
- spy SparkSession for scala
- java: not touching existent SparkSession.builder()
- additional tests for spark_conf
- integration tests to assure that all rules terminate
- correct paths for tests
- Fix GradientBoostedTrees
- fix infinite loop: accessing_execution_plan w/ constrained filter
- try/except BaseException for each glob file
- future annotations for py 3.8
  • Loading branch information
dvmarcilio authored May 1, 2024
1 parent 715fca9 commit 1ce871f
Show file tree
Hide file tree
Showing 25 changed files with 774 additions and 103 deletions.
10 changes: 10 additions & 0 deletions plugins/spark_upgrade/accessing_execution_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from polyglot_piranha import (
Rule,
Filter,
)

class AccessingExecutionPlan(ExecutePiranha):
Expand All @@ -34,6 +35,15 @@ def get_rules(self) -> List[Rule]:
replace_node="*",
replace="@dataframe.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].initialPlan",
holes={"queryExec", "execPlan"},
filters={Filter(
enclosing_node="(var_definition) @var_def",
not_contains=["""(
(field_expression
field: (identifier) @field_id
(#eq? @field_id "initialPlan")
) @field_expr
)"""],
)}
)
return [transform_IDFModel_args]

Expand Down
23 changes: 23 additions & 0 deletions plugins/spark_upgrade/gradient_boost_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,21 @@

from polyglot_piranha import (
Rule,
Filter,
)

_INSTANCE_EXPR_QUERY = """(
(instance_expression
(type_identifier) @typ_id
arguments: (arguments
(_)
(_)
(_)
)
(#eq? @typ_id "Instance")
) @inst
)"""


class GradientBoostTrees(ExecutePiranha):
def __init__(self, paths_to_codebase: List[str]):
Expand Down Expand Up @@ -44,6 +57,11 @@ def get_rules(self) -> List[Rule]:
@seed,
@featureSubsetStrategy
)""",
filters={
Filter(
not_contains=[_INSTANCE_EXPR_QUERY],
)
},
holes={"gbt"},
)

Expand All @@ -62,6 +80,11 @@ def get_rules(self) -> List[Rule]:
@seed,
@featureSubsetStrategy
)""",
filters={
Filter(
not_contains=[_INSTANCE_EXPR_QUERY],
)
},
holes={"gbt"},
)
return [gradient_boost_trees, gradient_boost_trees_comment]
Expand Down
152 changes: 152 additions & 0 deletions plugins/spark_upgrade/java_spark_context/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) 2024 Uber Technologies, Inc.

# <p>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
# except in compliance with the License. You may obtain a copy of the License at
# <p>http://www.apache.org/licenses/LICENSE-2.0

# <p>Unless required by applicable law or agreed to in writing, software distributed under the
# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations
from typing import Any, List, Dict
from execute_piranha import ExecutePiranha

from polyglot_piranha import (
execute_piranha,
Filter,
OutgoingEdges,
Rule,
PiranhaOutputSummary,
Match,
PiranhaArguments,
RuleGraph,
)

_JAVASPARKCONTEXT_OCE_QUERY = """(
(object_creation_expression
type: (_) @oce_typ
(#eq? @oce_typ "JavaSparkContext")
) @oce
)"""

_NEW_SPARK_CONF_CHAIN_QUERY = """(
(argument_list
.
(method_invocation) @mi
.
(#match? @mi "^new SparkConf()\\.")
)
)""" # matches a chain of method invocations starting with `new SparkConf().`; the chain is the only argument of an argument_list (indicated by the surrounding anchors `.`).

# Note that we don't remove the unused `SparkConf` import; that will be automated somewhere else.
_ADD_IMPORT_RULE = Rule(
name="add_import_rule",
query="""(
(program
(import_declaration) @imp_decl
)
)""", # matches the last import
replace_node="imp_decl",
replace="@imp_decl\nimport org.apache.spark.sql.SparkSession;",
is_seed_rule=False,
filters={
Filter( # avoids infinite loop
enclosing_node="((program) @unit)",
not_contains=[("cs import org.apache.spark.sql.SparkSession;")],
),
},
)


class JavaSparkContextChange(ExecutePiranha):
def __init__(self, paths_to_codebase: List[str], language: str = "java"):
super().__init__(
paths_to_codebase=paths_to_codebase,
substitutions={
"spark_conf": "SparkConf",
},
language=language,
)

def __call__(self) -> dict[str, bool]:
if self.language != "java":
return {}

piranha_args = self.get_piranha_arguments()
summaries: list[PiranhaOutputSummary] = execute_piranha(piranha_args)
assert summaries is not None

for summary in summaries:
file_path: str = summary.path
match: tuple[str, Match]
for match in summary.matches:
if match[0] == "java_match_rule":
matched_str = match[1].matched_string

replace_str = matched_str.replace(
"new SparkConf()",
'SparkSession.builder().config("spark.sql.legacy.allowUntypedScalaUDF", "true")',
)
replace_str = replace_str.replace(".setAppName(", ".appName(")
replace_str = replace_str.replace(".setMaster(", ".master(")
replace_str = replace_str.replace(".set(", ".config(")
replace_str += ".getOrCreate().sparkContext()"

# assumes that there's only one match on the file
rewrite_rule = Rule(
name="rewrite_rule",
query=_NEW_SPARK_CONF_CHAIN_QUERY,
replace_node="mi",
replace=replace_str,
filters={
Filter(enclosing_node=_JAVASPARKCONTEXT_OCE_QUERY),
},
)

rule_graph = RuleGraph(
rules=[rewrite_rule, _ADD_IMPORT_RULE],
edges=[
OutgoingEdges(
"rewrite_rule",
to=["add_import_rule"],
scope="File",
)
],
)
execute_piranha(
PiranhaArguments(
language=self.language,
rule_graph=rule_graph,
paths_to_codebase=[file_path],
)
)

if not summaries:
return {self.step_name(): False}

return {self.step_name(): True}

def step_name(self) -> str:
return "JavaSparkContext Change"

def get_rules(self) -> List[Rule]:
if self.language != "java":
return []

java_match_rule = Rule(
name="java_match_rule",
query=_NEW_SPARK_CONF_CHAIN_QUERY,
filters={
Filter(enclosing_node=_JAVASPARKCONTEXT_OCE_QUERY),
},
)

return [java_match_rule]

def get_edges(self) -> List[OutgoingEdges]:
return []

def summaries_to_custom_dict(self, _) -> Dict[str, Any]:
return {}
67 changes: 48 additions & 19 deletions plugins/spark_upgrade/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging

import glob

from update_calendar_interval import UpdateCalendarInterval
from IDF_model_signature_change import IDFModelSignatureChange
from accessing_execution_plan import AccessingExecutionPlan
Expand All @@ -20,6 +22,9 @@
from sql_new_execution import SQLNewExecutionChange
from query_test_check_answer_change import QueryTestCheckAnswerChange
from spark_config import SparkConfigChange
from java_spark_context import JavaSparkContextChange
from scala_session_builder import ScalaSessionBuilder


def _parse_args():
parser = argparse.ArgumentParser(
Expand All @@ -43,39 +48,63 @@ def _parse_args():
logging.basicConfig(format=FORMAT)
logging.getLogger().setLevel(logging.DEBUG)


def main():
args = _parse_args()
if args.new_version == "3.3":
upgrade_to_spark_3_3(args.path_to_codebase)


def upgrade_to_spark_3_3(path_to_codebase):
update_calendar_interval = UpdateCalendarInterval([path_to_codebase])
def upgrade_to_spark_3_3(path_to_codebase: str):
"""Wraps calls to Piranha with try/except to prevent it failing on a single file.
We catch `BaseException`, as pyo3 `PanicException` extends it."""
for scala_file in glob.glob(f"{path_to_codebase}/**/*.scala", recursive=True):
try:
update_file(scala_file)
except BaseException as e:
logging.error(f"Error running for file file {scala_file}: {e}")

for java_file in glob.glob(f"{path_to_codebase}/**/*.java", recursive=True):
try:
update_file(java_file)
except BaseException as e:
logging.error(f"Error running for file file {java_file}: {e}")


def update_file(file_path: str):
update_calendar_interval = UpdateCalendarInterval([file_path])
_ = update_calendar_interval()
idf_model_signature_change = IDFModelSignatureChange([path_to_codebase])

idf_model_signature_change = IDFModelSignatureChange([file_path])
_ = idf_model_signature_change()
accessing_execution_plan = AccessingExecutionPlan([path_to_codebase])

accessing_execution_plan = AccessingExecutionPlan([file_path])
_ = accessing_execution_plan()

gradient_boost_trees = GradientBoostTrees([path_to_codebase])
gradient_boost_trees = GradientBoostTrees([file_path])
_ = gradient_boost_trees()
calculator_signature_change = CalculatorSignatureChange([path_to_codebase])

calculator_signature_change = CalculatorSignatureChange([file_path])
_ = calculator_signature_change()
sql_new_execution = SQLNewExecutionChange([path_to_codebase])

sql_new_execution = SQLNewExecutionChange([file_path])
_ = sql_new_execution()
query_test_check_answer_change = QueryTestCheckAnswerChange([path_to_codebase])

query_test_check_answer_change = QueryTestCheckAnswerChange([file_path])
_ = query_test_check_answer_change()
spark_config = SparkConfigChange([path_to_codebase])

spark_config = SparkConfigChange([file_path])
_ = spark_config()
spark_config = SparkConfigChange([path_to_codebase], language="java")

spark_config = SparkConfigChange([file_path], language="java")
_ = spark_config()


javasparkcontext = JavaSparkContextChange([file_path], language="java")
_ = javasparkcontext()

scalasessionbuilder = ScalaSessionBuilder([file_path], language="scala")
_ = scalasessionbuilder()


if __name__ == "__main__":
main()
Loading

0 comments on commit 1ce871f

Please sign in to comment.