Skip to content

Commit 9813877

Browse files
[XLA:Original Value] Propagate OriginalValue information when hoisting loop-invariant instructions.
When `HoistLoopInvariantInstructions` creates a new while loop, the `OriginalValue` for the new while instruction is constructed. Elements from the original while tuple retain their original values. For hoisted instructions, their original values are copied, and the original instruction name is prefixed to indicate they are loop-invariant values hoisted out of the while loop. Also, introduced `GetOriginalCallInstructionName` and use it in `CallInliner`. PiperOrigin-RevId: 814772914
1 parent 21a7c86 commit 9813877

10 files changed

+327
-15
lines changed

xla/hlo/ir/hlo_original_value.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,4 +309,24 @@ bool OriginalValue::IsCompatibleWith(const Shape& shape) const {
309309
return tree().IsStructurallyCompatible(shape);
310310
}
311311

312+
std::optional<std::string> OriginalValue::GetOriginalCallLikeInstructions()
313+
const {
314+
if (is_synthetic_call()) {
315+
// Synthetic call are transparent and hence resulting in empty call
316+
// instructions.
317+
return "";
318+
}
319+
if (IsEmpty()) {
320+
// Currently we don't track original call information separately and rely
321+
// on the first leaf to find the original call information. So if there are
322+
// no leaves we return std::nullopt.
323+
return std::nullopt;
324+
}
325+
auto original_array = original_arrays().begin()->second;
326+
if (!original_array.has_value()) {
327+
return std::nullopt;
328+
}
329+
return original_array->instruction_name;
330+
}
331+
312332
} // namespace xla

xla/hlo/ir/hlo_original_value.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,12 @@ class OriginalValue {
135135
return !(*this == other);
136136
}
137137

138+
// Gets the (partial) call hierarchy string of the original call instructions
139+
// that this OriginalValue is associated with. Returns std::nullopt if this
140+
// OriginalValue is not associated with a call instruction or the call
141+
// hierarchy is lost (e.g., after complicated optimizations).
142+
std::optional<std::string> GetOriginalCallLikeInstructions() const;
143+
138144
template <typename H>
139145
friend H AbslHashValue(H h, const OriginalValue& value) {
140146
h = H::combine(std::move(h), value.is_synthetic_call());

xla/service/BUILD

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4707,10 +4707,10 @@ xla_cc_test(
47074707
"//xla/hlo/utils:hlo_matchers",
47084708
"//xla/tests:xla_internal_test_main",
47094709
"//xla/tsl/lib/core:status_test_util",
4710+
"//xla/tsl/platform:statusor",
47104711
"@com_google_absl//absl/algorithm:container",
47114712
"@com_google_absl//absl/status:statusor",
47124713
"@com_google_absl//absl/strings:string_view",
4713-
"@tsl//tsl/platform:statusor",
47144714
],
47154715
)
47164716

@@ -4829,6 +4829,8 @@ cc_library(
48294829
"//xla/hlo/pass:hlo_pass",
48304830
"//xla/hlo/transforms/simplifiers:hlo_dce",
48314831
"//xla/hlo/transforms/simplifiers:tuple_simplifier",
4832+
"//xla/tsl/platform:errors",
4833+
"//xla/tsl/platform:statusor",
48324834
"@com_google_absl//absl/algorithm:container",
48334835
"@com_google_absl//absl/container:flat_hash_map",
48344836
"@com_google_absl//absl/container:flat_hash_set",
@@ -4837,8 +4839,6 @@ cc_library(
48374839
"@com_google_absl//absl/log:check",
48384840
"@com_google_absl//absl/status:statusor",
48394841
"@com_google_absl//absl/strings",
4840-
"@tsl//tsl/platform:errors",
4841-
"@tsl//tsl/platform:statusor",
48424842
],
48434843
)
48444844

@@ -4857,8 +4857,8 @@ xla_cc_test(
48574857
"//xla/hlo/utils:hlo_matchers",
48584858
"//xla/tests:xla_internal_test_main",
48594859
"//xla/tsl/lib/core:status_test_util",
4860+
"//xla/tsl/platform:statusor",
48604861
"@com_google_absl//absl/log",
4861-
"@tsl//tsl/platform:statusor",
48624862
],
48634863
)
48644864

@@ -4873,13 +4873,16 @@ cc_library(
48734873
"//xla/hlo/analysis:while_loop_analysis",
48744874
"//xla/hlo/ir:hlo",
48754875
"//xla/hlo/pass:hlo_pass",
4876+
"//xla/tsl/platform:errors",
4877+
"//xla/tsl/platform:statusor",
48764878
"@com_google_absl//absl/algorithm:container",
48774879
"@com_google_absl//absl/container:flat_hash_map",
48784880
"@com_google_absl//absl/container:flat_hash_set",
48794881
"@com_google_absl//absl/container:inlined_vector",
48804882
"@com_google_absl//absl/log",
48814883
"@com_google_absl//absl/log:check",
48824884
"@com_google_absl//absl/status:statusor",
4885+
"@com_google_absl//absl/strings",
48834886
"@com_google_absl//absl/strings:string_view",
48844887
],
48854888
)
@@ -4895,6 +4898,7 @@ xla_cc_test(
48954898
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
48964899
"//xla/hlo/utils:hlo_matchers",
48974900
"//xla/tests:xla_internal_test_main",
4901+
"//xla/tsl/platform:statusor",
48984902
"@com_google_googletest//:gtest",
48994903
"@tsl//tsl/platform:statusor",
49004904
],

xla/service/call_inliner.cc

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,21 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
214214
new_hlo_pointer->set_original_value(nullptr);
215215
return;
216216
}
217+
std::optional<std::string> call_instructions =
218+
call_original_value->GetOriginalCallLikeInstructions();
219+
if (!call_instructions.has_value()) {
220+
// If the call instruction is lost, we must drop the original values
221+
// on the inlined instructions because the call hierarchy is lost.
222+
new_hlo_pointer->set_original_value(nullptr);
223+
return;
224+
}
217225
new_hlo_pointer->CopyOriginalValue(hlo, /*clone=*/true,
218226
/*issue_warning=*/true);
219-
if (call_original_value->is_synthetic_call()) {
227+
if (call_instructions->empty()) {
228+
// Empty call instructions means the call is synthetic and hence the
229+
// inlined instruction do not need to be prefixed with the call
230+
// instructions. Hence we can just return here to have the copied original
231+
// value to be used.
220232
return;
221233
}
222234
std::shared_ptr<OriginalValue> original_value =
@@ -227,12 +239,8 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
227239
for (auto& pair : original_value->mutable_original_arrays()) {
228240
std::optional<OriginalArray>& original_array = pair.second;
229241
if (original_array.has_value()) {
230-
std::string call_instruction_name =
231-
call_original_value->original_arrays()
232-
.begin()
233-
->second->instruction_name;
234242
original_array->instruction_name = absl::StrCat(
235-
call_instruction_name, "/", original_array->instruction_name);
243+
*call_instructions, "/", original_array->instruction_name);
236244
}
237245
}
238246
}

xla/service/while_loop_expensive_invariant_code_motion.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@ limitations under the License.
1717

1818
#include <cstdint>
1919
#include <iterator>
20+
#include <memory>
21+
#include <optional>
2022
#include <string>
23+
#include <utility>
2124
#include <vector>
2225

2326
#include "absl/algorithm/container.h"
@@ -27,10 +30,20 @@ limitations under the License.
2730
#include "absl/log/check.h"
2831
#include "absl/log/log.h"
2932
#include "absl/status/statusor.h"
33+
#include "absl/strings/str_cat.h"
3034
#include "absl/strings/string_view.h"
3135
#include "xla/hlo/analysis/while_loop_analysis.h"
36+
#include "xla/hlo/ir/hlo_computation.h"
37+
#include "xla/hlo/ir/hlo_instruction.h"
38+
#include "xla/hlo/ir/hlo_opcode.h"
39+
#include "xla/hlo/ir/hlo_original_value.h"
40+
#include "xla/hlo/ir/hlo_print_options.h"
41+
#include "xla/map_util.h"
3242
#include "xla/service/while_util.h"
43+
#include "xla/shape.h"
3344
#include "xla/shape_util.h"
45+
#include "xla/tsl/platform/errors.h"
46+
#include "xla/tsl/platform/statusor.h"
3447
#include "xla/util.h"
3548

3649
namespace xla {
@@ -101,6 +114,32 @@ static void CreateLoopInvariantCopy(
101114
old_instruction->CloneWithNewOperands(old_instruction->shape(),
102115
new_operands));
103116

117+
std::optional<std::string> original_call_instructions;
118+
if (while_instr->original_value() != nullptr) {
119+
original_call_instructions =
120+
while_instr->original_value()->GetOriginalCallLikeInstructions();
121+
}
122+
if (original_call_instructions.has_value() &&
123+
old_instruction->original_value() != nullptr) {
124+
std::string original_call_prefix;
125+
if (!original_call_instructions->empty()) {
126+
// We only add the wildcard iteration count if the call-like
127+
// instruction is available.
128+
original_call_prefix =
129+
absl::StrCat(*original_call_instructions, "#*/");
130+
}
131+
132+
auto new_original_value = std::make_shared<OriginalValue>(
133+
*old_instruction->original_value());
134+
for (auto& [shape_index, original_array] :
135+
new_original_value->mutable_original_arrays()) {
136+
if (original_array) {
137+
original_array->instruction_name = absl::StrCat(
138+
original_call_prefix, original_array->instruction_name);
139+
}
140+
}
141+
new_instruction->set_original_value(std::move(new_original_value));
142+
}
104143
info.hoisted_copy = new_instruction;
105144
}
106145

xla/service/while_loop_expensive_invariant_code_motion_test.cc

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ limitations under the License.
2323
#include "xla/hlo/parser/hlo_parser.h"
2424
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
2525
#include "xla/hlo/utils/hlo_matchers.h"
26+
#include "xla/tsl/platform/statusor.h"
2627
#include "xla/util.h"
27-
#include "tsl/platform/statusor.h"
2828

2929
namespace xla {
3030
namespace {
@@ -268,5 +268,75 @@ ENTRY entry {
268268
EXPECT_FALSE(simplified_loop);
269269
}
270270

271+
TEST_F(WhileLoopExpensiveInvariantCodeMotionTest, HoistWithOriginalValue) {
272+
const char* const hlo_string = R"(
273+
HloModule licm_ov_test
274+
275+
body {
276+
p_body = (f32[8,8], f32[16, 8]) parameter(0)
277+
b = f32[16, 8] get-tuple-element(p_body), index=1
278+
const = f32[] constant(1.0)
279+
lhs = f32[8, 16] broadcast(const), dimensions={}, origin={{"lhs.1"}}
280+
dot = f32[8,8] dot(lhs, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}, origin={{"dot.1"}}
281+
a = f32[8,8] get-tuple-element(p_body), index=0
282+
add = f32[8,8] add(a, dot)
283+
ROOT root = (f32[8,8], f32[16,8]) tuple(add, b)
284+
}
285+
286+
condition {
287+
p_cond = (f32[8,8], f32[16, 8]) parameter(0)
288+
ROOT result = pred[] constant(true)
289+
}
290+
291+
ENTRY entry {
292+
param0 = f32[8,8] parameter(0)
293+
param1 = f32[16, 8] parameter(1)
294+
while_init = (f32[8,8], f32[16,8]) tuple(param0, param1)
295+
ROOT while0 = (f32[8,8], f32[16, 8]) while(while_init), condition=condition, body=body, origin={({"while.5" {0}},{"while.5" {1}})}
296+
}
297+
)";
298+
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
299+
HloComputation* body = m->GetComputationWithName("body");
300+
HloInstruction* dot = body->GetInstructionWithName("dot");
301+
HloInstruction* lhs = body->GetInstructionWithName("lhs");
302+
303+
TF_ASSERT_OK_AND_ASSIGN(
304+
bool simplified_loop,
305+
WhileLoopExpensiveInvariantCodeMotion(
306+
/*worth_hoisting_individually=*/HloPredicateIsOp<HloOpcode::kDot>)
307+
.Run(m.get()));
308+
EXPECT_TRUE(simplified_loop);
309+
310+
HloInstruction* transformed_while = nullptr;
311+
for (auto* instr : m->entry_computation()->instructions()) {
312+
if (instr->opcode() == HloOpcode::kWhile) {
313+
transformed_while = instr;
314+
break;
315+
}
316+
}
317+
ASSERT_NE(transformed_while, nullptr);
318+
319+
HloInstruction* hoisted_dot = nullptr;
320+
HloInstruction* hoisted_lhs = nullptr;
321+
for (auto* instr : m->entry_computation()->instructions()) {
322+
if (instr->opcode() == HloOpcode::kDot && instr->shape() == dot->shape()) {
323+
hoisted_dot = instr;
324+
}
325+
if (instr->opcode() == HloOpcode::kBroadcast &&
326+
instr->shape() == lhs->shape()) {
327+
hoisted_lhs = instr;
328+
}
329+
}
330+
ASSERT_NE(hoisted_dot, nullptr);
331+
ASSERT_NE(hoisted_lhs, nullptr);
332+
ASSERT_NE(hoisted_dot->original_value(), nullptr);
333+
EXPECT_EQ(hoisted_dot->original_value()->ToString(), "{\"while.5#*/dot.1\"}");
334+
ASSERT_NE(hoisted_lhs->original_value(), nullptr);
335+
EXPECT_EQ(hoisted_lhs->original_value()->ToString(), "{\"while.5#*/lhs.1\"}");
336+
ASSERT_NE(transformed_while->original_value(), nullptr);
337+
EXPECT_EQ(transformed_while->original_value()->ToString(),
338+
"({\"while.5\" {0}}, {\"while.5\" {1}}, {\"while.5#*/dot.1\"})");
339+
}
340+
271341
} // namespace
272342
} // namespace xla

xla/service/while_loop_invariant_code_motion.cc

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@ limitations under the License.
1717

1818
#include <cstdint>
1919
#include <iterator>
20+
#include <memory>
21+
#include <optional>
2022
#include <string>
23+
#include <utility>
2124
#include <vector>
2225

2326
#include "absl/algorithm/container.h"
@@ -26,21 +29,24 @@ limitations under the License.
2629
#include "absl/container/inlined_vector.h"
2730
#include "absl/log/check.h"
2831
#include "absl/log/log.h"
32+
#include "absl/strings/str_cat.h"
2933
#include "absl/strings/string_view.h"
3034
#include "xla/hlo/analysis/while_loop_analysis.h"
3135
#include "xla/hlo/ir/hlo_computation.h"
3236
#include "xla/hlo/ir/hlo_instruction.h"
3337
#include "xla/hlo/ir/hlo_opcode.h"
38+
#include "xla/hlo/ir/hlo_original_value.h"
39+
#include "xla/hlo/ir/hlo_print_options.h"
3440
#include "xla/hlo/transforms/simplifiers/hlo_dce.h"
3541
#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h"
3642
#include "xla/map_util.h"
3743
#include "xla/service/compile_time_cap.h"
3844
#include "xla/service/while_util.h"
3945
#include "xla/shape.h"
4046
#include "xla/shape_util.h"
47+
#include "xla/tsl/platform/errors.h"
48+
#include "xla/tsl/platform/statusor.h"
4149
#include "xla/util.h"
42-
#include "tsl/platform/errors.h"
43-
#include "tsl/platform/statusor.h"
4450

4551
namespace xla {
4652

@@ -93,6 +99,33 @@ static void CreateLoopInvariantCopy(
9399
parent_of_while->AddInstruction(old_instruction->CloneWithNewOperands(
94100
old_instruction->shape(), new_operands));
95101

102+
std::optional<std::string> original_call_instructions;
103+
if (while_instr->original_value() != nullptr) {
104+
original_call_instructions =
105+
while_instr->original_value()->GetOriginalCallLikeInstructions();
106+
}
107+
if (original_call_instructions.has_value() &&
108+
old_instruction->original_value() != nullptr) {
109+
std::string original_call_prefix;
110+
if (!original_call_instructions->empty()) {
111+
// We only add the wildcard iteration count if the call-like
112+
// instruction is available.
113+
original_call_prefix =
114+
absl::StrCat(*original_call_instructions, "#*/");
115+
}
116+
117+
auto new_original_value =
118+
std::make_shared<OriginalValue>(*old_instruction->original_value());
119+
for (auto& [shape_index, original_array] :
120+
new_original_value->mutable_original_arrays()) {
121+
if (original_array) {
122+
original_array->instruction_name = absl::StrCat(
123+
original_call_prefix, original_array->instruction_name);
124+
}
125+
}
126+
new_instruction->set_original_value(std::move(new_original_value));
127+
}
128+
96129
InsertOrDie(hoisted_instructions, old_instruction, new_instruction);
97130

98131
// Approximately half of the instructions that would normally be present

0 commit comments

Comments
 (0)