Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:MSA] Fixes an issue that leads to sync mem op replacements not respecting the auxiliary control dependencies while converting sync mem ops to async #23276

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions xla/service/memory_space_assignment/algorithm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2857,6 +2857,20 @@ AllocationRequest MsaAlgorithm::CreateAllocationRequest(
}
required_copy_allocation_latest_time =
std::min(earliest_use_time, earliest_position_time);
// We need to make sure that the copy allocation is scheduled before the
// controlled successor of the sync mem op.
for (const HloInstruction* control_successor :
required_copy_allocation_for->control_successors()) {
int64_t successor_time = instruction_schedule.at(control_successor);
if (successor_time < required_copy_allocation_latest_time) {
VLOG(3) << "Updating the required replacement async mem op allocation "
"latest time from "
<< required_copy_allocation_latest_time << " to "
<< successor_time << ", because of control successor "
<< control_successor->ToString();
required_copy_allocation_latest_time = successor_time;
}
}
}
int64_t use_time = instruction_schedule.at(hlo_use.instruction);
bool allow_no_copy_alternate_mem_allocation = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,55 @@ ENTRY entry {
op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, p0));
}

TEST_F(MemorySpaceAssignmentTest, SyncCopyReplacementWithControlPredecessor) {
absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true

ENTRY entry {
p0 = f32[2,3]{1,0} parameter(0)
p1 = f32[2,3]{1,0} parameter(1)
negate0 = f32[2,3]{1,0} negate(p1)
negate1 = f32[2,3]{1,0} negate(negate0)
negate2 = f32[2,3]{1,0} negate(negate1)
negate3 = f32[2,3]{1,0} negate(negate2)
negate4 = f32[2,3]{1,0} negate(negate3)
negate5 = f32[2,3]{1,0} negate(negate4)
negate6 = f32[2,3]{1,0} negate(negate5)
negate7 = f32[2,3]{1,0} negate(negate6)
p0_copy = f32[2,3]{1,0} copy(p0)
negate8 = f32[2,3]{1,0} negate(negate7)
negate9 = f32[2,3]{1,0} negate(negate8), control-predecessors={p0_copy}
ROOT add0 = f32[2,3]{1,0} add(p0_copy, negate9)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
Options options = DefaultMemorySpaceOptions();
options.enable_sync_copy_replacement = true;
AssignMemorySpace(module.get(), options);
HloInstruction* add0 = FindInstruction(module.get(), "add0");
ASSERT_NE(add0, nullptr);
HloInstruction* p0 = FindInstruction(module.get(), "p0");
ASSERT_NE(p0, nullptr);
EXPECT_THAT(add0->operand(0),
op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, p0));
HloInstruction* negate9 = FindInstruction(module.get(), "negate9");
ASSERT_NE(negate9, nullptr);
const HloInstruction* copy_done = add0->operand(0);
const HloInstructionSequence& sequence =
module->schedule().sequence(module->entry_computation());
auto find_index = [&](const HloInstruction* instruction) {
return std::distance(sequence.instructions().begin(),
std::find(sequence.instructions().begin(),
sequence.instructions().end(), instruction));
};
int64_t copy_done_time = find_index(copy_done);
int64_t negate9_time = find_index(negate9);
// The negate9 instruction should be scheduled after the copy done, because of
// the control dependency constraint.
EXPECT_GT(negate9_time, copy_done_time);
}

// This is a case where we p0_copy uses and and p0 uses after copy(p0) are not
// allowed to use the same async CopyAllocation. While p0 can be prefetched at
// p0_copy, but we may clobber the data if we use the same async copy used for
Expand Down