Skip to content

Commit

Permalink
[SPARK-50909][PYTHON] Setup faulthandler in PythonPlannerRunners
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Setups `faulthandler` in `PythonPlannerRunner`s.

It can be enabled by the same config as UDFs.

- SQL conf: `spark.sql.execution.pyspark.udf.faulthandler.enabled`
- It fallback to Spark conf: `spark.python.worker.faulthandler.enabled`
- `False` by default

### Why are the changes needed?

The `faulthandler` is not set up in `PythonPlannerRunner`s.

### Does this PR introduce _any_ user-facing change?

When enabled, if Python worker crashes, it may generate thread-dump in the error message on the best-effort basis of Python process.

### How was this patch tested?

Added the related tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#49592 from ueshin/issues/SPARK-50909/faulthandler.

Authored-by: Takuya Ueshin <[email protected]>
Signed-off-by: Takuya Ueshin <[email protected]>
  • Loading branch information
ueshin committed Jan 24, 2025
1 parent f2765f4 commit 100105b
Show file tree
Hide file tree
Showing 11 changed files with 275 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark._
import org.apache.spark.api.python.PythonFunction.PythonAccumulator
import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.internal.LogKeys.TASK_NAME
import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES, Python}
import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES}
import org.apache.spark.internal.config.Python._
import org.apache.spark.rdd.InputFileBlockHolder
import org.apache.spark.resource.ResourceProfile.{EXECUTOR_CORES_LOCAL_PROPERTY, PYSPARK_MEMORY_LOCAL_PROPERTY}
Expand Down Expand Up @@ -90,11 +90,11 @@ private[spark] object PythonEvalType {
}
}

private object BasePythonRunner {
private[spark] object BasePythonRunner {

private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler")
private[spark] lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler")

private def faultHandlerLogPath(pid: Int): Path = {
private[spark] def faultHandlerLogPath(pid: Int): Path = {
new File(faultHandlerLogDir, pid.toString).toPath
}
}
Expand Down Expand Up @@ -574,15 +574,15 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
JavaFiles.deleteIfExists(path)
throw new SparkException(s"Python worker exited unexpectedly (crashed): $error", e)

case eof: EOFException if !faultHandlerEnabled =>
case e: IOException if !faultHandlerEnabled =>
throw new SparkException(
s"Python worker exited unexpectedly (crashed). " +
"Consider setting 'spark.sql.execution.pyspark.udf.faulthandler.enabled' or" +
s"'${Python.PYTHON_WORKER_FAULTHANLDER_ENABLED.key}' configuration to 'true' for" +
"the better Python traceback.", eof)
s"'${PYTHON_WORKER_FAULTHANLDER_ENABLED.key}' configuration to 'true' for " +
"the better Python traceback.", e)

case eof: EOFException =>
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
case e: IOException =>
throw new SparkException("Python worker exited unexpectedly (crashed)", e)
}
}

Expand Down
120 changes: 120 additions & 0 deletions python/pyspark/sql/tests/test_python_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,126 @@ def write(self, iterator):
):
df.write.format("test").mode("append").saveAsTable("test_table")

def test_data_source_segfault(self):
import ctypes

for enabled, expected in [
(True, "Segmentation fault"),
(False, "Consider setting .* for the better Python traceback."),
]:
with self.subTest(enabled=enabled), self.sql_conf(
{"spark.sql.execution.pyspark.udf.faulthandler.enabled": enabled}
):
with self.subTest(worker="pyspark.sql.worker.create_data_source"):

class TestDataSource(DataSource):
@classmethod
def name(cls):
return "test"

def schema(self):
return ctypes.string_at(0)

self.spark.dataSource.register(TestDataSource)

with self.assertRaisesRegex(Exception, expected):
self.spark.read.format("test").load().show()

with self.subTest(worker="pyspark.sql.worker.plan_data_source_read"):

class TestDataSource(DataSource):
@classmethod
def name(cls):
return "test"

def schema(self):
return "x string"

def reader(self, schema):
return TestReader()

class TestReader(DataSourceReader):
def partitions(self):
ctypes.string_at(0)
return []

def read(self, partition):
return []

self.spark.dataSource.register(TestDataSource)

with self.assertRaisesRegex(Exception, expected):
self.spark.read.format("test").load().show()

with self.subTest(worker="pyspark.worker"):

class TestDataSource(DataSource):
@classmethod
def name(cls):
return "test"

def schema(self):
return "x string"

def reader(self, schema):
return TestReader()

class TestReader(DataSourceReader):
def read(self, partition):
ctypes.string_at(0)
yield "x",

self.spark.dataSource.register(TestDataSource)

with self.assertRaisesRegex(Exception, expected):
self.spark.read.format("test").load().show()

with self.subTest(worker="pyspark.sql.worker.write_into_data_source"):

class TestDataSource(DataSource):
@classmethod
def name(cls):
return "test"

def writer(self, schema, overwrite):
return TestWriter()

class TestWriter(DataSourceWriter):
def write(self, iterator):
ctypes.string_at(0)
return WriterCommitMessage()

self.spark.dataSource.register(TestDataSource)

with self.assertRaisesRegex(Exception, expected):
self.spark.range(10).write.format("test").mode("append").saveAsTable(
"test_table"
)

with self.subTest(worker="pyspark.sql.worker.commit_data_source_write"):

class TestDataSource(DataSource):
@classmethod
def name(cls):
return "test"

def writer(self, schema, overwrite):
return TestWriter()

class TestWriter(DataSourceWriter):
def write(self, iterator):
return WriterCommitMessage()

def commit(self, messages):
ctypes.string_at(0)

self.spark.dataSource.register(TestDataSource)

with self.assertRaisesRegex(Exception, expected):
self.spark.range(10).write.format("test").mode("append").saveAsTable(
"test_table"
)


class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase):
...
Expand Down
37 changes: 37 additions & 0 deletions python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2761,6 +2761,43 @@ def eval(self, n):
res = self.spark.sql("select i, to_json(v['v1']) from test_udtf_struct(8)")
assertDataFrameEqual(res, [Row(i=n, s=f'{{"a":"{chr(99 + n)}"}}') for n in range(8)])

def test_udtf_segfault(self):
for enabled, expected in [
(True, "Segmentation fault"),
(False, "Consider setting .* for the better Python traceback."),
]:
with self.subTest(enabled=enabled), self.sql_conf(
{"spark.sql.execution.pyspark.udf.faulthandler.enabled": enabled}
):
with self.subTest(method="eval"):

class TestUDTF:
def eval(self):
import ctypes

yield ctypes.string_at(0),

self._check_result_or_exception(
TestUDTF, "x: string", expected, err_type=Exception
)

with self.subTest(method="analyze"):

class TestUDTFWithAnalyze:
@staticmethod
def analyze():
import ctypes

ctypes.string_at(0)
return AnalyzeResult(StructType().add("x", StringType()))

def eval(self):
yield "x",

self._check_result_or_exception(
TestUDTFWithAnalyze, None, expected, err_type=Exception
)


class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/sql/worker/analyze_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#

import faulthandler
import inspect
import os
import sys
Expand Down Expand Up @@ -106,7 +107,13 @@ def main(infile: IO, outfile: IO) -> None:
in JVM and receive the Python UDTF and its arguments for the `analyze` static method,
and call the `analyze` static method, and send back a AnalyzeResult as a result of the method.
"""
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
try:
if faulthandler_log_path:
faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
faulthandler_log_file = open(faulthandler_log_path, "w")
faulthandler.enable(file=faulthandler_log_file)

check_python_version(infile)

memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
Expand Down Expand Up @@ -247,6 +254,11 @@ def invalid_analyze_result_field(field_name: str, expected_field: str) -> PySpar
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
finally:
if faulthandler_log_path:
faulthandler.disable()
faulthandler_log_file.close()
os.remove(faulthandler_log_path)

send_accumulator_updates(outfile)

Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/sql/worker/commit_data_source_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import faulthandler
import os
import sys
from typing import IO
Expand Down Expand Up @@ -47,7 +48,13 @@ def main(infile: IO, outfile: IO) -> None:
responsible for invoking either the `commit` or the `abort` method on a data source
writer instance, given a list of commit messages.
"""
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
try:
if faulthandler_log_path:
faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
faulthandler_log_file = open(faulthandler_log_path, "w")
faulthandler.enable(file=faulthandler_log_file)

check_python_version(infile)

memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
Expand Down Expand Up @@ -93,6 +100,11 @@ def main(infile: IO, outfile: IO) -> None:
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
finally:
if faulthandler_log_path:
faulthandler.disable()
faulthandler_log_file.close()
os.remove(faulthandler_log_path)

send_accumulator_updates(outfile)

Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/sql/worker/create_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import faulthandler
import inspect
import os
import sys
Expand Down Expand Up @@ -60,7 +61,13 @@ def main(infile: IO, outfile: IO) -> None:
This process then creates a `DataSource` instance using the above information and
sends the pickled instance as well as the schema back to the JVM.
"""
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
try:
if faulthandler_log_path:
faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
faulthandler_log_file = open(faulthandler_log_path, "w")
faulthandler.enable(file=faulthandler_log_file)

check_python_version(infile)

memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
Expand Down Expand Up @@ -158,6 +165,11 @@ def main(infile: IO, outfile: IO) -> None:
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
finally:
if faulthandler_log_path:
faulthandler.disable()
faulthandler_log_file.close()
os.remove(faulthandler_log_path)

send_accumulator_updates(outfile)

Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/sql/worker/lookup_data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import faulthandler
from importlib import import_module
from pkgutil import iter_modules
import os
Expand Down Expand Up @@ -50,7 +51,13 @@ def main(infile: IO, outfile: IO) -> None:
This is responsible for searching the available Python Data Sources so they can be
statically registered automatically.
"""
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
try:
if faulthandler_log_path:
faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
faulthandler_log_file = open(faulthandler_log_path, "w")
faulthandler.enable(file=faulthandler_log_file)

check_python_version(infile)

memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
Expand Down Expand Up @@ -78,6 +85,11 @@ def main(infile: IO, outfile: IO) -> None:
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
finally:
if faulthandler_log_path:
faulthandler.disable()
faulthandler_log_file.close()
os.remove(faulthandler_log_path)

send_accumulator_updates(outfile)

Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/sql/worker/plan_data_source_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#

import faulthandler
import os
import sys
import functools
Expand Down Expand Up @@ -187,7 +188,13 @@ def main(infile: IO, outfile: IO) -> None:
The partition values and the Arrow Batch are then serialized and sent back to the JVM
via the socket.
"""
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
try:
if faulthandler_log_path:
faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
faulthandler_log_file = open(faulthandler_log_path, "w")
faulthandler.enable(file=faulthandler_log_file)

check_python_version(infile)

memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
Expand Down Expand Up @@ -351,6 +358,11 @@ def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.Rec
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
finally:
if faulthandler_log_path:
faulthandler.disable()
faulthandler_log_file.close()
os.remove(faulthandler_log_path)

send_accumulator_updates(outfile)

Expand Down
Loading

0 comments on commit 100105b

Please sign in to comment.