Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 16 additions & 17 deletions paddle/fluid/pybind/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ void BindTensor(pybind11::module &m) { // NOLINT
tensor dims, lod information, device index.

)DOC")
.def("_share_device_ipc",
.def("_share_cuda",
[](phi::DenseTensor self) {
if (!self.IsInitialized() || self.numel() == 0)
throw std::runtime_error(
Expand Down Expand Up @@ -1053,9 +1053,9 @@ void BindTensor(pybind11::module &m) { // NOLINT
>>> import paddle

>>> tensor = paddle.ones([3,3])
>>> metainfo = tensor.value().get_tensor()._share_device_ipc()
>>> metainfo = tensor.value().get_tensor()._share_cuda()
)DOC")
.def("_new_from_ipc",
.def("_new_shared_cuda",
[](py::tuple t) {
if (FLAGS_use_virtual_memory_auto_growth && t.size() == 5) {
return RebuildTensorFromVmmMeta(t);
Expand Down Expand Up @@ -1102,8 +1102,8 @@ void BindTensor(pybind11::module &m) { // NOLINT
>>> import paddle

>>> tensor = paddle.ones([3,3])
>>> metainfo = tensor.value().get_tensor()._share_device_ipc()
>>> tensor_from_shared = paddle.to_tensor(paddle.base.core.DenseTensor._new_from_ipc(metainfo))
>>> metainfo = tensor.value().get_tensor()._share_cuda()
>>> tensor_from_shared = paddle.to_tensor(paddle.base.core.DenseTensor._new_shared_cuda(metainfo))
)DOC")
#endif
#ifdef PADDLE_WITH_XPU
Expand Down Expand Up @@ -1157,7 +1157,7 @@ void BindTensor(pybind11::module &m) { // NOLINT
tuple: contains data size, data type, tensor dims, lod
information, device index.
)DOC")
.def("_share_device_ipc",
.def("_share_xpu",
[](phi::DenseTensor &self) {
if (!self.IsInitialized() || self.numel() == 0)
throw std::runtime_error(
Expand All @@ -1167,7 +1167,7 @@ void BindTensor(pybind11::module &m) { // NOLINT
// Get the current device ID.
int dev_id = platform::GetXPUCurrentDeviceId();
paddle::platform::SetXPUDeviceId(dev_id);
VLOG(6) << "[DEBUG XPU] _share_device_ipc: current XPU device = "
VLOG(6) << "[DEBUG XPU] _share_xpu: current XPU device = "
<< dev_id;

auto *holder = dynamic_cast<memory::allocation::Allocation *>(
Expand All @@ -1180,18 +1180,18 @@ void BindTensor(pybind11::module &m) { // NOLINT
void *base_ptr = holder->base_ptr();
ptrdiff_t offset_bytes = reinterpret_cast<char *>(holder->ptr()) -
reinterpret_cast<char *>(base_ptr);
VLOG(6) << "[DEBUG XPU] _share_device_ipc: base_ptr = " << base_ptr
VLOG(6) << "[DEBUG XPU] _share_xpu: base_ptr = " << base_ptr
<< ", offset_bytes = " << offset_bytes;
cudaIpcMemHandle_t handle;
int ret = cudaIpcGetMemHandle(&handle, base_ptr);
VLOG(6) << "[DEBUG XPU] _share_device_ipc: "
<< "cudaIpcGetMemHandle returned: " << ret;
VLOG(6) << "[DEBUG XPU] _share_xpu: cudaIpcGetMemHandle returned: "
<< ret;
PADDLE_ENFORCE_XPU_SUCCESS(ret);
// Use the correct size for the IPC handle.
auto _handle = py::bytes(
reinterpret_cast<char *>(&handle),
(py::ssize_t)sizeof(cudaIpcMemHandle_t));
VLOG(6) << "[DEBUG XPU] _share_device_ipc: IPC handle (bytes) = "
VLOG(6) << "[DEBUG XPU] _share_xpu: IPC handle (bytes) = "
<< _handle;
const auto &device_id =
paddle::platform::GetXPUCurrentDeviceId();
Expand All @@ -1201,8 +1201,7 @@ void BindTensor(pybind11::module &m) { // NOLINT
size_t data_size = self.numel() *
framework::SizeOfType(
framework::TransToProtoVarType(self.type()));
VLOG(6) << "[DEBUG XPU] _share_device_ipc: data_size = "
<< data_size;
VLOG(6) << "[DEBUG XPU] _share_xpu: data_size = " << data_size;
return py::make_tuple(_handle,
(py::size_t)offset_bytes,
data_size,
Expand All @@ -1218,7 +1217,7 @@ void BindTensor(pybind11::module &m) { // NOLINT
tuple: contains handle, offset, data size, data type,
tensor dims, lod information, and device id.
)DOC")
.def("_new_from_ipc",
.def("_new_shared_xpu",
[](py::tuple t) {
if (t.size() != 7)
throw std::runtime_error(
Expand All @@ -1227,14 +1226,14 @@ void BindTensor(pybind11::module &m) { // NOLINT
// Get the current device ID.
int dev_id = platform::GetXPUCurrentDeviceId();
paddle::platform::SetXPUDeviceId(dev_id);
VLOG(6) << "[DEBUG XPU] _new_from_ipc: current XPU device = "
VLOG(6) << "[DEBUG XPU] _new_shared_xpu: current XPU device = "
<< dev_id;

phi::DenseTensor tensor;
const std::string &handle = t[0].cast<std::string>();
ptrdiff_t offset_bytes = (ptrdiff_t)t[1].cast<int64_t>();
auto device_id = t[6].cast<int>();
VLOG(6) << "[DEBUG XPU] _new_from_ipc: handle = " << handle
VLOG(6) << "[DEBUG XPU] _new_shared_xpu: handle = " << handle
<< ", offset_bytes = " << offset_bytes;
auto base_ptr = memory::allocation::GetIpcBasePtr(handle);
size_t size = t[2].cast<size_t>();
Expand All @@ -1248,7 +1247,7 @@ void BindTensor(pybind11::module &m) { // NOLINT
static_cast<phi::DataType>(t[3].cast<int>()));
tensor.Resize(common::make_ddim(
t[4].cast<std::vector<int>>()));
VLOG(6) << "[DEBUG XPU] _new_from_ipc: Reshape tensor dims: "
VLOG(6) << "[DEBUG XPU] _new_shared_xpu: Reshape tensor dims: "
<< tensor.dims();
return tensor;
},
Expand Down
6 changes: 2 additions & 4 deletions python/paddle/distributed/fleet/utils/tensor_fusion_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@
def _share_tensor_ipc_meta(tensor):
if tensor is None:
return None
if (
core.is_compiled_with_cuda() or core.is_compiled_with_xpu()
) and not core.is_compiled_with_rocm():
return tensor.value().get_tensor()._share_device_ipc()
if core.is_compiled_with_cuda() and not core.is_compiled_with_rocm():
return tensor.value().get_tensor()._share_cuda()
return None


Expand Down
37 changes: 30 additions & 7 deletions python/paddle/incubate/multiprocessing/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _rebuild_tensor(cls, lodtensor, metadata):
def _rebuild_vmm_tensor(
cls, blob: bytes, dtype_idx: int, dims: list[int], lod, device: int
):
lodtensor = cls._new_from_ipc((blob, dtype_idx, dims, lod, device))
lodtensor = cls._new_shared_cuda((blob, dtype_idx, dims, lod, device))
return lodtensor


Expand Down Expand Up @@ -176,12 +176,35 @@ def _rebuild_lodtensor_filedescriptor(
return lodtensor


def _rebuild_device_tensor(
def _rebuild_cuda_tensor(
cls, handle, offset_bytes, size, type_idx, dims, lod, device_idx
):
cache_tensor = _cuda_from_cache((handle, offset_bytes))
if cache_tensor is None:
lodtensor = cls._new_from_ipc(
lodtensor = cls._new_shared_cuda(
(handle, offset_bytes, size, type_idx, dims, lod, device_idx)
)
# We only cache cuda shared tensor here.
# The opening cost of cudaIpcMemoryHandle is very high.
# Since we cache the received tensor directly,
# The sender may reallocate the tensor space,
# you should manually maintain the lifecycle of ipc tensor
shared_cache[(handle, offset_bytes)] = lodtensor
else:
lodtensor = paddle.base.core.DenseTensor()
lodtensor._share_buffer_with(
cache_tensor, (size, type_idx, dims, lod, device_idx)
)

return lodtensor


def _rebuild_xpu_tensor(
cls, handle, offset_bytes, size, type_idx, dims, lod, device_idx
):
cache_tensor = _cuda_from_cache((handle, offset_bytes))
if cache_tensor is None:
lodtensor = cls._new_shared_xpu(
(handle, offset_bytes, size, type_idx, dims, lod, device_idx)
)
# We only cache cuda shared tensor here.
Expand Down Expand Up @@ -237,17 +260,17 @@ def _reduce_lodtensor(lodtensor):
if prev_id != cur_id:
paddle.base.core.set_cuda_current_device_id(cur_id)
try:
metadata = lodtensor._share_device_ipc()
metadata = lodtensor._share_cuda()
if len(metadata) == 5:
rebuild = _rebuild_vmm_tensor
else:
rebuild = _rebuild_device_tensor
rebuild = _rebuild_cuda_tensor
finally:
if prev_id != cur_id:
paddle.base.core.set_cuda_current_device_id(prev_id)
elif lodtensor._place().is_xpu_place():
metadata = lodtensor._share_device_ipc()
rebuild = _rebuild_device_tensor
metadata = lodtensor._share_xpu()
rebuild = _rebuild_xpu_tensor
else:
raise RuntimeError(
"We only support pass cpu/gpu/xpu lodtensor for now!"
Expand Down
11 changes: 5 additions & 6 deletions python/paddle/optimizer/fusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,8 @@
def _share_tensor_ipc_meta(tensor):
if tensor is None:
return None
if (
paddle.core.is_compiled_with_cuda()
or paddle.core.is_compiled_with_xpu()
) and not paddle.core.is_compiled_with_rocm():
return tensor.value().get_tensor()._share_device_ipc()
if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm():
return tensor.value().get_tensor()._share_cuda()
return None


Expand Down Expand Up @@ -245,7 +242,9 @@ def reset_meta(
assert len(buffer_ipc_meta) in (5, 7), (
"buffer_ipc_meta must be a tuple with length 5 when FLAGS_use_virtual_memory_auto_growth is True or 7 when FLAGS_use_virtual_memory_auto_growth is False."
)
new_tensor = paddle.base.core.DenseTensor._new_from_ipc(buffer_ipc_meta)
new_tensor = paddle.base.core.DenseTensor._new_shared_cuda(
buffer_ipc_meta
)

self.buffer = paddle.to_tensor(new_tensor)
self.cpu_buffer = self.buffer.pin_memory()
Expand Down
10 changes: 5 additions & 5 deletions test/legacy_test/test_cuda_vmm_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def _vmm_runtime_available() -> bool:
return False
try:
tensor = paddle.randn([32], dtype="float32")
meta = tensor.get_tensor()._share_device_ipc()
rebuilt = paddle.base.core.DenseTensor._new_from_ipc(meta)
meta = tensor.get_tensor()._share_cuda()
rebuilt = paddle.base.core.DenseTensor._new_shared_cuda(meta)
_ = paddle.to_tensor(rebuilt)
_VMM_RUNTIME_AVAILABLE = True
except Exception:
Expand Down Expand Up @@ -138,7 +138,7 @@ def test_reduce_scatter_buffer_uses_vmm(self):
values = paddle.arange(param_storage.numel(), dtype=param_storage.dtype)
values_md5sum = values._md5sum()
param_storage.set_value(values)
imported = paddle.base.core.DenseTensor._new_from_ipc(refreshed_meta)
imported = paddle.base.core.DenseTensor._new_shared_cuda(refreshed_meta)
imported_tensor = paddle.to_tensor(imported)
np.testing.assert_allclose(imported_tensor.numpy(), values.numpy())
del imported_tensor
Expand Down Expand Up @@ -167,7 +167,7 @@ def test_reduce_scatter_meta_refresh_after_tensor_swap(self):
fused_comm_buffer._param_buffer_meta_tensor = param_storage
meta_a = fused_comm_buffer.param_buffer_ipc_meta
imported_a = paddle.to_tensor(
paddle.base.core.DenseTensor._new_from_ipc(meta_a)
paddle.base.core.DenseTensor._new_shared_cuda(meta_a)
)
np.testing.assert_allclose(
imported_a.numpy(), param_storage.numpy(), rtol=0, atol=0
Expand All @@ -177,7 +177,7 @@ def test_reduce_scatter_meta_refresh_after_tensor_swap(self):
fused_comm_buffer._param_buffer_meta_tensor = new_storage
meta_b = fused_comm_buffer.param_buffer_ipc_meta
imported_b = paddle.to_tensor(
paddle.base.core.DenseTensor._new_from_ipc(meta_b)
paddle.base.core.DenseTensor._new_shared_cuda(meta_b)
)
np.testing.assert_allclose(
imported_b.numpy(), new_storage.numpy(), rtol=0, atol=0
Expand Down
4 changes: 2 additions & 2 deletions test/legacy_test/test_paddle_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def check_ipc_tensor(event, ipc_metas):
ground_truth1 = paddle.to_tensor([1, 2, 3])
ground_truth2 = paddle.to_tensor([3, 4, 5])
shared_ipc_tensor = paddle.to_tensor(
paddle.base.core.DenseTensor._new_from_ipc(ipc_metas)
paddle.base.core.DenseTensor._new_shared_cuda(ipc_metas)
)
paddle.cuda.ipc_collect()

Expand Down Expand Up @@ -235,7 +235,7 @@ def test_ipc_tensor(self):
paddle.device.set_device(get_device())
initial_tensor = paddle.to_tensor([1, 2, 3])
bonus = paddle.to_tensor([2])
ipc_metas = initial_tensor.value().get_tensor()._share_device_ipc()
ipc_metas = initial_tensor.value().get_tensor()._share_cuda()
ctx = mp.get_context("spawn")
event = ctx.Event()
process = ctx.Process(target=check_ipc_tensor, args=(event, ipc_metas))
Expand Down
Loading
Loading