diff --git a/plugins/spark_upgrade/accessing_execution_plan.py b/plugins/spark_upgrade/accessing_execution_plan.py index d96ed1c12..96ddbbc7b 100644 --- a/plugins/spark_upgrade/accessing_execution_plan.py +++ b/plugins/spark_upgrade/accessing_execution_plan.py @@ -14,6 +14,7 @@ from polyglot_piranha import ( Rule, + Filter, ) class AccessingExecutionPlan(ExecutePiranha): @@ -34,6 +35,15 @@ def get_rules(self) -> List[Rule]: replace_node="*", replace="@dataframe.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].initialPlan", holes={"queryExec", "execPlan"}, + filters={Filter( + enclosing_node="(var_definition) @var_def", + not_contains=["""( + (field_expression + field: (identifier) @field_id + (#eq? @field_id "initialPlan") + ) @field_expr + )"""], + )} ) return [transform_IDFModel_args] diff --git a/plugins/spark_upgrade/gradient_boost_trees.py b/plugins/spark_upgrade/gradient_boost_trees.py index a7e5eb13d..7df6cf307 100644 --- a/plugins/spark_upgrade/gradient_boost_trees.py +++ b/plugins/spark_upgrade/gradient_boost_trees.py @@ -14,8 +14,21 @@ from polyglot_piranha import ( Rule, + Filter, ) +_INSTANCE_EXPR_QUERY = """( + (instance_expression + (type_identifier) @typ_id + arguments: (arguments + (_) + (_) + (_) + ) + (#eq? @typ_id "Instance") + ) @inst +)""" + class GradientBoostTrees(ExecutePiranha): def __init__(self, paths_to_codebase: List[str]): @@ -44,6 +57,11 @@ def get_rules(self) -> List[Rule]: @seed, @featureSubsetStrategy )""", + filters={ + Filter( + not_contains=[_INSTANCE_EXPR_QUERY], + ) + }, holes={"gbt"}, ) @@ -62,6 +80,11 @@ def get_rules(self) -> List[Rule]: @seed, @featureSubsetStrategy )""", + filters={ + Filter( + not_contains=[_INSTANCE_EXPR_QUERY], + ) + }, holes={"gbt"}, ) return [gradient_boost_trees, gradient_boost_trees_comment] diff --git a/plugins/spark_upgrade/java_spark_context/__init__.py b/plugins/spark_upgrade/java_spark_context/__init__.py new file mode 100644 index 000000000..e86e8ef42 --- /dev/null +++ b/plugins/spark_upgrade/java_spark_context/__init__.py @@ -0,0 +1,152 @@ +# Copyright (c) 2024 Uber Technologies, Inc. + +#

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file +# except in compliance with the License. You may obtain a copy of the License at +#

http://www.apache.org/licenses/LICENSE-2.0 + +#

Unless required by applicable law or agreed to in writing, software distributed under the +# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Any, List, Dict +from execute_piranha import ExecutePiranha + +from polyglot_piranha import ( + execute_piranha, + Filter, + OutgoingEdges, + Rule, + PiranhaOutputSummary, + Match, + PiranhaArguments, + RuleGraph, +) + +_JAVASPARKCONTEXT_OCE_QUERY = """( + (object_creation_expression + type: (_) @oce_typ + (#eq? @oce_typ "JavaSparkContext") + ) @oce +)""" + +_NEW_SPARK_CONF_CHAIN_QUERY = """( + (argument_list + . + (method_invocation) @mi + . + (#match? @mi "^new SparkConf()\\.") + ) +)""" # matches a chain of method invocations starting with `new SparkConf().`; the chain is the only argument of an argument_list (indicated by the surrounding anchors `.`). + +# Note that we don't remove the unused `SparkConf` import; that will be automated somewhere else. +_ADD_IMPORT_RULE = Rule( + name="add_import_rule", + query="""( + (program + (import_declaration) @imp_decl + ) + )""", # matches the last import + replace_node="imp_decl", + replace="@imp_decl\nimport org.apache.spark.sql.SparkSession;", + is_seed_rule=False, + filters={ + Filter( # avoids infinite loop + enclosing_node="((program) @unit)", + not_contains=[("cs import org.apache.spark.sql.SparkSession;")], + ), + }, +) + + +class JavaSparkContextChange(ExecutePiranha): + def __init__(self, paths_to_codebase: List[str], language: str = "java"): + super().__init__( + paths_to_codebase=paths_to_codebase, + substitutions={ + "spark_conf": "SparkConf", + }, + language=language, + ) + + def __call__(self) -> dict[str, bool]: + if self.language != "java": + return {} + + piranha_args = self.get_piranha_arguments() + summaries: list[PiranhaOutputSummary] = execute_piranha(piranha_args) + assert summaries is not None + + for summary in summaries: + file_path: str = summary.path + match: tuple[str, Match] + for match in summary.matches: + if match[0] == "java_match_rule": + matched_str = match[1].matched_string + + replace_str = matched_str.replace( + "new SparkConf()", + 'SparkSession.builder().config("spark.sql.legacy.allowUntypedScalaUDF", "true")', + ) + replace_str = replace_str.replace(".setAppName(", ".appName(") + replace_str = replace_str.replace(".setMaster(", ".master(") + replace_str = replace_str.replace(".set(", ".config(") + replace_str += ".getOrCreate().sparkContext()" + + # assumes that there's only one match on the file + rewrite_rule = Rule( + name="rewrite_rule", + query=_NEW_SPARK_CONF_CHAIN_QUERY, + replace_node="mi", + replace=replace_str, + filters={ + Filter(enclosing_node=_JAVASPARKCONTEXT_OCE_QUERY), + }, + ) + + rule_graph = RuleGraph( + rules=[rewrite_rule, _ADD_IMPORT_RULE], + edges=[ + OutgoingEdges( + "rewrite_rule", + to=["add_import_rule"], + scope="File", + ) + ], + ) + execute_piranha( + PiranhaArguments( + language=self.language, + rule_graph=rule_graph, + paths_to_codebase=[file_path], + ) + ) + + if not summaries: + return {self.step_name(): False} + + return {self.step_name(): True} + + def step_name(self) -> str: + return "JavaSparkContext Change" + + def get_rules(self) -> List[Rule]: + if self.language != "java": + return [] + + java_match_rule = Rule( + name="java_match_rule", + query=_NEW_SPARK_CONF_CHAIN_QUERY, + filters={ + Filter(enclosing_node=_JAVASPARKCONTEXT_OCE_QUERY), + }, + ) + + return [java_match_rule] + + def get_edges(self) -> List[OutgoingEdges]: + return [] + + def summaries_to_custom_dict(self, _) -> Dict[str, Any]: + return {} diff --git a/plugins/spark_upgrade/main.py b/plugins/spark_upgrade/main.py index a91c95b7c..1e39d4957 100644 --- a/plugins/spark_upgrade/main.py +++ b/plugins/spark_upgrade/main.py @@ -8,10 +8,12 @@ # License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing permissions and # limitations under the License. - + import argparse import logging +import glob + from update_calendar_interval import UpdateCalendarInterval from IDF_model_signature_change import IDFModelSignatureChange from accessing_execution_plan import AccessingExecutionPlan @@ -20,6 +22,9 @@ from sql_new_execution import SQLNewExecutionChange from query_test_check_answer_change import QueryTestCheckAnswerChange from spark_config import SparkConfigChange +from java_spark_context import JavaSparkContextChange +from scala_session_builder import ScalaSessionBuilder + def _parse_args(): parser = argparse.ArgumentParser( @@ -43,39 +48,63 @@ def _parse_args(): logging.basicConfig(format=FORMAT) logging.getLogger().setLevel(logging.DEBUG) + def main(): args = _parse_args() if args.new_version == "3.3": upgrade_to_spark_3_3(args.path_to_codebase) -def upgrade_to_spark_3_3(path_to_codebase): - update_calendar_interval = UpdateCalendarInterval([path_to_codebase]) +def upgrade_to_spark_3_3(path_to_codebase: str): + """Wraps calls to Piranha with try/except to prevent it failing on a single file. + We catch `BaseException`, as pyo3 `PanicException` extends it.""" + for scala_file in glob.glob(f"{path_to_codebase}/**/*.scala", recursive=True): + try: + update_file(scala_file) + except BaseException as e: + logging.error(f"Error running for file file {scala_file}: {e}") + + for java_file in glob.glob(f"{path_to_codebase}/**/*.java", recursive=True): + try: + update_file(java_file) + except BaseException as e: + logging.error(f"Error running for file file {java_file}: {e}") + + +def update_file(file_path: str): + update_calendar_interval = UpdateCalendarInterval([file_path]) _ = update_calendar_interval() - - idf_model_signature_change = IDFModelSignatureChange([path_to_codebase]) + + idf_model_signature_change = IDFModelSignatureChange([file_path]) _ = idf_model_signature_change() - - accessing_execution_plan = AccessingExecutionPlan([path_to_codebase]) + + accessing_execution_plan = AccessingExecutionPlan([file_path]) _ = accessing_execution_plan() - gradient_boost_trees = GradientBoostTrees([path_to_codebase]) + gradient_boost_trees = GradientBoostTrees([file_path]) _ = gradient_boost_trees() - - calculator_signature_change = CalculatorSignatureChange([path_to_codebase]) + + calculator_signature_change = CalculatorSignatureChange([file_path]) _ = calculator_signature_change() - - sql_new_execution = SQLNewExecutionChange([path_to_codebase]) + + sql_new_execution = SQLNewExecutionChange([file_path]) _ = sql_new_execution() - - query_test_check_answer_change = QueryTestCheckAnswerChange([path_to_codebase]) + + query_test_check_answer_change = QueryTestCheckAnswerChange([file_path]) _ = query_test_check_answer_change() - - spark_config = SparkConfigChange([path_to_codebase]) + + spark_config = SparkConfigChange([file_path]) _ = spark_config() - - spark_config = SparkConfigChange([path_to_codebase], language="java") + + spark_config = SparkConfigChange([file_path], language="java") _ = spark_config() - + + javasparkcontext = JavaSparkContextChange([file_path], language="java") + _ = javasparkcontext() + + scalasessionbuilder = ScalaSessionBuilder([file_path], language="scala") + _ = scalasessionbuilder() + + if __name__ == "__main__": main() diff --git a/plugins/spark_upgrade/scala_session_builder/__init__.py b/plugins/spark_upgrade/scala_session_builder/__init__.py new file mode 100644 index 000000000..4a31f2395 --- /dev/null +++ b/plugins/spark_upgrade/scala_session_builder/__init__.py @@ -0,0 +1,164 @@ +# Copyright (c) 2024 Uber Technologies, Inc. + +#

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file +# except in compliance with the License. You may obtain a copy of the License at +#

http://www.apache.org/licenses/LICENSE-2.0 + +#

Unless required by applicable law or agreed to in writing, software distributed under the +# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Any, List, Dict +from execute_piranha import ExecutePiranha + +from polyglot_piranha import ( + execute_piranha, + Filter, + OutgoingEdges, + Rule, + PiranhaOutputSummary, + Match, + Edit, + PiranhaArguments, + RuleGraph, +) + +VAL_DEF_QUERY = """( + (val_definition + pattern: (identifier) @val_id + type: (type_identifier) @type_id + value: (call_expression + function: (identifier) @func_call + ) + (#eq? @type_id "SparkSession") + (#eq? @func_call "spy") + ) @val_def +)""" + +QUERY = f"""( + (function_definition + body: (block + {VAL_DEF_QUERY} + . + (call_expression + function: (field_expression + value: (field_expression + value: (identifier) @lhs + field: (identifier) @rhs + ) + field: (identifier) @call_name + (#eq? @lhs @val_id) + (#eq? @rhs "sqlContext") + (#eq? @call_name "setConf") + ) + )+ @calls + ) + ) @func_def +)""" + + +class ScalaSessionBuilder(ExecutePiranha): + def __init__(self, paths_to_codebase: List[str], language: str = "scala"): + super().__init__( + paths_to_codebase=paths_to_codebase, + substitutions={}, + language=language, + ) + + def __call__(self) -> dict[str, bool]: + if self.language != "scala": + return {} + + piranha_args = self.get_piranha_arguments() + summaries: list[PiranhaOutputSummary] = execute_piranha(piranha_args) + assert summaries is not None + + for summary in summaries: + file_path: str = summary.path + edit: Edit + if len(summary.rewrites) == 0: + continue + + print(f"rewrites: {len(summary.rewrites)}") + + calls_to_add_str = "" + # the rewrite's edit will have `calls` with all matches + edit = summary.rewrites[0] + if edit.matched_rule == "delete_calls_query": + match: Match = edit.p_match + val_id = match.matches["val_id"] + calls = match.matches["calls"] + print(f"calls: {calls}") + calls_to_add_str = calls.replace( + f"{val_id}.sqlContext.setConf", ".config" + ) + + match = summary.rewrites[0].p_match + val_def = match.matches["val_def"] + + assert isinstance(val_def, str) + assert "getOrCreate()" in val_def + + replace_str = calls_to_add_str + "\n.getOrCreate()" + new_val_def = val_def.replace(".getOrCreate()", replace_str) + + replace_val_def_rule = Rule( + name="replace_val_def_rule", + query=VAL_DEF_QUERY, + replace_node="val_def", + replace=new_val_def, + filters={ + Filter( + enclosing_node="(val_definition) @_vl_def", + not_contains=( + [ + """( + (identifier) @conf_id + (#eq? @conf_id "config") + )""" + ] + ), + ) + }, + ) + + rule_graph = RuleGraph( + rules=[replace_val_def_rule], + edges=[], + ) + execute_piranha( + PiranhaArguments( + language=self.language, + rule_graph=rule_graph, + paths_to_codebase=[file_path], + ) + ) + + if not summaries: + return {self.step_name(): False} + + return {self.step_name(): True} + + def step_name(self) -> str: + return "Spark spy SessionBuilder" + + def get_rules(self) -> List[Rule]: + if self.language != "scala": + return [] + + delete_calls_query = Rule( + name="delete_calls_query", + query=QUERY, + replace_node="calls", + replace="", + ) + + return [delete_calls_query] + + def get_edges(self) -> List[OutgoingEdges]: + return [] + + def summaries_to_custom_dict(self, _) -> Dict[str, Any]: + return {} diff --git a/plugins/spark_upgrade/spark_config/__init__.py b/plugins/spark_upgrade/spark_config/__init__.py index aaa86b6b8..b2d1b5cd6 100644 --- a/plugins/spark_upgrade/spark_config/__init__.py +++ b/plugins/spark_upgrade/spark_config/__init__.py @@ -18,6 +18,39 @@ Rule, ) +_JAVASPARKCONTEXT_OCE_QUERY = """( + (object_creation_expression + type: (_) @oce_typ + (#eq? @oce_typ "JavaSparkContext") + ) @oce +)""" + +_SPARK_SESSION_BUILDER_CHAIN_QUERY = """( + (method_invocation + object: (method_invocation + object: (identifier) @spark_session + name: (identifier) @receiver + ) + (#eq? @spark_session "SparkSession") + (#eq? @receiver "builder") + ) @mi +)""" + +_EXPR_STMT_CHAIN_ENDS_WITH_GETORCREATE_QUERY = """( + (expression_statement + (method_invocation + name: (identifier) @last + ) + (#match? @last "getOrCreate") + ) @expr_stmt +)""" + +_SCALA_CHAIN_ENDS_WITH_GETORCREATE_QUERY = """( + (field_expression + field: (identifier) @last_field + (#eq? @last_field "getOrCreate") + ) @field_expr +)""" class SparkConfigChange(ExecutePiranha): def __init__(self, paths_to_codebase: List[str], language: str = "scala"): @@ -33,28 +66,48 @@ def step_name(self) -> str: return "Spark Config Change" def get_rules(self) -> List[Rule]: + # filters cannot be added without reinstantiating Rule(), so we create the full filter set before + fs = { + Filter( + not_enclosing_node='cs new SparkConf().set("spark.sql.legacy.timeParserPolicy","LEGACY").set("spark.sql.legacy.allowUntypedScalaUDF", "true")' + ), + } + if self.language == "java": + fs.add(Filter(not_enclosing_node=_JAVASPARKCONTEXT_OCE_QUERY)) + fs.add( + Filter(not_enclosing_node=_EXPR_STMT_CHAIN_ENDS_WITH_GETORCREATE_QUERY) + ) + elif self.language == "scala": + fs.add(Filter(not_enclosing_node=_SCALA_CHAIN_ENDS_WITH_GETORCREATE_QUERY)) + update_spark_conf_init = Rule( name="update_spark_conf_init", query="cs new SparkConf()", replace_node="*", replace='new SparkConf().set("spark.sql.legacy.timeParserPolicy","LEGACY").set("spark.sql.legacy.allowUntypedScalaUDF", "true")', - filters={ - Filter( - not_enclosing_node='cs new SparkConf().set("spark.sql.legacy.timeParserPolicy","LEGACY").set("spark.sql.legacy.allowUntypedScalaUDF", "true")' - ) - }, + filters=fs, ) + fs2 = { + Filter( + not_enclosing_node='cs SparkSession.builder().config("spark.sql.legacy.timeParserPolicy","LEGACY").config("spark.sql.legacy.allowUntypedScalaUDF", "true")' + ) + } + if self.language == "java": + fs2.add(Filter(not_enclosing_node=_JAVASPARKCONTEXT_OCE_QUERY)) + fs2.add(Filter(not_enclosing_node=_SPARK_SESSION_BUILDER_CHAIN_QUERY)) + fs2.add( + Filter(not_enclosing_node=_EXPR_STMT_CHAIN_ENDS_WITH_GETORCREATE_QUERY) + ) + elif self.language == "scala": + fs2.add(Filter(not_enclosing_node=_SCALA_CHAIN_ENDS_WITH_GETORCREATE_QUERY)) + update_spark_session_builder_init = Rule( name="update_spark_conf_init", query="cs SparkSession.builder()", replace_node="*", replace='SparkSession.builder().config("spark.sql.legacy.timeParserPolicy","LEGACY").config("spark.sql.legacy.allowUntypedScalaUDF", "true")', - filters={ - Filter( - not_enclosing_node='cs SparkSession.builder().config("spark.sql.legacy.timeParserPolicy","LEGACY").config("spark.sql.legacy.allowUntypedScalaUDF", "true")' - ) - }, + filters=fs2, ) update_import_array_queue = Rule( diff --git a/plugins/spark_upgrade/tests/resources/accessing_execution_plan copy/expected/sample.scala b/plugins/spark_upgrade/tests/resources/accessing_execution_plan copy/expected/sample.scala deleted file mode 100644 index ecfc47ba6..000000000 --- a/plugins/spark_upgrade/tests/resources/accessing_execution_plan copy/expected/sample.scala +++ /dev/null @@ -1,11 +0,0 @@ -package org.piranha - -object AccessingExecutionPlan { - def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { - QueryTest.checkAnswer(actual, expectedAnswer) - } - - def checkAnswer1(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { - QueryTest.checkAnswer(actual, expectedAnswer) - } -} diff --git a/plugins/spark_upgrade/tests/resources/accessing_execution_plan copy/input/sample.scala b/plugins/spark_upgrade/tests/resources/accessing_execution_plan copy/input/sample.scala deleted file mode 100644 index 9d999728a..000000000 --- a/plugins/spark_upgrade/tests/resources/accessing_execution_plan copy/input/sample.scala +++ /dev/null @@ -1,17 +0,0 @@ -package org.piranha - -object AccessingExecutionPlan { - def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { - QueryTest.checkAnswer(actual, expectedAnswer) match { - case Some(errorMessage) => fail(errorMessage) - case None => - } - } - - def checkAnswer1(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { - QueryTest.checkAnswer(actual, expectedAnswer) match { - case Some(errorMessage) => fail(errorMessage) - case None => // ignore - } - } -} diff --git a/plugins/spark_upgrade/tests/resources/gradient_boost_trees/expected/sample.scala b/plugins/spark_upgrade/tests/resources/gradient_boost_trees/expected/sample.scala index f6552f3c1..1c2047a86 100644 --- a/plugins/spark_upgrade/tests/resources/gradient_boost_trees/expected/sample.scala +++ b/plugins/spark_upgrade/tests/resources/gradient_boost_trees/expected/sample.scala @@ -1,12 +1,13 @@ package org.piranha +// this rules removes the comment on line 11 object GradientBoostTressExample { def main(args: Array[String]): Unit = { - val (baseLearners: Array[DecisionTreeRegressionModel], learnerWeights) = + val (a, b) = GradientBoostedTrees.run( oldDataset.map(data => new Instance(data.label, 1.0, data.features)), boostingStrategy, - $(seed), - "auto" /* featureSubsetStrategy */ ) + seed, + "auto") } } diff --git a/plugins/spark_upgrade/tests/resources/gradient_boost_trees/expected/sample2.scala b/plugins/spark_upgrade/tests/resources/gradient_boost_trees/expected/sample2.scala new file mode 100644 index 000000000..82a996e07 --- /dev/null +++ b/plugins/spark_upgrade/tests/resources/gradient_boost_trees/expected/sample2.scala @@ -0,0 +1,13 @@ +package org.piranha + +object GradientBoostTressExample { + def main(args: Array[String]): Unit = { + val (c, d) = + GradientBoostedTrees.run( + oldDataset.map(data => new Instance(data.label, 1.0, data.features)), + boostingStrategy, + seed, + "auto" + ) + } +} diff --git a/plugins/spark_upgrade/tests/resources/gradient_boost_trees/input/sample.scala b/plugins/spark_upgrade/tests/resources/gradient_boost_trees/input/sample.scala index e460a8d49..2f827bbe5 100644 --- a/plugins/spark_upgrade/tests/resources/gradient_boost_trees/input/sample.scala +++ b/plugins/spark_upgrade/tests/resources/gradient_boost_trees/input/sample.scala @@ -1,20 +1,13 @@ package org.piranha +// this rules removes the comment on line 11 object GradientBoostTressExample { def main(args: Array[String]): Unit = { val (a, b) = GradientBoostedTrees.run( oldDataset, boostingStrategy, - $(seed), - "auto" /* featureSubsetStrategy */ ) - } - - val (x, y) = - GradientBoostedTrees.run( - another_dataset, - boostingStrategy, - $(seed), - "auto" /* featureSubsetStrategy */ ) + seed, + "auto" /* featureSubsetStrategy */) } } diff --git a/plugins/spark_upgrade/tests/resources/gradient_boost_trees/input/sample2.scala b/plugins/spark_upgrade/tests/resources/gradient_boost_trees/input/sample2.scala new file mode 100644 index 000000000..29ab15eb5 --- /dev/null +++ b/plugins/spark_upgrade/tests/resources/gradient_boost_trees/input/sample2.scala @@ -0,0 +1,12 @@ +package org.piranha + +object GradientBoostTressExample { + def main(args: Array[String]): Unit = { + val (c, d) = + GradientBoostedTrees.run( + oldDataset, + boostingStrategy, + seed, + "auto") + } +} diff --git a/plugins/spark_upgrade/tests/resources/spark_conf/expected/sample.java b/plugins/spark_upgrade/tests/resources/spark_conf/expected/sample.java index e207352ee..a030f6289 100644 --- a/plugins/spark_upgrade/tests/resources/spark_conf/expected/sample.java +++ b/plugins/spark_upgrade/tests/resources/spark_conf/expected/sample.java @@ -29,9 +29,8 @@ public static void main(String[] args) { conf2.setExecutorEnv("spark.executor.extraClassPath", "test"); + // Should not touch existent SparkSession.builder() SparkSession sparkSession = SparkSession.builder() - .config("spark.sql.legacy.timeParserPolicy","LEGACY") - .config("spark.sql.legacy.allowUntypedScalaUDF", "true") .master(master) .appName(appName) .getOrCreate(); diff --git a/plugins/spark_upgrade/tests/resources/spark_conf/expected/sample.scala b/plugins/spark_upgrade/tests/resources/spark_conf/expected/sample.scala index a21792221..5fcdf56e3 100644 --- a/plugins/spark_upgrade/tests/resources/spark_conf/expected/sample.scala +++ b/plugins/spark_upgrade/tests/resources/spark_conf/expected/sample.scala @@ -14,21 +14,19 @@ class Sample { .set("spark.driver.allowMultipleContexts", "true") val sc = new SparkContext(conf) val sqlContext = new TestHiveContext(sc).sparkSession - + val conf2 = new SparkConf() .set("spark.sql.legacy.timeParserPolicy","LEGACY") .set("spark.sql.legacy.allowUntypedScalaUDF", "true") - + conf2.setSparkHome(sparkHome) conf2.setExecutorEnv("spark.executor.extraClassPath", "test") val sparkSession = SparkSession.builder() - .config("spark.sql.legacy.timeParserPolicy","LEGACY") - .config("spark.sql.legacy.allowUntypedScalaUDF", "true") .master(master) .appName(appName) - .getOrCreate() + .getOrCreate } diff --git a/plugins/spark_upgrade/tests/resources/spark_conf/expected/sample2.java b/plugins/spark_upgrade/tests/resources/spark_conf/expected/sample2.java new file mode 100644 index 000000000..72c141627 --- /dev/null +++ b/plugins/spark_upgrade/tests/resources/spark_conf/expected/sample2.java @@ -0,0 +1,21 @@ +import org.apache.spark.SparkConf; +import org.apache.spark.sql.SparkSession; + +public class Sample2 { + + private static JavaSparkContext jsc; + + @BeforeClass + public static void startSpark() { + jsc = + new JavaSparkContext( + SparkSession.builder() + .config("spark.sql.legacy.allowUntypedScalaUDF", "true") + .appName(Sample2.class.getName()) + .master("master") + .config("spark.driver.allowMultipleContexts", "true") + .getOrCreate() + .sparkContext()); + } + +} diff --git a/plugins/spark_upgrade/tests/resources/spark_conf/expected/sample2.scala b/plugins/spark_upgrade/tests/resources/spark_conf/expected/sample2.scala new file mode 100644 index 000000000..9830fabfc --- /dev/null +++ b/plugins/spark_upgrade/tests/resources/spark_conf/expected/sample2.scala @@ -0,0 +1,18 @@ +package com.piranha; + +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession + +class Sample { + + def main(argv: Array[String]): Unit = { + val spark = SparkSession.builder + .appName("appName") + .config("conf", "package.conf") + .enableHiveSupport + .getOrCreate + + run(spark, config) + } + +} diff --git a/plugins/spark_upgrade/tests/resources/spark_conf/expected/sample3.java b/plugins/spark_upgrade/tests/resources/spark_conf/expected/sample3.java new file mode 100644 index 000000000..0209c6274 --- /dev/null +++ b/plugins/spark_upgrade/tests/resources/spark_conf/expected/sample3.java @@ -0,0 +1,25 @@ +package com.piranha; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +import org.apache.spark.sql.SparkSession; + +public class Sample { + public static void main(String[] args) { + // Should not touch existent SparkSession.builder() + SparkSession session = SparkSession.builder().config(sc.getConf()).getOrCreate(); + + SparkSession session2 = + SparkSession.builder().appName("appName").config("config", "local").getOrCreate(); + + SparkSession.builder().appName("appName").config("config", "local").getOrCreate(); + } + + @Test + public void test() { + SparkSession session = + SparkSession.builder().appName("appName").config("config", "local").getOrCreate(); + SparkContext sc = session.sparkContext(); + JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc()); + } +} diff --git a/plugins/spark_upgrade/tests/resources/spark_conf/expected/sample3.scala b/plugins/spark_upgrade/tests/resources/spark_conf/expected/sample3.scala new file mode 100644 index 000000000..6643b8b10 --- /dev/null +++ b/plugins/spark_upgrade/tests/resources/spark_conf/expected/sample3.scala @@ -0,0 +1,21 @@ +package com.piranha; + +import org.apache.spark.sql.SparkSession + +import org.mockito.Mockito.spy + +class Sample { + + @Test + def testMain(): Unit = { + lazy val spark: SparkSession = spy( + SparkSession + .builder() + .master("master") + .appName("AppName") + .config("spark.ui.enabled", "false") + .config("spark.driver.host", "localhost") + .getOrCreate()) + } + +} diff --git a/plugins/spark_upgrade/tests/resources/spark_conf/input/sample.java b/plugins/spark_upgrade/tests/resources/spark_conf/input/sample.java index 77b8fd13d..25f475c3f 100644 --- a/plugins/spark_upgrade/tests/resources/spark_conf/input/sample.java +++ b/plugins/spark_upgrade/tests/resources/spark_conf/input/sample.java @@ -25,6 +25,7 @@ public static void main(String[] args) { conf2.setExecutorEnv("spark.executor.extraClassPath", "test"); + // Should not touch existent SparkSession.builder() SparkSession sparkSession = SparkSession.builder() .master(master) .appName(appName) diff --git a/plugins/spark_upgrade/tests/resources/spark_conf/input/sample.scala b/plugins/spark_upgrade/tests/resources/spark_conf/input/sample.scala index 57ee241ad..6342ec806 100644 --- a/plugins/spark_upgrade/tests/resources/spark_conf/input/sample.scala +++ b/plugins/spark_upgrade/tests/resources/spark_conf/input/sample.scala @@ -12,9 +12,9 @@ class Sample { .set("spark.driver.allowMultipleContexts", "true") val sc = new SparkContext(conf) val sqlContext = new TestHiveContext(sc).sparkSession - + val conf2 = new SparkConf() - + conf2.setSparkHome(sparkHome) conf2.setExecutorEnv("spark.executor.extraClassPath", "test") @@ -22,9 +22,7 @@ class Sample { val sparkSession = SparkSession.builder() .master(master) .appName(appName) - .getOrCreate() - - + .getOrCreate } } diff --git a/plugins/spark_upgrade/tests/resources/spark_conf/input/sample2.java b/plugins/spark_upgrade/tests/resources/spark_conf/input/sample2.java new file mode 100644 index 000000000..b4ed513d4 --- /dev/null +++ b/plugins/spark_upgrade/tests/resources/spark_conf/input/sample2.java @@ -0,0 +1,17 @@ +import org.apache.spark.SparkConf; + +public class Sample2 { + + private static JavaSparkContext jsc; + + @BeforeClass + public static void startSpark() { + jsc = + new JavaSparkContext( + new SparkConf() + .setAppName(Sample2.class.getName()) + .setMaster("master") + .set("spark.driver.allowMultipleContexts", "true")); + } + +} diff --git a/plugins/spark_upgrade/tests/resources/spark_conf/input/sample2.scala b/plugins/spark_upgrade/tests/resources/spark_conf/input/sample2.scala new file mode 100644 index 000000000..9830fabfc --- /dev/null +++ b/plugins/spark_upgrade/tests/resources/spark_conf/input/sample2.scala @@ -0,0 +1,18 @@ +package com.piranha; + +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession + +class Sample { + + def main(argv: Array[String]): Unit = { + val spark = SparkSession.builder + .appName("appName") + .config("conf", "package.conf") + .enableHiveSupport + .getOrCreate + + run(spark, config) + } + +} diff --git a/plugins/spark_upgrade/tests/resources/spark_conf/input/sample3.java b/plugins/spark_upgrade/tests/resources/spark_conf/input/sample3.java new file mode 100644 index 000000000..0bf205918 --- /dev/null +++ b/plugins/spark_upgrade/tests/resources/spark_conf/input/sample3.java @@ -0,0 +1,26 @@ +package com.piranha; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +import org.apache.spark.sql.SparkSession; + +public class Sample { + public static void main(String[] args) { + // Should not touch existent SparkSession.builder() + SparkSession session = SparkSession.builder().config(sc.getConf()).getOrCreate(); + + SparkSession session2 = + SparkSession.builder().appName("appName").config("config", "local").getOrCreate(); + + SparkSession.builder().appName("appName").config("config", "local").getOrCreate(); + } + + @Test + public void test() { + SparkSession session = + SparkSession.builder().appName("appName").config("config", "local").getOrCreate(); + SparkContext sc = session.sparkContext(); + JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc()); + } + +} diff --git a/plugins/spark_upgrade/tests/resources/spark_conf/input/sample3.scala b/plugins/spark_upgrade/tests/resources/spark_conf/input/sample3.scala new file mode 100644 index 000000000..6b3ae5747 --- /dev/null +++ b/plugins/spark_upgrade/tests/resources/spark_conf/input/sample3.scala @@ -0,0 +1,21 @@ +package com.piranha; + +import org.apache.spark.sql.SparkSession + +import org.mockito.Mockito.spy + +class Sample { + + @Test + def testMain(): Unit = { + lazy val spark: SparkSession = spy( + SparkSession + .builder() + .master("master") + .appName("AppName") + .getOrCreate()) + spark.sqlContext.setConf("spark.ui.enabled", "false") + spark.sqlContext.setConf("spark.driver.host", "localhost") + } + +} diff --git a/plugins/spark_upgrade/tests/test_spark_upgrade.py b/plugins/spark_upgrade/tests/test_spark_upgrade.py index 0eca150f1..e8d8e24b0 100644 --- a/plugins/spark_upgrade/tests/test_spark_upgrade.py +++ b/plugins/spark_upgrade/tests/test_spark_upgrade.py @@ -21,6 +21,8 @@ from sql_new_execution import SQLNewExecutionChange from query_test_check_answer_change import QueryTestCheckAnswerChange from spark_config import SparkConfigChange +from java_spark_context import JavaSparkContextChange +from scala_session_builder import ScalaSessionBuilder FORMAT = "%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s" logging.basicConfig(format=FORMAT) @@ -28,8 +30,12 @@ def test_update_CalendarInterval(): - input_codebase = "plugins/spark_upgrade/tests/resources/input/update_calendar_interval/" - expected_codebase = "plugins/spark_upgrade/tests/resources/expected/update_calendar_interval/" + input_codebase = ( + "plugins/spark_upgrade/tests/resources/update_calendar_interval/input/" + ) + expected_codebase = ( + "plugins/spark_upgrade/tests/resources/update_calendar_interval/expected/" + ) with TemporaryDirectory() as temp_dir: tp = temp_dir copy_dir(input_codebase, tp) @@ -38,32 +44,44 @@ def test_update_CalendarInterval(): assert summary is not None assert is_as_expected_files(expected_codebase, tp) - + def test_update_IDFModelSignatureChange(): - input_codebase = "plugins/spark_upgrade/tests/resources/input/idf_model_signature_change/" - expected_codebase = "plugins/spark_upgrade/tests/resources/expected/idf_model_signature_change/" + input_codebase = ( + "plugins/spark_upgrade/tests/resources/idf_model_signature_change/input/" + ) + expected_codebase = ( + "plugins/spark_upgrade/tests/resources/idf_model_signature_change/expected/" + ) with TemporaryDirectory() as temp_dir: tp = temp_dir copy_dir(input_codebase, tp) idf_model_signature_change = IDFModelSignatureChange([tp]) summary = idf_model_signature_change() assert summary is not None - assert is_as_expected_files(expected_codebase, tp) - + assert is_as_expected_files(expected_codebase, tp) + + def test_update_accessing_execution_plan(): - input_codebase = "plugins/spark_upgrade/tests/resources/input/accessing_execution_plan/" - expected_codebase = "plugins/spark_upgrade/tests/resources/expected/accessing_execution_plan/" + input_codebase = ( + "plugins/spark_upgrade/tests/resources/accessing_execution_plan/input/" + ) + expected_codebase = ( + "plugins/spark_upgrade/tests/resources/accessing_execution_plan/expected/" + ) with TemporaryDirectory() as temp_dir: tp = temp_dir copy_dir(input_codebase, tp) accessing_execution_plan = AccessingExecutionPlan([tp]) summary = accessing_execution_plan() assert summary is not None - assert is_as_expected_files(expected_codebase, tp) - + assert is_as_expected_files(expected_codebase, tp) + + def test_update_gradient_boost_trees(): - input_codebase = "plugins/spark_upgrade/tests/resources/input/gradient_boost_trees/" - expected_codebase = "plugins/spark_upgrade/tests/resources/expected/gradient_boost_trees/" + input_codebase = "plugins/spark_upgrade/tests/resources/gradient_boost_trees/input/" + expected_codebase = ( + "plugins/spark_upgrade/tests/resources/gradient_boost_trees/expected/" + ) with TemporaryDirectory() as temp_dir: tp = temp_dir copy_dir(input_codebase, tp) @@ -72,9 +90,14 @@ def test_update_gradient_boost_trees(): assert summary is not None assert is_as_expected_files(expected_codebase, tp) + def test_update_calculator_signature_change(): - input_codebase = "plugins/spark_upgrade/tests/resources/input/calculator_signature_change/" - expected_codebase = "plugins/spark_upgrade/tests/resources/expected/calculator_signature_change/" + input_codebase = ( + "plugins/spark_upgrade/tests/resources/calculator_signature_change/input/" + ) + expected_codebase = ( + "plugins/spark_upgrade/tests/resources/calculator_signature_change/expected/" + ) with TemporaryDirectory() as temp_dir: tp = temp_dir copy_dir(input_codebase, tp) @@ -83,9 +106,12 @@ def test_update_calculator_signature_change(): assert summary is not None assert is_as_expected_files(expected_codebase, tp) + def test_sql_new_execution(): - input_codebase = "plugins/spark_upgrade/tests/resources/input/sql_new_execution/" - expected_codebase = "plugins/spark_upgrade/tests/resources/expected/sql_new_execution/" + input_codebase = "plugins/spark_upgrade/tests/resources/sql_new_execution/input/" + expected_codebase = ( + "plugins/spark_upgrade/tests/resources/sql_new_execution/expected/" + ) with TemporaryDirectory() as temp_dir: tp = temp_dir copy_dir(input_codebase, tp) @@ -93,10 +119,15 @@ def test_sql_new_execution(): summary = sql_new_execution() assert summary is not None assert is_as_expected_files(expected_codebase, tp) - + + def test_query_test_check_answer_change(): - input_codebase = "plugins/spark_upgrade/tests/resources/input/query_test_check_answer_change/" - expected_codebase = "plugins/spark_upgrade/tests/resources/expected/query_test_check_answer_change/" + input_codebase = ( + "plugins/spark_upgrade/tests/resources/query_test_check_answer_change/input/" + ) + expected_codebase = ( + "plugins/spark_upgrade/tests/resources/query_test_check_answer_change/expected/" + ) with TemporaryDirectory() as temp_dir: tp = temp_dir copy_dir(input_codebase, tp) @@ -104,7 +135,8 @@ def test_query_test_check_answer_change(): summary = query_test_check_answer_change() assert summary is not None assert is_as_expected_files(expected_codebase, tp) - + + def test_spark_config_change(): input_codebase = "plugins/spark_upgrade/tests/resources/spark_conf/input/" expected_codebase = "plugins/spark_upgrade/tests/resources/spark_conf/expected/" @@ -117,6 +149,21 @@ def test_spark_config_change(): spark_config_change = SparkConfigChange([tp], "java") summary = spark_config_change() assert summary is not None + + javasparkcontext = JavaSparkContextChange([tp], language="java") + summary = javasparkcontext() + assert summary is not None + javasparkcontext = JavaSparkContextChange([tp]) + summary = javasparkcontext() + assert summary is not None + + scalasessionbuilder = ScalaSessionBuilder([tp], language="scala") + summary = scalasessionbuilder() + assert summary is not None + scalasessionbuilder = ScalaSessionBuilder([tp], language="java") + summary = scalasessionbuilder() + assert summary is not None + assert is_as_expected_files(expected_codebase, tp) @@ -130,6 +177,45 @@ def remove_whitespace(input_str): return "".join(input_str.split()).strip() +def test_integration(): + """Test that the integration of all plugins terminate correctly.""" + with TemporaryDirectory() as input_temp_dir: + for input_dir in [ + "plugins/spark_upgrade/tests/resources/accessing_execution_plan/input/", + "plugins/spark_upgrade/tests/resources/calculator_signature_change/input/", + "plugins/spark_upgrade/tests/resources/gradient_boost_trees/input/", + "plugins/spark_upgrade/tests/resources/idf_model_signature_change/input/", + "plugins/spark_upgrade/tests/resources/spark_conf/input/", + "plugins/spark_upgrade/tests/resources/sql_new_execution/input/", + "plugins/spark_upgrade/tests/resources/update_calendar_interval/input/", + ]: + dir_name = input_dir.split("/")[-3] + copy_dir(input_dir, input_temp_dir + "/" + dir_name + "/") + + update_calendar_interval = UpdateCalendarInterval([input_temp_dir]) + _ = update_calendar_interval() + idf_model_signature_change = IDFModelSignatureChange([input_temp_dir]) + _ = idf_model_signature_change() + accessing_execution_plan = AccessingExecutionPlan([input_temp_dir]) + _ = accessing_execution_plan() + gradient_boost_trees = GradientBoostTrees([input_temp_dir]) + _ = gradient_boost_trees() + calculator_signature_change = CalculatorSignatureChange([input_temp_dir]) + _ = calculator_signature_change() + sql_new_execution = SQLNewExecutionChange([input_temp_dir]) + _ = sql_new_execution() + query_test_check_answer_change = QueryTestCheckAnswerChange([input_temp_dir]) + _ = query_test_check_answer_change() + spark_config = SparkConfigChange([input_temp_dir]) + _ = spark_config() + spark_config = SparkConfigChange([input_temp_dir], language="java") + _ = spark_config() + javasparkcontext = JavaSparkContextChange([input_temp_dir], language="java") + _ = javasparkcontext() + scalasessionbuilder = ScalaSessionBuilder([input_temp_dir], language="scala") + _ = scalasessionbuilder() + + def copy_dir(source_dir, dest_dir): """Copy files in {source_dir} to {dest_dir} Properties to note: