Skip to content

Commit 6b9e9ae

Browse files
committed
Prototype of runtime profiler
1 parent 6739e4f commit 6b9e9ae

File tree

7 files changed

+324
-32
lines changed

7 files changed

+324
-32
lines changed

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class SparkEnv (
9494
*/
9595
private case class PythonWorkersKey(
9696
pythonExec: String, workerModule: String, daemonModule: String, envVars: Map[String, String])
97-
private val pythonWorkers = mutable.HashMap[PythonWorkersKey, PythonWorkerFactory]()
97+
private[sql] val pythonWorkers = mutable.HashMap[PythonWorkersKey, PythonWorkerFactory]()
9898

9999
// A general, soft-reference map for metadata needed during HadoopRDD split computation
100100
// (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).

core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -38,36 +38,43 @@ import org.apache.spark.internal.config.Python.PYTHON_FACTORY_IDLE_WORKER_MAX_PO
3838
import org.apache.spark.security.SocketAuthHelper
3939
import org.apache.spark.util.{RedirectThread, Utils}
4040

41-
case class PythonWorker(channel: SocketChannel) {
42-
43-
private[this] var selectorOpt: Option[Selector] = None
44-
private[this] var selectionKeyOpt: Option[SelectionKey] = None
45-
46-
def selector: Selector = selectorOpt.orNull
47-
def selectionKey: SelectionKey = selectionKeyOpt.orNull
48-
49-
private def closeSelector(): Unit = {
50-
selectionKeyOpt.foreach(_.cancel())
51-
selectorOpt.foreach(_.close())
41+
case class PythonWorker(
42+
channel: SocketChannel,
43+
extraChannel: Option[SocketChannel] = None) {
44+
45+
private[this] var selectors: Seq[Selector] = Seq.empty
46+
private[this] var selectionKeys: Seq[SelectionKey] = Seq.empty
47+
48+
private def closeSelectors(): Unit = {
49+
selectionKeys.foreach(_.cancel())
50+
selectors.foreach(_.close())
51+
selectors = Seq.empty
52+
selectionKeys = Seq.empty
5253
}
5354

5455
def refresh(): this.type = synchronized {
55-
closeSelector()
56-
if (channel.isBlocking) {
57-
selectorOpt = None
58-
selectionKeyOpt = None
59-
} else {
60-
val selector = Selector.open()
61-
selectorOpt = Some(selector)
62-
selectionKeyOpt =
63-
Some(channel.register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE))
64-
}
56+
closeSelectors()
57+
58+
val channels = Seq(Some(channel), extraChannel).flatten
59+
val (selList, keyList) = channels.map { ch =>
60+
if (ch.isBlocking) {
61+
(None, None)
62+
} else {
63+
val selector = Selector.open()
64+
val key = ch.register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE)
65+
(Some(selector), Some(key))
66+
}
67+
}.unzip
68+
69+
selectors = selList.flatten
70+
selectionKeys = keyList.flatten
6571
this
6672
}
6773

6874
def stop(): Unit = synchronized {
69-
closeSelector()
75+
closeSelectors()
7076
Option(channel).foreach(_.close())
77+
extraChannel.foreach(_.close())
7178
}
7279
}
7380

@@ -129,6 +136,10 @@ private[spark] class PythonWorkerFactory(
129136
envVars.getOrElse("PYTHONPATH", ""),
130137
sys.env.getOrElse("PYTHONPATH", ""))
131138

139+
def getAllDaemonWorkers: Seq[(PythonWorker, ProcessHandle)] = self.synchronized {
140+
daemonWorkers.filter { case (_, handle) => handle.isAlive}.toSeq
141+
}
142+
132143
def create(): (PythonWorker, Option[ProcessHandle]) = {
133144
if (useDaemon) {
134145
self.synchronized {
@@ -163,22 +174,36 @@ private[spark] class PythonWorkerFactory(
163174
private def createThroughDaemon(): (PythonWorker, Option[ProcessHandle]) = {
164175

165176
def createWorker(): (PythonWorker, Option[ProcessHandle]) = {
166-
val socketChannel = if (isUnixDomainSock) {
177+
val mainChannel = if (isUnixDomainSock) {
167178
SocketChannel.open(UnixDomainSocketAddress.of(daemonSockPath))
168179
} else {
169180
SocketChannel.open(new InetSocketAddress(daemonHost, daemonPort))
170181
}
182+
183+
val extraChannel = if (envVars.getOrElse("PYSPARK_RUNTIME_PROFILE", "false").toBoolean) {
184+
if (isUnixDomainSock) {
185+
Some(SocketChannel.open(UnixDomainSocketAddress.of(daemonSockPath)))
186+
} else {
187+
Some(SocketChannel.open(new InetSocketAddress(daemonHost, daemonPort)))
188+
}
189+
} else {
190+
None
191+
}
192+
171193
// These calls are blocking.
172-
val pid = new DataInputStream(Channels.newInputStream(socketChannel)).readInt()
194+
val pid = new DataInputStream(Channels.newInputStream(mainChannel)).readInt()
173195
if (pid < 0) {
174196
throw new IllegalStateException("Python daemon failed to launch worker with code " + pid)
175197
}
176198
val processHandle = ProcessHandle.of(pid).orElseThrow(
177199
() => new IllegalStateException("Python daemon failed to launch worker.")
178200
)
179-
authHelper.authToServer(socketChannel)
180-
socketChannel.configureBlocking(false)
181-
val worker = PythonWorker(socketChannel)
201+
202+
authHelper.authToServer(mainChannel)
203+
mainChannel.configureBlocking(false)
204+
extraChannel.foreach(_.configureBlocking(false))
205+
206+
val worker = PythonWorker(mainChannel, extraChannel)
182207
daemonWorkers.put(worker, processHandle)
183208
(worker.refresh(), Some(processHandle))
184209
}
@@ -271,7 +296,7 @@ private[spark] class PythonWorkerFactory(
271296
if (!blockingMode) {
272297
socketChannel.configureBlocking(false)
273298
}
274-
val worker = PythonWorker(socketChannel)
299+
val worker = PythonWorker(socketChannel, None)
275300
self.synchronized {
276301
simpleWorkers.put(worker, workerProcess)
277302
}

python/pyspark/daemon.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def compute_real_exit_code(exit_code):
4747
return 1
4848

4949

50-
def worker(sock, authenticated):
50+
def worker(sock, sock2, authenticated):
5151
"""
5252
Called by a worker process after the fork().
5353
"""
@@ -64,6 +64,9 @@ def worker(sock, authenticated):
6464
buffer_size = int(os.environ.get("SPARK_BUFFER_SIZE", 65536))
6565
infile = os.fdopen(os.dup(sock.fileno()), "rb", buffer_size)
6666
outfile = os.fdopen(os.dup(sock.fileno()), "wb", buffer_size)
67+
outfile2 = None
68+
if sock2 is not None:
69+
outfile2 = os.fdopen(os.dup(sock2.fileno()), "wb", buffer_size)
6770

6871
if not authenticated:
6972
client_secret = UTF8Deserializer().loads(infile)
@@ -74,11 +77,16 @@ def worker(sock, authenticated):
7477
write_with_length("err".encode("utf-8"), outfile)
7578
outfile.flush()
7679
sock.close()
80+
if sock2 is not None:
81+
sock2.close()
7782
return 1
7883

7984
exit_code = 0
8085
try:
81-
worker_main(infile, outfile)
86+
if sock2 is not None:
87+
worker_main(infile, (outfile, outfile2))
88+
else:
89+
worker_main(infile, outfile)
8290
except SystemExit as exc:
8391
exit_code = compute_real_exit_code(exc.code)
8492
finally:
@@ -94,6 +102,7 @@ def manager():
94102
os.setpgid(0, 0)
95103

96104
is_unix_domain_sock = os.environ.get("PYTHON_UNIX_DOMAIN_ENABLED", "false").lower() == "true"
105+
is_python_runtime_profile = os.environ.get("PYSPARK_RUNTIME_PROFILE", "false").lower() == "true"
97106
socket_path = None
98107

99108
# Create a listening socket on the loopback interface
@@ -173,6 +182,15 @@ def handle_sigterm(*args):
173182
continue
174183
raise
175184

185+
sock2 = None
186+
if is_python_runtime_profile:
187+
try:
188+
sock2, _ = listen_sock.accept()
189+
except OSError as e:
190+
if e.errno == EINTR:
191+
continue
192+
raise
193+
176194
# Launch a worker process
177195
try:
178196
pid = os.fork()
@@ -186,6 +204,13 @@ def handle_sigterm(*args):
186204
outfile.flush()
187205
outfile.close()
188206
sock.close()
207+
208+
if sock2 is not None:
209+
outfile = sock2.makefile(mode="wb")
210+
write_int(e.errno, outfile) # Signal that the fork failed
211+
outfile.flush()
212+
outfile.close()
213+
sock2.close()
189214
continue
190215

191216
if pid == 0:
@@ -217,14 +242,16 @@ def handle_sigterm(*args):
217242
or False
218243
)
219244
while True:
220-
code = worker(sock, authenticated)
245+
code = worker(sock, sock2, authenticated)
221246
if code == 0:
222247
authenticated = True
223248
if not reuse or code:
224249
# wait for closing
225250
try:
226251
while sock.recv(1024):
227252
pass
253+
while sock2 is not None and sock2.recv(1024):
254+
pass
228255
except Exception:
229256
pass
230257
break

python/pyspark/worker.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
"""
1919
Worker that receives input from Piped RDD.
2020
"""
21+
import pickle
22+
import threading
2123
import itertools
2224
import os
2325
import sys
@@ -45,6 +47,7 @@
4547
read_bool,
4648
write_long,
4749
read_int,
50+
write_with_length,
4851
SpecialLengths,
4952
CPickleSerializer,
5053
BatchedSerializer,
@@ -3167,7 +3170,27 @@ def func(_, it):
31673170
return func, None, ser, ser
31683171

31693172

3173+
def write_profile(outfile):
3174+
import yappi
3175+
3176+
while True:
3177+
stats = []
3178+
for thread in yappi.get_thread_stats():
3179+
data = list(yappi.get_func_stats(ctx_id=thread.id))
3180+
stats.extend([{str(k): str(v) for k, v in d.items()} for d in data])
3181+
pickled = pickle.dumps(stats)
3182+
write_with_length(pickled, outfile)
3183+
time.sleep(1)
3184+
3185+
31703186
def main(infile, outfile):
3187+
if isinstance(outfile, tuple):
3188+
import yappi
3189+
3190+
outfile, outfile2 = outfile
3191+
yappi.start()
3192+
threading.Thread(target=write_profile, args=(outfile2,), daemon=True).start()
3193+
31713194
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
31723195
tracebackDumpIntervalSeconds = os.environ.get("PYTHON_TRACEBACK_DUMP_INTERVAL_SECONDS", None)
31733196
try:

sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ org.apache.spark.sql.execution.datasources.xml.XmlFileFormat
2626
org.apache.spark.sql.execution.streaming.ConsoleSinkProvider
2727
org.apache.spark.sql.execution.streaming.sources.RateStreamProvider
2828
org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider
29+
org.apache.spark.sql.execution.streaming.sources.PythonProfileSourceProvider
2930
org.apache.spark.sql.execution.datasources.binaryfile.BinaryFileFormat
3031
org.apache.spark.sql.execution.streaming.sources.RatePerMicroBatchProvider
3132
org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataSource

0 commit comments

Comments
 (0)