Skip to content

Commit f4addd2

Browse files
#HLODiff Continue doing hlo diff when value tracing errors.
PiperOrigin-RevId: 839923406
1 parent 34cfecb commit f4addd2

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.cc

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,11 @@ absl::Status HloGumgraph::PrecomputeComputationFingerprint() {
321321

322322
void HloGumgraph::PrecomputeInstructionDependencies() {
323323
LOG(INFO) << "Precomputing instruction dependencies";
324+
if (hlo_value_tracing_ == nullptr) {
325+
LOG(WARNING) << "Skipping PrecomputeInstructionDependencies because "
326+
"HloValueTracing failed to initialize.";
327+
return;
328+
}
324329
for (auto* computation : hlo_module_.MakeComputationPostOrder()) {
325330
for (auto* instruction : computation->MakeInstructionPostOrder()) {
326331
HloInstructionNode* node = GetNode(instruction);
@@ -370,11 +375,23 @@ absl::StatusOr<std::unique_ptr<const HloGumgraph>> HloGumgraph::Create(
370375
<< "Expected a non-null entry computation";
371376

372377
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(hlo_module);
373-
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloValueTracing> hlo_value_tracing,
374-
HloValueTracing::Run(*hlo_module));
378+
precompute_instruction_dependencies = true;
379+
std::unique_ptr<HloValueTracing> hlo_value_tracing_ptr = nullptr;
380+
if (precompute_instruction_dependencies) {
381+
absl::StatusOr<std::unique_ptr<HloValueTracing>> hlo_value_tracing =
382+
HloValueTracing::Run(*hlo_module);
383+
if (hlo_value_tracing.ok()) {
384+
hlo_value_tracing_ptr = *std::move(hlo_value_tracing);
385+
} else {
386+
LOG(WARNING) << "Failed to run HloValueTracing: "
387+
<< hlo_value_tracing.status();
388+
// hlo_value_tracing_ptrs is left as nullptr.
389+
}
390+
}
391+
375392
auto graph = absl::WrapUnique(
376393
new HloGumgraph(*hlo_module, fingerprint_options, std::move(call_graph),
377-
std::move(hlo_value_tracing)));
394+
std::move(hlo_value_tracing_ptr)));
378395

379396
TF_RETURN_IF_ERROR(graph->ConstructGraph(*hlo_module));
380397
TF_ASSIGN_OR_RETURN(std::vector<HloInstructionNode*> zero_indegree_nodes,

0 commit comments

Comments
 (0)