From 968aea7a2b69c57d633ede7ef10e834d8f385c61 Mon Sep 17 00:00:00 2001 From: Buddh Prakash Date: Fri, 3 Oct 2025 17:10:56 -0700 Subject: [PATCH] Add `FlattenDebugPayloadIntoMessage` to XLA error utilities. PiperOrigin-RevId: 814883674 --- xla/error/BUILD | 6 ++ xla/error/debug_me_context_util.cc | 54 ++++++++++++++---- xla/error/debug_me_context_util.h | 16 +++++- xla/error/debug_me_context_util_test.cc | 75 +++++++++++++++++++++++++ 4 files changed, 140 insertions(+), 11 deletions(-) diff --git a/xla/error/BUILD b/xla/error/BUILD index ef8418e6ffcda..b3af6fe6a1c3e 100644 --- a/xla/error/BUILD +++ b/xla/error/BUILD @@ -41,8 +41,10 @@ cc_library( deps = [ "//xla/tsl/platform:debug_me_context", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", + "@tsl//tsl/platform", ], ) @@ -52,8 +54,12 @@ xla_cc_test( deps = [ ":debug_me_context_util", "//xla/tsl/platform:debug_me_context", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform", ], ) diff --git a/xla/error/debug_me_context_util.cc b/xla/error/debug_me_context_util.cc index 1d432b0bfa49b..005f8fb6bb73f 100644 --- a/xla/error/debug_me_context_util.cc +++ b/xla/error/debug_me_context_util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/error/debug_me_context_util.h" +#include #include #include @@ -22,20 +23,12 @@ limitations under the License. #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "xla/tsl/platform/debug_me_context.h" +#include "tsl/platform/platform.h" namespace xla::error { -void AttachDebugMeContextPayload(absl::Status& status) { - if (!status.ok()) { - std::string error_message_string = DebugMeContextToErrorMessageString(); - if (!error_message_string.empty()) { - status.SetPayload(kDebugContextPayloadUrl, - absl::Cord(error_message_string)); - } - } -} - std::string DebugMeContextToErrorMessageString() { if (!tsl::DebugMeContext::HasAnyValues()) { return ""; @@ -73,4 +66,45 @@ std::string DebugMeContextToErrorMessageString() { return error_message; } +void AttachDebugMeContextPayload(absl::Status& status) { + if (!status.ok()) { + std::string error_message_string = DebugMeContextToErrorMessageString(); + if (!error_message_string.empty()) { + status.SetPayload(kDebugContextPayloadUrl, + absl::Cord(error_message_string)); + } + } +} + +absl::Status FlattenDebugPayloadIntoMessage(const absl::Status& status) { + if (status.ok()) { + return status; + } + + std::optional debug_context_payload = + status.GetPayload(kDebugContextPayloadUrl); + if (!debug_context_payload.has_value()) { + return status; + } + + std::string new_message = + absl::StrCat(status.message(), "\n", debug_context_payload.value()); +#if defined(PLATFORM_GOOGLE) + absl::Status new_status(status.code(), new_message, + status.GetSourceLocations().front()); +#else // ndef PLATFORM_GOOGLE + absl::Status new_status(status.code(), new_message); +#endif // ndef PLATFORM_GOOGLE + + // Copy all other payloads from the old status to the new one. + status.ForEachPayload([&](absl::string_view type_url, const absl::Cord& p) { + if (type_url != kDebugContextPayloadUrl) { + new_status.SetPayload(type_url, p); + } + }); + + // Replace the original status with our new, updated one. + return new_status; +} + } // namespace xla::error diff --git a/xla/error/debug_me_context_util.h b/xla/error/debug_me_context_util.h index 7e62007f0041a..96b9dada5d618 100644 --- a/xla/error/debug_me_context_util.h +++ b/xla/error/debug_me_context_util.h @@ -20,8 +20,8 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "xla/tsl/platform/debug_me_context.h" // This file provides XLA-specific specializations and utilities for the // thread-local debugging context system. @@ -59,6 +59,20 @@ std::string DebugMeContextToErrorMessageString(); // status, if the context is not empty and the status is not OK. void AttachDebugMeContextPayload(absl::Status& status); +// If the status contains a DebugMeContext payload, this function will add it to +// the status's message and remove the payload. Otherwise, do nothing. +absl::Status FlattenDebugPayloadIntoMessage(const absl::Status& status); + +template +inline absl::StatusOr FlattenDebugPayloadIntoMessage( + const absl::StatusOr& status_or) { + if (status_or.ok()) { + return status_or; + } + + return FlattenDebugPayloadIntoMessage(status_or.status()); +} + } // namespace error } // namespace xla diff --git a/xla/error/debug_me_context_util_test.cc b/xla/error/debug_me_context_util_test.cc index ef87f5a32f0d0..6e6f176b189c3 100644 --- a/xla/error/debug_me_context_util_test.cc +++ b/xla/error/debug_me_context_util_test.cc @@ -15,16 +15,24 @@ limitations under the License. #include "xla/error/debug_me_context_util.h" +#include #include +#include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" #include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "xla/tsl/platform/debug_me_context.h" +#include "tsl/platform/platform.h" namespace xla { namespace { +using ::testing::EndsWith; + TEST(DebugMeContextUtil, StringCheck) { constexpr absl::string_view kCompilerName{"MyCompiler"}; @@ -37,5 +45,72 @@ TEST(DebugMeContextUtil, StringCheck) { EXPECT_TRUE(absl::StrContains(error_message, kCompilerName)); } +TEST(FlattenDebugPayloadIntoMessage, StatusWithPayloadIsFlattened) { + absl::Status status = absl::InternalError("Original message."); + status.SetPayload(error::kDebugContextPayloadUrl, absl::Cord("Debug info.")); + + status = error::FlattenDebugPayloadIntoMessage(status); + + EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + EXPECT_TRUE( + absl::StrContains(status.message(), "Original message.\nDebug info.")); + EXPECT_FALSE(status.GetPayload(error::kDebugContextPayloadUrl).has_value()); +#if defined(PLATFORM_GOOGLE) + EXPECT_THAT(status.GetSourceLocations().front().file_name(), + EndsWith("debug_me_context_util_test.cc")); +#endif // defined(PLATFORM_GOOGLE) +} + +TEST(FlattenDebugPayloadIntoMessage, StatusWithOtherPayloadsIsPreserved) { + constexpr absl::string_view kOtherPayloadUrl = "other_payload"; + constexpr absl::string_view kOtherPayloadContent = "preserved"; + absl::Status status = absl::InternalError("Original message."); + status.SetPayload(error::kDebugContextPayloadUrl, absl::Cord("Debug info.")); + status.SetPayload(kOtherPayloadUrl, absl::Cord(kOtherPayloadContent)); + + status = error::FlattenDebugPayloadIntoMessage(status); + + // Assert the debug payload was flattened. + EXPECT_TRUE( + absl::StrContains(status.message(), "Original message.\nDebug info.")); + EXPECT_FALSE(status.GetPayload(error::kDebugContextPayloadUrl).has_value()); + std::optional other_payload = status.GetPayload(kOtherPayloadUrl); + ASSERT_TRUE(other_payload.has_value()); + EXPECT_EQ(other_payload.value(), kOtherPayloadContent); +} + +TEST(FlattenDebugPayloadIntoMessage, StatusWithoutPayloadIsUnchanged) { + absl::Status status = absl::InternalError("Original message."); + absl::Status original_status = status; + + status = error::FlattenDebugPayloadIntoMessage(status); + + EXPECT_EQ(status, original_status); +} + +TEST(FlattenDebugPayloadIntoMessage, OkStatusIsUnchanged) { + absl::Status status = absl::OkStatus(); + + status = error::FlattenDebugPayloadIntoMessage(status); + + EXPECT_TRUE(status.ok()); +} + +TEST(FlattenDebugPayloadIntoMessage, StatusOrWithPayloadIsFlattened) { + absl::Status status = absl::InternalError("Original message."); + status.SetPayload(error::kDebugContextPayloadUrl, absl::Cord("Debug info.")); + absl::StatusOr status_or = status; + + status_or = error::FlattenDebugPayloadIntoMessage(status_or); + + EXPECT_FALSE(status_or.ok()); + EXPECT_EQ(status_or.status().code(), absl::StatusCode::kInternal); + EXPECT_TRUE(absl::StrContains(status_or.status().message(), + "Original message.\nDebug info.")); + EXPECT_FALSE(status_or.status() + .GetPayload(error::kDebugContextPayloadUrl) + .has_value()); +} + } // namespace } // namespace xla