Skip to content

Commit

Permalink
[XLA] GPU Memory: Improve memory limit handling and shape size calcul…
Browse files Browse the repository at this point in the history
…ation
  • Loading branch information
zhenying-liu committed Feb 28, 2025
1 parent 33a054b commit 0eb5512
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 32 deletions.
6 changes: 6 additions & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2189,11 +2189,16 @@ xla_test(
"//xla:shape_util",
"//xla/hlo/analysis:hlo_ordering",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass_pipeline",
"//xla/hlo/testlib:filecheck",
"//xla/hlo/testlib:verified_hlo_module",
"//xla/hlo/utils:hlo_query",
"//xla/service:backend",
"//xla/service:hlo_module_config",
"//xla/service:legalize_scheduling_annotations",
"//xla/service/gpu:gpu_latency_hiding_scheduler",
"//xla/service/gpu/transforms:schedule_postprocessing",
"//xla/service/gpu/transforms:scheduling_instruction_annotator",
"//xla/stream_executor:device_description",
"//xla/tests:hlo_test_base",
"//xla/tests:test_utils",
Expand Down Expand Up @@ -3052,6 +3057,7 @@ cc_library(
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_query",
"//xla/service:hlo_cost_analysis",
"//xla/service:collective_ops_utils",
"//xla/service:collective_permute_decomposer",
"//xla/service:latency_hiding_scheduler",
Expand Down
4 changes: 1 addition & 3 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2496,9 +2496,7 @@ GpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,

HloCostAnalysis::ShapeSizeFunction GpuCompiler::ShapeSizeBytesFunction() const {
// Capture just the pointer size, not the entire GpuCompiler object.
return [pointer_size = pointer_size_](const Shape& shape) {
return GetSizeOfShape(shape, pointer_size);
};
return gpu::ShapeSizeBytesFunction(pointer_size_);
}

absl::StatusOr<std::unique_ptr<AotCompilationResult>> GpuCompiler::Export(
Expand Down
51 changes: 28 additions & 23 deletions xla/service/gpu/gpu_hlo_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ HloInstructionSequence PostprocessorToScheduleSyncCollectives(
return result;
}

SchedulerConfig MakeGPUSchedulerConfig(int64_t memory_limit,
SchedulerConfig MakeGPUSchedulerConfig(uint64_t memory_limit,
int64_t overlap_limit) {
SchedulerConfig config;
config.all_reduce_overlap_limit = 1;
Expand Down Expand Up @@ -510,19 +510,15 @@ std::unique_ptr<LatencyEstimator> GetLatencyEstimator(
LOG(INFO) << "Using analytical latency estimator";
return std::make_unique<AnalyticalLatencyEstimator>(
config, std::move(gpu_latency_estimator), gpu_device_info,
[input_pointer_size = pointer_size](const Shape& shape) {
return GetSizeOfShape(shape, input_pointer_size);
},
ShapeSizeBytesFunction(pointer_size),
module.entry_computation());
}

if (options.xla_gpu_enable_analytical_sol_latency_estimator()) {
LOG(INFO) << "Using Speed-of-Light (SoL) analytical latency estimator";
return std::make_unique<SolLatencyEstimator>(
config, std::move(gpu_latency_estimator), gpu_device_info,
[input_pointer_size = pointer_size](const Shape& shape) {
return GetSizeOfShape(shape, input_pointer_size);
},
ShapeSizeBytesFunction(pointer_size),
module.entry_computation());
}
return gpu_latency_estimator;
Expand Down Expand Up @@ -569,7 +565,7 @@ LegalizeSchedulingAnnotations::Config SchedulingAnnotationsConfig() {
// `pipeline`.
absl::Status RunLatencyHidingSchedulerPasses(
HloModule* module, int pointer_size, absl::string_view fingerprint,
int64_t memory_limit, const se::DeviceDescription& gpu_device_info) {
uint64_t memory_limit, const se::DeviceDescription& gpu_device_info) {
HloPassPipeline pipeline("latency-hiding-scheduler");
const DebugOptions& options = module->config().debug_options();
pipeline.AddPass<LegalizeSchedulingAnnotations>(
Expand All @@ -579,9 +575,7 @@ absl::Status RunLatencyHidingSchedulerPasses(
memory_limit,
options.xla_gpu_experimental_parallel_collective_overlap_limit());

auto shape_size_in_bytes = [pointer_size](const Shape& shape) {
return GetSizeOfShape(shape, pointer_size);
};
auto shape_size_in_bytes = ShapeSizeBytesFunction(pointer_size);

std::unique_ptr<LatencyEstimator> estimator = GetLatencyEstimator(
*module, pointer_size, gpu_device_info, fingerprint, config);
Expand Down Expand Up @@ -611,9 +605,9 @@ absl::Status RunLatencyHidingSchedulerPasses(

// Compute the device memory limit to be used by passes like scheduler and
// HLO rematerialization.
int64_t GetSchedulerMemoryLimit(const HloModule& module,
const se::DeviceDescription& gpu_device_info,
int pointer_size) {
uint64_t GetSchedulerMemoryLimit(const HloModule& module,
const se::DeviceDescription& gpu_device_info,
int pointer_size) {
// There is a "base" value which is either specified in HloModuleConfig
// (this value should take into account the fact that we need to leave some
// memory free for allocations that happen outside of XLA's allocator) or
Expand All @@ -622,25 +616,29 @@ int64_t GetSchedulerMemoryLimit(const HloModule& module,
//
// From that base value, subtract any input and output sizes (assuming they
// are live throughout the execution) and then apply a slop factor.
const int64_t base_limit =
const uint64_t base_limit =
module.config().device_memory_size() != 0
? module.config().device_memory_size()
: gpu_device_info.device_memory_size() * 80 / 100;

// Create size function that only counts device memory
auto get_device_shape_size = gpu::ShapeSizeBytesFunction(pointer_size,
/*memory_space=*/Layout::kDefaultMemorySpace);

// Find the total size of inputs and outputs.
int64_t total_io_size = 0;
uint64_t total_io_size = 0;
for (HloInstruction* param :
module.entry_computation()->parameter_instructions()) {
ShapeUtil::ForEachSubshape(
param->shape(),
[&](const Shape& subshape, const ShapeIndex& /*index*/) {
total_io_size += GetSizeOfShape(subshape, pointer_size);
total_io_size += get_device_shape_size(subshape);
});
}
ShapeUtil::ForEachSubshape(
module.result_shape(),
[&](const Shape& subshape, const ShapeIndex& /*index*/) {
total_io_size += GetSizeOfShape(subshape, pointer_size);
total_io_size += get_device_shape_size(subshape);
});

// If any inputs and outputs are aliased, do not double count them.
Expand All @@ -649,12 +647,19 @@ int64_t GetSchedulerMemoryLimit(const HloModule& module,
const HloInputOutputAliasConfig::Alias&) {
const Shape& subshape =
ShapeUtil::GetSubshape(module.result_shape(), output_index);
total_io_size -= GetSizeOfShape(subshape, pointer_size);
total_io_size -= get_device_shape_size(subshape);
});

int64_t limit =
(base_limit - total_io_size) *
module.config().debug_options().xla_gpu_memory_limit_slop_factor() / 100;
uint64_t limit = 0;
if (total_io_size > base_limit) {
LOG(ERROR) << "The byte size of input/output arguments (" << total_io_size
<< ") exceeds the base limit (" << base_limit
<< "). This indicates an error in the calculation!";
} else {
limit = (base_limit - total_io_size) *
module.config().debug_options().xla_gpu_memory_limit_slop_factor() /
100;
}
return limit;
}

Expand Down Expand Up @@ -747,7 +752,7 @@ absl::StatusOr<ScheduleMetadata> ScheduleGpuModule(
// Tag the module with its 128 bit fingerprint. The fingerprint should include
// instruction name with ids.
std::string fingerprint = TagWithFingerprint(module);
int64_t memory_limit =
uint64_t memory_limit =
GetSchedulerMemoryLimit(*module, gpu_device_info, pointer_size);

// Module already has a schedule, do nothing.
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/gpu_hlo_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace gpu {
absl::Status RunAsyncCollectivesConversionPasses(HloModule* module);

struct ScheduleMetadata {
int64_t scheduler_mem_limit;
uint64_t scheduler_mem_limit;
};

// Determines the schedule of HLO instructions for a module run on the GPU.
Expand Down
126 changes: 126 additions & 0 deletions xla/service/gpu/gpu_hlo_schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,18 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_schedule.h"
#include "xla/hlo/pass/hlo_pass_interface.h"
#include "xla/hlo/testlib/filecheck.h"
#include "xla/hlo/testlib/verified_hlo_module.h"
#include "xla/hlo/utils/hlo_query.h"
#include "xla/service/backend.h"
#include "xla/service/gpu/gpu_latency_hiding_scheduler.h"
#include "xla/service/gpu/gpu_compiler.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/latency_hiding_scheduler.h"
#include "xla/service/legalize_scheduling_annotations.h"
#include "xla/service/gpu/transforms/schedule_postprocessing.h"
#include "xla/service/gpu/transforms/scheduling_instruction_annotator.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/device_description.h"
Expand All @@ -57,6 +63,7 @@ namespace xla {
namespace gpu {

using ::testing::ElementsAre;
using ::testing::HasSubstr;
using ::tsl::testing::StatusIs;

class GpuHloScheduleTest : public HloTestBase {
Expand Down Expand Up @@ -110,6 +117,53 @@ class GpuHloScheduleTest : public HloTestBase {
// The fingerprint is 128 bits stored as a hex string (128/4 hex digits).
return it != attrs.map().end() && it->second.size() == 128 / 4;
}

// Run the gpu hlo scheduler and latency hiding scheduler
absl::StatusOr<bool> RunGpuLatencyHidingScheduler(HloModule* module,
uint64_t memory_limit) {
HloModuleConfig default_config = GetModuleConfig({});
auto* gpu_compiler = dynamic_cast<GpuCompiler*>(backend().compiler());
EXPECT_NE(gpu_compiler, nullptr);
const int64_t pointer_size = gpu_compiler->GetPointerSize();

auto shape_size_in_bytes = ShapeSizeBytesFunction(pointer_size);

int64_t initial_peak_memory = -1;
TF_ASSIGN_OR_RETURN(HloSchedule initial_schedule,
ScheduleGpuModuleWithMemoryScheduler(
module, pointer_size, &initial_peak_memory));

TF_CHECK_OK(module->set_schedule(std::move(initial_schedule)));

SchedulerConfig config;
config.memory_limit = memory_limit;

auto estimator = std::make_unique<ApproximateLatencyEstimator>();
auto async_tracker = std::make_unique<GpuAsyncTracker>(config);
auto tracker_ptr = async_tracker.get();
auto scheduler_core = std::make_unique<DefaultSchedulerCore>(
shape_size_in_bytes, tracker_ptr, estimator.get(), config,
/*target_scheduling_rule=*/nullptr,
/*early_target_scheduling_rule=*/nullptr,
/*post_processing_fn=*/nullptr,
/*scheduling_instruction_crosses_overlap_limit=*/
GpuScheduleCrossesOverlapLimit);

HloPassPipeline pipeline("latency-hiding-scheduler");
// Only run latency hiding scheduling if the memory limit is positive
// to avoid out of memory
if (memory_limit > 0) {
pipeline.AddPass<LatencyHidingScheduler>(
std::move(estimator), std::move(async_tracker),
std::move(scheduler_core), shape_size_in_bytes);
return pipeline.Run(module);
} else {
return Internal(
"The byte size of input/output arguments exceeds the "
"base limit. This indicates an error in the calculation!");
}
return true;
}
};

// Test of a single stream, where data dependencies fully determine the
Expand Down Expand Up @@ -1742,5 +1796,77 @@ TEST_F(GpuHloScheduleTest, DiscountCPUMemoryFromGPUPeakMemoryUsage) {
)"));
}

constexpr absl::string_view kCopyStartOverlap = R"(
HloModule conv_offloading
ENTRY %main (param_0: f32[1024], param_1: f32[1024]) -> f32[1024] {
%param_1 = f32[1024]{0} parameter(1)
%param_0 = f32[1024]{0} parameter(0)
%res_3 = f32[1024]{0} add(f32[1024]{0} %param_0, f32[1024]{0} %param_1)
%copy-start = (f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) copy-start(f32[1024]{0} %res_3)
%copy-done = f32[1024]{0:S(5)} copy-done((f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) %copy-start)
%res_4 = f32[1024]{0} tanh(f32[1024]{0} %res_3)
%copy-start.2 = (f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) copy-start(f32[1024]{0} %res_4)
%copy-done.2 = f32[1024]{0:S(5)} copy-done((f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) %copy-start.2)
%res_5 = f32[1024]{0} tanh(f32[1024]{0} %res_4)
%res_6 = f32[1024]{0} tanh(f32[1024]{0} %res_5)
%res_7 = f32[1024]{0} add(f32[1024]{0} %res_6, f32[1024]{0} %res_6)
%copy-start.1 = (f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) copy-start(f32[1024]{0:S(5)} %copy-done)
%copy-done.1 = f32[1024]{0} copy-done((f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) %copy-start.1)
%res_8 = f32[1024]{0} add(f32[1024]{0} %res_7, f32[1024]{0} %res_5)
%copy-start.3 = (f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) copy-start(f32[1024]{0:S(5)} %copy-done.2)
%copy-done.3 = f32[1024]{0} copy-done((f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) %copy-start.3)
%res_9 = f32[1024]{0} add(f32[1024]{0} %res_8, f32[1024]{0} %copy-done.3)
%res_10 = f32[1024]{0} add(f32[1024]{0} %res_9, f32[1024]{0} %copy-done.1)
ROOT %res_11 = f32[1024]{0} tanh(f32[1024]{0} %res_10)
})";

// This test ensures that the GPU scheduler applies latency hiding scheduling
// while adhering to a specified memory limit.
TEST_F(GpuHloScheduleTest, RunLHSToBeWithinMemoryLimit) {
TF_ASSERT_OK_AND_ASSIGN(
auto module,
ParseAndReturnVerifiedModule(kCopyStartOverlap, GetModuleConfig({})));

// Define a large memory limit for the scheduler.
constexpr uint64_t kMemoryLimitLarge = 22000;

// Run the latency hiding scheduler with the specified memory limit.
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunGpuLatencyHidingScheduler(
module.get(), kMemoryLimitLarge));

EXPECT_TRUE(changed);

EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
// CHECK: ENTRY
// CHECK: %copy-start.2 = (f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) copy-start
// CHECK: %copy-start = (f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) copy-start
// CHECK: %copy-done.2 = f32[1024]{0:S(5)} copy-done
// CHECK: %copy-start.3 = (f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) copy-start
// CHECK: %copy-done = f32[1024]{0:S(5)} copy-done
// CHECK: %copy-start.1 = (f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) copy-start
// CHECK: %copy-done.3 = f32[1024]{0} copy-done
// CHECK: %copy-done.1 = f32[1024]{0} copy-done
)"));
}

// This test verifies that the GPU scheduler doesn't run latency hiding
// scheduling if the given memory limit is negative.
TEST_F(GpuHloScheduleTest, NegativeTestMemoryLimit) {
TF_ASSERT_OK_AND_ASSIGN(
auto module,
ParseAndReturnVerifiedModule(kCopyStartOverlap, GetModuleConfig({})));

constexpr uint64_t kMemoryLimitNeg = 0;

// Run latency hiding scheduler with a negative memory limit
auto status =
RunGpuLatencyHidingScheduler(module.get(), kMemoryLimitNeg).status();
EXPECT_FALSE(status.ok());
EXPECT_THAT(
status.message(),
HasSubstr("The byte size of input/output arguments exceeds the "
"base limit. This indicates an error in the calculation!"));
}

} // namespace gpu
} // namespace xla
12 changes: 10 additions & 2 deletions xla/service/gpu/gpu_latency_hiding_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,22 @@ size_t CountOverlappingRanks(const std::vector<ReplicaGroup>& group,

} // namespace

int64_t GetSizeOfShape(const Shape& shape, int pointer_size) {
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction(
int64_t pointer_size, std::optional<int64_t> memory_space) {
return [pointer_size, memory_space](const Shape& shape) -> int64_t {
// Filter by memory space if specified
if (memory_space.has_value() && shape.has_layout() &&
shape.layout().memory_space() != memory_space.value()) {
return 0;
}
int64_t size = ShapeUtil::ByteSizeOf(shape, pointer_size);
if (shape.IsTuple() || shape.is_static()) {
return size;
}
// Each dynamic dimension size is represented as a S32.
int64_t metadata_size = sizeof(int32_t) * shape.dimensions_size();
return size + metadata_size;
};
}

CanonicalAsyncOp GpuGetCanonicalAsyncOp(const HloInstruction& hlo) {
Expand Down Expand Up @@ -540,7 +548,7 @@ ApproximateLatencyEstimator::TimeCost GpuLatencyEstimator::GetLatencyBetween(
.xla_gpu_enable_approx_costly_collectives();
bool is_all_reduce = from.GetInstr().opcode() == HloOpcode::kAllReduceStart;
bool collective_size_exceeds_threshold =
GetSizeOfShape(from.GetInstr().shape(), pointer_size_) >
ShapeSizeBytesFunction(pointer_size_)(from.GetInstr().shape()) >
kCostlyAllReduceThreshold;
if (enable_approx_collectives && is_all_reduce &&
collective_size_exceeds_threshold) {
Expand Down
7 changes: 5 additions & 2 deletions xla/service/gpu/gpu_latency_hiding_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.

#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/hlo_cost_analysis.h"
#include "xla/service/latency_hiding_scheduler.h"
#include "xla/service/profile_guided_latency_estimator.h"
#include "xla/shape.h"
Expand All @@ -31,8 +32,10 @@ namespace gpu {
// E.g. AllReduceStart is broken down into Reduce + AsyncStart.
CanonicalAsyncOp GpuGetCanonicalAsyncOp(const HloInstruction& hlo);

// Returns size of the `shape` given the `pointer_size`.
int64_t GetSizeOfShape(const Shape& shape, int pointer_size);
// The shape size function depending on the pointer size and
// memory space.
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction(
int64_t pointer_size, std::optional<int64_t> memory_space = std::nullopt);

// GPU overlap limit rule rule for scheduling candidate.
// On top of the default rule, we do not allow collectives with more than 1
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/nvptx_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class NVPTXCompilerTest : public HloTestBase {

auto buffer_size_bytes_function =
[](const BufferValue& buffer_value) -> int64_t {
return GetSizeOfShape(buffer_value.shape(), pointer_size);
return ShapeSizeBytesFunction(pointer_size)(buffer_value.shape());
};

return BufferAssigner::Run(
Expand Down

0 comments on commit 0eb5512

Please sign in to comment.