Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support PyTorch CUDACachingAllocator #12

Merged
merged 2 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ cc_library(
":profiler",
":sys_util",
":tf_logging",
":torch_allocator",
":xla_coordinator",
"@xla//xla/service:gpu_plugin",
"@xla//xla/pjrt/gpu:se_gpu_pjrt_client",
Expand Down Expand Up @@ -371,6 +372,18 @@ cc_library(
],
)

cc_library(
name = "torch_allocator",
srcs = ["torch_allocator.cc"],
hdrs = ["torch_allocator.h"],
deps = [
":tf_logging",
"@tsl//tsl/framework:allocator",
"@torch//:headers",
"@xla//xla/stream_executor/gpu:gpu_types_header",
],
)

cc_library(
name = "tensor_source",
hdrs = ["tensor_source.h"],
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/env_vars.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const char* const kEnvPjrtAllocatorCudaAsync = "PJRT_ALLOCATOR_CUDA_ASYNC";
const char* const kEnvPjrtAllocatorPreallocate = "PJRT_ALLOCATOR_PREALLOCATE";
const char* const kEnvPjrtAllocatorFraction = "PJRT_ALLOCATOR_FRACTION";
const char* const kEnvPjrtDynamicPlugins = "PJRT_DYNAMIC_PLUGINS";
const char* const kEnvPjrtUseTorchAllocator = "PJRT_USE_TORCH_ALLOCATOR";

} // namespace env
} // namespace runtime
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/env_vars.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ extern const char* const kEnvPjrtAllocatorCudaAsync;
extern const char* const kEnvPjrtAllocatorPreallocate;
extern const char* const kEnvPjrtAllocatorFraction;
extern const char* const kEnvPjrtDynamicPlugins;
extern const char* const kEnvPjrtUseTorchAllocator;

} // namespace env
} // namespace runtime
Expand Down
128 changes: 127 additions & 1 deletion torch_xla/csrc/runtime/pjrt_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "torch_xla/csrc/runtime/profiler.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/torch_allocator.h"
#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/distributed/client.h"
Expand All @@ -14,6 +15,14 @@
#include "xla/pjrt/pjrt_api.h"
#include "xla/pjrt/pjrt_c_api_client.h"
#include "xla/pjrt/tfrt_cpu_pjrt_client.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/stream_executor/integrations/device_mem_allocator.h"
#include "xla/stream_executor/integrations/tf_allocator_adapter.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor.h"

namespace torch_xla {
namespace runtime {
Expand Down Expand Up @@ -54,6 +63,117 @@ void RegisterPjRtPlugin(std::string name,
pjrt_plugins_[name] = plugin;
}

// Copied from openxla's
// xla/pjrt/gpu/se_gpu_pjrt_client.cc::BuildLocalDeviceStates
absl::StatusOr<std::map<int, std::unique_ptr<xla::LocalDeviceState>>>
BuildLocalDeviceStates(xla::LocalClient* xla_client) {
std::map<int, std::unique_ptr<xla::LocalDeviceState>> addressable_devices;
for (stream_executor::StreamExecutor* executor :
xla_client->backend().stream_executors()) {
addressable_devices.emplace(
executor->device_ordinal(),
std::make_unique<xla::LocalDeviceState>(
executor, xla_client, xla::LocalDeviceState::kComputeSynchronized,
/*max_inflight_computations=*/32,
/*allow_event_reuse=*/true, /*use_callback_stream=*/true));
}
return std::move(addressable_devices);
}

// Modified from openxla's
// xla/pjrt/gpu/se_gpu_pjrt_client.cc::GetStreamExecutorGpuDeviceAllocator
// change to use torch allocator
absl::StatusOr<std::unique_ptr<stream_executor::DeviceMemoryAllocator>>
GetTorchAllocator(stream_executor::Platform* platform,
const xla::GpuAllocatorConfig& allocator_config,
const std::map<int, std::unique_ptr<xla::LocalDeviceState>>&
addressable_devices) {
std::vector<stream_executor::MultiDeviceAdapter::AllocatorInfo> allocators;
LOG(INFO) << "Using PyTorch CUDACachingAllocator.";
for (const auto& ordinal_and_device : addressable_devices) {
stream_executor::StreamExecutor* executor =
ordinal_and_device.second->executor();
int device_ordinal = executor->device_ordinal();
auto allocator =
std::make_unique<TorchCUDACachingAllocator>(device_ordinal);
allocator->SetStreamAndPreallocateMemory(
ordinal_and_device.second->compute_stream()
->platform_specific_handle()
.stream);
allocators.emplace_back(std::move(allocator),
ordinal_and_device.second->compute_stream(),
/*memory_space=*/0);
}

// Add any additional allocators for alternate memory spaces.
for (const auto& ordinal_and_device : addressable_devices) {
TF_ASSIGN_OR_RETURN(
auto collective_bfc_allocator,
xla::CreateCollectiveBFCAllocator(
ordinal_and_device.second->executor(),
/*memory_fraction=*/1.0 - allocator_config.memory_fraction,
allocator_config.collective_memory_size));
allocators.emplace_back(std::move(collective_bfc_allocator),
ordinal_and_device.second->compute_stream(),
/*memory_space=*/1);
}

for (const auto& ordinal_and_device : addressable_devices) {
auto host_allocator =
xla::GetGpuHostAllocator(ordinal_and_device.second->executor());
allocators.emplace_back(
std::move(host_allocator), ordinal_and_device.second->compute_stream(),
/*memory_space=*/
static_cast<int>(stream_executor::MemoryType::kHost));
}

return std::make_unique<stream_executor::MultiDeviceAdapter>(
platform, std::move(allocators));
}

// Modified from xla::GetStreamExecutorGpuClient, change to use torch allocator
absl::StatusOr<std::unique_ptr<xla::PjRtClient>>
GetPjRtClientWithTorchAllocator(const xla::GpuClientOptions& options) {
auto pjrt_platform_name = xla::CudaName();

TF_ASSIGN_OR_RETURN(
xla::LocalClient * xla_client,
xla::GetGpuXlaClient(options.platform_name, options.allowed_devices));
std::map<int, std::unique_ptr<xla::LocalDeviceState>> local_device_states;
TF_ASSIGN_OR_RETURN(local_device_states, BuildLocalDeviceStates(xla_client));
xla::EnablePeerAccess(xla_client->backend().stream_executors());

TF_ASSIGN_OR_RETURN(
auto allocator,
GetTorchAllocator(xla_client->platform(), options.allocator_config,
local_device_states));

auto host_memory_allocator =
xla::GetGpuHostAllocator(local_device_states.begin()->second->executor());

std::vector<std::unique_ptr<xla::PjRtStreamExecutorDevice>> devices;
auto gpu_run_options = std::make_unique<xla::gpu::GpuExecutableRunOptions>();
if (options.enable_mock_nccl) {
gpu_run_options->set_enable_mock_nccl_collectives();
}
std::shared_ptr<xla::KeyValueStoreInterface> kv_store = options.kv_store;
if (options.enable_mock_nccl) {
kv_store = std::make_shared<xla::InMemoryKeyValueStore>();
}
TF_RET_CHECK(options.num_nodes == 1 || kv_store != nullptr);
TF_RETURN_IF_ERROR(xla::BuildDistributedDevices(
pjrt_platform_name, std::move(local_device_states), options.node_id,
options.num_nodes, &devices, gpu_run_options.get(), kv_store,
options.enable_mock_nccl));

return std::unique_ptr<xla::PjRtClient>(
std::make_unique<xla::StreamExecutorGpuClient>(
pjrt_platform_name, xla_client, std::move(devices), options.node_id,
std::move(allocator), std::move(host_memory_allocator),
options.should_stage_host_to_device_transfers,
std::move(gpu_run_options)));
}

std::tuple<std::unique_ptr<xla::PjRtClient>, std::unique_ptr<XlaCoordinator>>
InitializePjRt(const std::string& device_type) {
std::unique_ptr<xla::PjRtClient> client;
Expand Down Expand Up @@ -167,7 +287,13 @@ InitializePjRt(const std::string& device_type) {
options.platform_name = "gpu";
options.should_stage_host_to_device_transfers = true;
options.kv_store = kv_store;
client = std::move(xla::GetStreamExecutorGpuClient(options).value());
bool use_torch_allocator =
sys_util::GetEnvBool(env::kEnvPjrtUseTorchAllocator, false);
if (use_torch_allocator) {
client = std::move(GetPjRtClientWithTorchAllocator(options).value());
} else {
client = std::move(xla::GetStreamExecutorGpuClient(options).value());
}
} else if (device_type == "XPU") {
TF_VLOG(1) << "Initializing PjRt XPU client...";
XLA_CHECK_OK(
Expand Down
48 changes: 48 additions & 0 deletions torch_xla/csrc/runtime/torch_allocator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#include "torch_xla/csrc/runtime/torch_allocator.h"

#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAGuard.h>

#include "torch_xla/csrc/runtime/tf_logging.h"

namespace torch_xla {
namespace runtime {

TorchCUDACachingAllocator::TorchCUDACachingAllocator(int device_ordinal) {
VLOG(3) << "Creating TorchCUDACachingAllocator for device " << device_ordinal;
name_ = c10::cuda::CUDACachingAllocator::name();
cuda_stream_ = nullptr;
device_index_ = static_cast<c10::DeviceIndex>(device_ordinal);
}

void* TorchCUDACachingAllocator::AllocateRaw(size_t alignment,
size_t num_bytes) {
CHECK(cuda_stream_ != nullptr)
<< "A stream must be added to the TorchCUDACachingAllocator allocator";
if (num_bytes == 0) {
return nullptr;
}
at::cuda::CUDAGuard device_guard{device_index_};
auto ptr = c10::cuda::CUDACachingAllocator::raw_alloc_with_stream(
num_bytes, cuda_stream_);
VLOG(3) << "Alloc num_bytes " << num_bytes << " with ptr " << ptr
<< " for device " << static_cast<int>(device_index_);
return ptr;
}

void TorchCUDACachingAllocator::DeallocateRaw(void* ptr) {
VLOG(3) << "Dealloc ptr " << ptr << " for device "
<< static_cast<int>(device_index_);
c10::cuda::CUDACachingAllocator::raw_delete(ptr);
}

void TorchCUDACachingAllocator::SetStreamAndPreallocateMemory(void* stream) {
auto new_cuda_stream = static_cast<cudaStream_t>(stream);
VLOG(3) << "Setting cuda stream " << stream
<< " for TorchCUDACachingAllocator on device "
<< static_cast<int>(device_index_);
cuda_stream_ = new_cuda_stream;
}

} // namespace runtime
} // namespace torch_xla
37 changes: 37 additions & 0 deletions torch_xla/csrc/runtime/torch_allocator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#ifndef XLA_CLIENT_TORCH_ALLOCATOR_H_
#define XLA_CLIENT_TORCH_ALLOCATOR_H_

#include <c10/cuda/CUDAStream.h>
#include <cuda_runtime_api.h>

#include "tsl/framework/allocator.h"

namespace torch_xla {
namespace runtime {

class TorchCUDACachingAllocator : public tsl::Allocator {
public:
TorchCUDACachingAllocator(int device_ordinal);
~TorchCUDACachingAllocator() override{};

std::string Name() override { return name_; }

void* AllocateRaw(size_t alignment, size_t num_bytes) override;
void DeallocateRaw(void* ptr) override;

void SetStreamAndPreallocateMemory(void* stream) override;

tsl::AllocatorMemoryType GetMemoryType() const override {
return tsl::AllocatorMemoryType::kDevice;
}

private:
std::string name_;
cudaStream_t cuda_stream_;
c10::DeviceIndex device_index_;
};

} // namespace runtime
} // namespace torch_xla

#endif // XLA_CLIENT_TORCH_ALLOCATOR_H_
Loading