Skip to content

Commit

Permalink
[XLA:MSA] Fixes an issue that leads to sync mem op replacements not r…
Browse files Browse the repository at this point in the history
…especting the auxiliary control dependencies while converting sync mem ops to async

MSA assumes all control dependencies (predesseccors/successors) are met in earlier passes by the scheduler. Therefore, it traditionally only respects the data flow dependencies (i.e., instruction.uses()). This became problematic when we replaced sync copies with async ones: the async copies where not aware of indirect control dependencies when we were determining their latest allowed copy-done time.

PiperOrigin-RevId: 732288966
  • Loading branch information
mehrdadkhani authored and Google-ML-Automation committed Mar 1, 2025
1 parent 5aa1c50 commit 4528217
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
14 changes: 14 additions & 0 deletions xla/service/memory_space_assignment/algorithm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2857,6 +2857,20 @@ AllocationRequest MsaAlgorithm::CreateAllocationRequest(
}
required_copy_allocation_latest_time =
std::min(earliest_use_time, earliest_position_time);
// We need to make sure that the copy allocation is scheduled before the
// controlled successor of the sync mem op.
for (const HloInstruction* control_successor :
required_copy_allocation_for->control_successors()) {
int64_t successor_time = instruction_schedule.at(control_successor);
if (successor_time < required_copy_allocation_latest_time) {
VLOG(3) << "Updating the required replacement async mem op allocation "
"latest time from "
<< required_copy_allocation_latest_time << " to "
<< successor_time << ", because of control successor "
<< control_successor->ToString();
required_copy_allocation_latest_time = successor_time;
}
}
}
int64_t use_time = instruction_schedule.at(hlo_use.instruction);
bool allow_no_copy_alternate_mem_allocation = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,55 @@ ENTRY entry {
op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, p0));
}

TEST_F(MemorySpaceAssignmentTest, SyncCopyReplacementWithControlPredecessor) {
absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true

ENTRY entry {
p0 = f32[2,3]{1,0} parameter(0)
p1 = f32[2,3]{1,0} parameter(1)
negate0 = f32[2,3]{1,0} negate(p1)
negate1 = f32[2,3]{1,0} negate(negate0)
negate2 = f32[2,3]{1,0} negate(negate1)
negate3 = f32[2,3]{1,0} negate(negate2)
negate4 = f32[2,3]{1,0} negate(negate3)
negate5 = f32[2,3]{1,0} negate(negate4)
negate6 = f32[2,3]{1,0} negate(negate5)
negate7 = f32[2,3]{1,0} negate(negate6)
p0_copy = f32[2,3]{1,0} copy(p0)
negate8 = f32[2,3]{1,0} negate(negate7)
negate9 = f32[2,3]{1,0} negate(negate8), control-predecessors={p0_copy}
ROOT add0 = f32[2,3]{1,0} add(p0_copy, negate9)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
Options options = DefaultMemorySpaceOptions();
options.enable_sync_copy_replacement = true;
AssignMemorySpace(module.get(), options);
HloInstruction* add0 = FindInstruction(module.get(), "add0");
ASSERT_NE(add0, nullptr);
HloInstruction* p0 = FindInstruction(module.get(), "p0");
ASSERT_NE(p0, nullptr);
EXPECT_THAT(add0->operand(0),
op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, p0));
HloInstruction* negate9 = FindInstruction(module.get(), "negate9");
ASSERT_NE(negate9, nullptr);
const HloInstruction* copy_done = add0->operand(0);
const HloInstructionSequence& sequence =
module->schedule().sequence(module->entry_computation());
auto find_index = [&](const HloInstruction* instruction) {
return std::distance(sequence.instructions().begin(),
std::find(sequence.instructions().begin(),
sequence.instructions().end(), instruction));
};
int64_t copy_done_time = find_index(copy_done);
int64_t negate9_time = find_index(negate9);
// The negate9 instruction should be scheduled after the copy done, because of
// the control dependency constraint.
EXPECT_GT(negate9_time, copy_done_time);
}

// This is a case where we p0_copy uses and and p0 uses after copy(p0) are not
// allowed to use the same async CopyAllocation. While p0 can be prefetched at
// p0_copy, but we may clobber the data if we use the same async copy used for
Expand Down

0 comments on commit 4528217

Please sign in to comment.