diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index c2928442971d..5e4f28f68239 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -941,6 +941,11 @@ "Argument must be a numerical column for plotting, got ." ] }, + "PROTOCOL_ERROR": { + "message": [ + ". This usually indicates that the message does not conform to the protocol." + ] + }, "PYTHON_HASH_SEED_NOT_SET": { "message": [ "Randomness of hash of string should be disabled via PYTHONHASHSEED." diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 50e71fb6da9d..9506912ce923 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -119,7 +119,19 @@ def __init__(self, infile=None): def load(self, infile): num_conf = read_int(infile) - for i in range(num_conf): + # We do a sanity check here to reduce the possibility to stuck indefinitely + # due to an invalid messsage. If the numer of configurations is obviously + # wrong, we just raise an error directly. + # We hand-pick the configurations to send to the worker so the number should + # be very small (less than 100). + if num_conf < 0 or num_conf > 10000: + raise PySparkRuntimeError( + errorClass="PROTOCOL_ERROR", + messageParameters={ + "failure": f"Invalid number of configurations: {num_conf}", + }, + ) + for _ in range(num_conf): k = utf8_deserializer.loads(infile) v = utf8_deserializer.loads(infile) self._conf[k] = v