Skip to content

Commit 58f5263

Browse files
committed
update
1 parent 7a9c595 commit 58f5263

File tree

4 files changed

+103
-29
lines changed

4 files changed

+103
-29
lines changed

src/xccl/ProcessGroupXCCL.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -267,11 +267,9 @@ ProcessGroupXCCL::WorkXCCL::WorkXCCL(
267267
: nullptr;
268268
xcclEndEvent_ = XPUEventCache::get(device.index())->create(enableTiming);
269269
} else {
270-
xcclStartEvent_ = enableTiming
271-
? std::make_shared<at::xpu::XPUEvent>(xpuEventDefault)
272-
: nullptr;
273-
xcclEndEvent_ = std::make_shared<at::xpu::XPUEvent>(
274-
enableTiming ? xpuEventDefault : xpuEventDisableTiming);
270+
xcclStartEvent_ =
271+
enableTiming ? std::make_shared<at::xpu::XPUEvent>(1) : nullptr;
272+
xcclEndEvent_ = std::make_shared<at::xpu::XPUEvent>(enableTiming ? 1 : 0);
275273
}
276274
stashed_for_allocator_safety_ = std::make_shared<TensorShelf>();
277275
}
@@ -902,7 +900,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
902900
auto cclstream = xcclStreamsMap_.at(key).second;
903901
syncStream(device, xcclEventsMap_[key], stream);
904902

905-
c10::intrusive_ptr<ProcessGroupNCCL::WorkXCCL> work;
903+
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work;
906904
if (!coalescing_state_) {
907905
work = initWork(device, rank_, opType, true, profilingTitle, {tensor}, {});
908906
work->outputs_ = std::make_shared<std::vector<at::Tensor>>();
@@ -944,9 +942,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
944942
checkForNan(tensor, stream);
945943
}
946944
if (!coalescing_state_) {
947-
// Start event should only be recorded before the ncclGroupStart()
948945
if (work->timingEnabled_) {
949-
work->ncclStartEvent_->record(stream);
946+
work->xcclStartEvent_->record(stream);
950947
}
951948

952949
pre(stream, work);
@@ -956,10 +953,10 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
956953
c10::xpu::XPUCachingAllocator::recordStream(
957954
tensor.storage().data_ptr(), stream);
958955

959-
xcclGroupStart();
956+
ccl::group_start();
960957
fn(tensor, *comm, stream, cclstream, p2pTargetRank);
961-
xcclGroupEnd();
962-
958+
ccl::group_end();
959+
963960
if (!coalescing_state_) {
964961
post(stream);
965962

src/xccl/ProcessGroupXCCL.hpp

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <torch/csrc/distributed/c10d/TraceUtils.h>
2323
#include <torch/csrc/distributed/c10d/logger.hpp>
2424
#include <xccl/ProcessGroupXCCLMonitor.hpp>
25+
#include <xccl/XPUEventCache.hpp>
2526
namespace c10d {
2627

2728
static std::vector<std::string> TORCH_XCCL_HIGH_PRIORITY = {
@@ -73,7 +74,9 @@ class TORCH_API ProcessGroupXCCL : public Backend {
7374
uint64_t seq,
7475
bool isP2P,
7576
const char* profilingTitle = nullptr,
76-
const std::optional<std::vector<at::Tensor>>& inputs = std::nullopt);
77+
const std::optional<std::vector<at::Tensor>>& inputs = std::nullopt,
78+
bool enableTiming = false,
79+
bool xpuEventCacheEnabled = false);
7780
WorkXCCL(const WorkXCCL& w);
7881
~WorkXCCL() override;
7982

@@ -87,6 +90,8 @@ class TORCH_API ProcessGroupXCCL : public Backend {
8790

8891
void synchronizeStream();
8992

93+
float getDuration() const override;
94+
9095
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;
9196

9297
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
@@ -308,23 +313,23 @@ class TORCH_API ProcessGroupXCCL : public Backend {
308313
/*nanCheck =*/false);
309314
}
310315

311-
template <typename Fn>
312-
c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
313-
at::Tensor& tensor,
314-
Fn fn,
315-
int peer,
316-
OpType opType,
317-
const char* profilingTitle) {
318-
return pointToPoint(
319-
tensor,
320-
fn,
321-
peer,
322-
opType,
323-
[](at::xpu::XPUStream&,
324-
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL>& work) {},
325-
[](at::xpu::XPUStream&) {},
326-
profilingTitle);
327-
}
316+
template <typename Fn>
317+
c10::intrusive_ptr<Work> pointToPoint(
318+
at::Tensor& tensor,
319+
Fn fn,
320+
int peer,
321+
OpType opType,
322+
const char* profilingTitle) {
323+
return pointToPoint(
324+
tensor,
325+
fn,
326+
peer,
327+
opType,
328+
[](at::xpu::XPUStream&,
329+
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL>& work) {},
330+
[](at::xpu::XPUStream&) {},
331+
profilingTitle);
332+
}
328333

329334
template <typename Fn, typename PreProcess, typename PostProcess>
330335
c10::intrusive_ptr<Work> pointToPoint(

src/xccl/XPUEventCache.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#include <c10/xpu/XPUStream.h>
2+
#include <xccl/XPUEventCache.hpp>
3+
#include <map>
4+
5+
namespace c10d {
6+
7+
XPUEventCache::XPUEventCache() = default;
8+
9+
std::shared_ptr<at::xpu::XPUEvent> XPUEventCache::create(bool timing) {
10+
auto deleter = [cache = shared_from_this(),
11+
timing](at::xpu::XPUEvent* event) {
12+
std::lock_guard<std::mutex> lock(cache->cacheMutex_);
13+
14+
cache->eventsArray_[timing ? 1 : 0].push_back(event);
15+
};
16+
at::xpu::XPUEvent* event = nullptr;
17+
{
18+
std::lock_guard<std::mutex> lock(cacheMutex_);
19+
auto& events = eventsArray_[timing ? 1 : 0];
20+
// If we still have events in the cache, we reuse it. Otherwise, we create a
21+
// new one.
22+
if (!events.empty()) {
23+
event = events.front();
24+
events.pop_front();
25+
} else {
26+
event = new at::xpu::XPUEvent(timing ? 1 : 0);
27+
}
28+
}
29+
return std::shared_ptr<at::xpu::XPUEvent>(event, std::move(deleter));
30+
}
31+
32+
std::shared_ptr<XPUEventCache> XPUEventCache::get(at::DeviceIndex device) {
33+
static thread_local std::map<at::DeviceIndex, std::shared_ptr<XPUEventCache>>
34+
cacheDeviceMap;
35+
// Check if device has already been in the map, if not, add a new entry
36+
auto it = cacheDeviceMap.find(device);
37+
if (it == cacheDeviceMap.end()) {
38+
cacheDeviceMap.emplace(device, std::make_shared<XPUEventCache>());
39+
}
40+
return cacheDeviceMap[device];
41+
}
42+
43+
} // namespace c10d

src/xccl/XPUEventCache.hpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
3+
#include <array>
4+
#include <deque>
5+
#include <memory>
6+
#include <mutex>
7+
8+
#include <ATen/xpu/XPUEvent.h>
9+
#include <c10/macros/Export.h>
10+
11+
namespace c10d {
12+
13+
class TORCH_API XPUEventCache
14+
: public std::enable_shared_from_this<XPUEventCache> {
15+
public:
16+
XPUEventCache();
17+
std::shared_ptr<at::xpu::XPUEvent> create(bool timing);
18+
static std::shared_ptr<XPUEventCache> get(at::DeviceIndex device);
19+
20+
private:
21+
std::mutex cacheMutex_;
22+
// NOTE: We intentionally store raw pointers so that
23+
// we do not attempt to destroy the event objects on process exit,
24+
// because cuda may be gone.
25+
std::array<std::deque<at::xpu::XPUEvent*>, 2>
26+
eventsArray_; // 0 for timing=false, 1 for timing=true
27+
};
28+
29+
} // namespace c10d

0 commit comments

Comments
 (0)