Skip to content

Commit

Permalink
Compute Spark SQL plan for each stage
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-laffon-dd committed Oct 30, 2023
1 parent 61ab1df commit d90a497
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package datadog.trace.instrumentation.spark;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import org.apache.spark.SparkConf;
import org.apache.spark.scheduler.SparkListenerEvent;
import org.apache.spark.scheduler.SparkListenerJobStart;
import org.apache.spark.sql.execution.SparkPlanInfo;

/**
* DatadogSparkListener compiled for Scala 2.12
Expand Down Expand Up @@ -36,4 +40,23 @@ protected String getSparkJobName(SparkListenerJobStart jobStart) {
protected int getStageCount(SparkListenerJobStart jobStart) {
return jobStart.stageInfos().length();
}

@Override
protected void updateSqlPlanInfo(SparkListenerEvent event) {
try {
Class<?> adaptiveExecutionUpdateClass = Class.forName("org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate");

if (adaptiveExecutionUpdateClass.isInstance(event)) {
Method executionIdMethod = adaptiveExecutionUpdateClass.getDeclaredMethod("executionId");
Method sparkPlanInfoMethod = adaptiveExecutionUpdateClass.getDeclaredMethod("sparkPlanInfo");

long queryId = (long) executionIdMethod.invoke(event);
SparkPlanInfo sparkPlanInfo = (SparkPlanInfo) sparkPlanInfoMethod.invoke(event);

System.out.println("For query Id: " + queryId + ", got sparkPlanInfo: " + sparkPlanInfo);

sqlPlans.put(queryId, sparkPlanInfo);
}
} catch (ClassNotFoundException | NoSuchMethodException | IllegalAccessException | InvocationTargetException ignored) {}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ public String[] helperClassNames() {
packageName + ".DatadogSpark212Listener",
packageName + ".SparkAggregatedTaskMetrics",
packageName + ".SparkConfAllowList",
packageName + ".SparkSQLUtils",
packageName + ".SparkSQLUtils$SparkPlanInfoForStage",
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

import java.util.ArrayList;
import org.apache.spark.SparkConf;
import org.apache.spark.scheduler.SparkListenerEvent;
import org.apache.spark.scheduler.SparkListenerJobStart;
import org.apache.spark.sql.execution.SparkPlanInfo;
import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate;
import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd;

/**
* DatadogSparkListener compiled for Scala 2.13
Expand Down Expand Up @@ -36,4 +40,16 @@ protected String getSparkJobName(SparkListenerJobStart jobStart) {
protected int getStageCount(SparkListenerJobStart jobStart) {
return jobStart.stageInfos().length();
}

@Override
protected void updateSqlPlanInfo(SparkListenerEvent event) {
if (event instanceof SparkListenerSQLAdaptiveExecutionUpdate) {
SparkListenerSQLAdaptiveExecutionUpdate update = (SparkListenerSQLAdaptiveExecutionUpdate) event;

long queryId = update.executionId();
SparkPlanInfo info = update.sparkPlanInfo();

sqlPlans.put(queryId, info);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.spark.TaskFailedReason;
import org.apache.spark.scheduler.*;
import org.apache.spark.sql.execution.SQLExecution;
import org.apache.spark.sql.execution.SparkPlanInfo;
import org.apache.spark.sql.execution.streaming.MicroBatchExecution;
import org.apache.spark.sql.execution.streaming.StreamExecution;
import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd;
Expand All @@ -31,6 +32,7 @@
import org.apache.spark.sql.streaming.StreamingQueryListener;
import org.apache.spark.sql.streaming.StreamingQueryProgress;
import scala.Tuple2;
import scala.collection.JavaConverters;

/**
* Implementation of the SparkListener {@link SparkListener} to generate spans from the execution of
Expand Down Expand Up @@ -72,7 +74,9 @@ public abstract class AbstractDatadogSparkListener extends SparkListener {
private final HashMap<UUID, StreamingQueryListener.QueryStartedEvent> streamingQueries =
new HashMap<>();
private final HashMap<Long, SparkListenerSQLExecutionStart> sqlQueries = new HashMap<>();
protected final HashMap<Long, SparkPlanInfo> sqlPlans = new HashMap<>();
private final HashMap<String, SparkListenerExecutorAdded> liveExecutors = new HashMap<>();
private final HashMap<Long, Integer> accumulatorToStage = new HashMap<>();

private final boolean isRunningOnDatabricks;
private final String databricksClusterName;
Expand Down Expand Up @@ -109,6 +113,8 @@ public AbstractDatadogSparkListener(SparkConf sparkConf, String appId, String sp
/** Stage count of the spark job. Provide an implementation based on a specific scala version */
protected abstract int getStageCount(SparkListenerJobStart jobStart);

protected abstract void updateSqlPlanInfo(SparkListenerEvent event);

@Override
public synchronized void onApplicationStart(SparkListenerApplicationStart applicationStart) {
this.applicationStart = applicationStart;
Expand Down Expand Up @@ -445,7 +451,13 @@ public synchronized void onStageCompleted(SparkListenerStageCompleted stageCompl
metric.allocateAvailableExecutorTime(currentAvailableExecutorTime);
}

for (AccumulableInfo info :
JavaConverters.asJavaCollection(stageInfo.accumulables().values())) {
accumulatorToStage.put(info.id(), stageId);
}

SparkAggregatedTaskMetrics stageMetric = stageMetrics.remove(stageSpanKey);
Properties prop = stageProperties.remove(stageSpanKey);
if (stageMetric != null) {
stageMetric.computeSkew();
stageMetric.setSpanMetrics(span);
Expand All @@ -455,9 +467,7 @@ public synchronized void onStageCompleted(SparkListenerStageCompleted stageCompl
.computeIfAbsent(jobId, k -> new SparkAggregatedTaskMetrics())
.accumulateStageMetrics(stageMetric);

Properties prop = stageProperties.remove(stageSpanKey);
String batchKey = getStreamingBatchKey(prop);

if (batchKey != null) {
streamingBatchMetrics
.computeIfAbsent(batchKey, k -> new SparkAggregatedTaskMetrics())
Expand All @@ -472,6 +482,12 @@ public synchronized void onStageCompleted(SparkListenerStageCompleted stageCompl
}
}

Long sqlQueryId = getSqlExecutionId(prop);
SparkPlanInfo sqlPlan = sqlPlans.get(sqlQueryId);
if (sqlPlan != null) {
SparkSQLUtils.addSQLPlanToStageSpan(span, sqlPlan, accumulatorToStage, stageId);
}

span.finish(completionTimeMs * 1000);
}

Expand Down Expand Up @@ -588,16 +604,20 @@ public void onOtherEvent(SparkListenerEvent event) {
} else if (event instanceof SparkListenerSQLExecutionEnd) {
onSQLExecutionEnd((SparkListenerSQLExecutionEnd) event);
}

updateSqlPlanInfo(event);
}

private synchronized void onSQLExecutionStart(SparkListenerSQLExecutionStart sqlStart) {
sqlPlans.put(sqlStart.executionId(), sqlStart.sparkPlanInfo());
sqlQueries.put(sqlStart.executionId(), sqlStart);
}

private synchronized void onSQLExecutionEnd(SparkListenerSQLExecutionEnd sqlEnd) {
AgentSpan span = sqlSpans.remove(sqlEnd.executionId());
SparkAggregatedTaskMetrics metrics = sqlMetrics.remove(sqlEnd.executionId());
sqlQueries.remove(sqlEnd.executionId());
sqlPlans.remove(sqlEnd.executionId());

if (span != null) {
if (metrics != null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package datadog.trace.instrumentation.spark;

import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import datadog.trace.bootstrap.instrumentation.api.AgentSpan;
import org.apache.spark.sql.execution.SparkPlanInfo;
import org.apache.spark.sql.execution.metric.SQLMetricInfo;
import scala.collection.JavaConverters;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

public class SparkSQLUtils {

public static void addSQLPlanToStageSpan(
AgentSpan span,
SparkPlanInfo sparkPlanInfo,
HashMap<Long, Integer> accumulators,
int stageId) {
System.out.println("Computing planForStage: " + stageId);

SparkPlanInfoForStage planForStage = computeStageInfoForStage(sparkPlanInfo, accumulators, stageId, false);

if (planForStage != null) {
System.out.println("Got non null planForStage: " + stageId);
String json = planForStage.toJson();
span.setTag("_dd.spark.sql_plan", json);
}
}

public static SparkPlanInfoForStage computeStageInfoForStage(SparkPlanInfo info, HashMap<Long, Integer> accumulatorToStage, int stageId, boolean alreadyStarted) {
Set<Integer> stageIds = stageIdsForPlan(info, accumulatorToStage);

boolean hasStageInfo = !stageIds.isEmpty();
boolean isForStage = stageIds.contains(stageId);

System.out.println("On Node: " + info.nodeName());

if (alreadyStarted && hasStageInfo && !isForStage) {
System.out.println("Skipping because other stage");
return null;
}

if (alreadyStarted || isForStage) {
System.out.println("In computation with alreadyStarted: " + alreadyStarted + " isForStage: " + isForStage);
List<SparkPlanInfoForStage> children = new ArrayList<>();
for (SparkPlanInfo child : JavaConverters.asJavaCollection(info.children())) {
SparkPlanInfoForStage infoForStage = computeStageInfoForStage(child, accumulatorToStage, stageId, true);

if (infoForStage != null) {
children.add(infoForStage);
}
}

return new SparkPlanInfoForStage(info.nodeName(), info.simpleString(), children);
}
else {
System.out.println("Not started, looking in childrens");
for (SparkPlanInfo child : JavaConverters.asJavaCollection(info.children())) {
System.out.println("Looking for children: " + child.nodeName());
SparkPlanInfoForStage infoForStage = computeStageInfoForStage(child, accumulatorToStage, stageId, false);

if (infoForStage != null) {
return infoForStage;
}
}
}

System.out.println("End of node: " + info.nodeName());
return null;
}

public static Set<Integer> stageIdsForPlan(SparkPlanInfo info, HashMap<Long, Integer> accumulatorToStage) {
Set<Integer> stageIds = new HashSet<>();

for (SQLMetricInfo metric : JavaConverters.asJavaCollection(info.metrics())) {
Integer stageId = accumulatorToStage.get(metric.accumulatorId());

if (stageId != null) {
stageIds.add(stageId);
}
}

return stageIds;
}

public static class SparkPlanInfoForStage {
private final String nodeName;
private final String simpleString;
private final List<SparkPlanInfoForStage> children;

public SparkPlanInfoForStage(String nodeName, String simpleString, List<SparkPlanInfoForStage> children) {
this.nodeName = nodeName;
this.simpleString = simpleString;
this.children = children;
}

public String toJson() {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectMapper mapper =
new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
try {
JsonGenerator generator = mapper.getFactory().createGenerator(baos);
this.toJson(generator);
generator.close();
baos.close();
return new String(baos.toByteArray(), StandardCharsets.UTF_8);
} catch (IOException e) {
return null;
}
}

private void toJson(JsonGenerator generator) throws IOException {
generator.writeStartObject();
generator.writeStringField("node_name", nodeName);
generator.writeStringField("simple_string", simpleString);

// Writing child nodes
if (children.size() > 0) {
generator.writeFieldName("children");
generator.writeStartArray();
for (SparkPlanInfoForStage child : children) {
child.toJson(generator);
}
generator.writeEndArray();
}

generator.writeEndObject();
}
}
}

0 comments on commit d90a497

Please sign in to comment.