Skip to content

Commit 0c69f08

Browse files
committed
update
1 parent 7a9c595 commit 0c69f08

File tree

4 files changed

+93
-18
lines changed

4 files changed

+93
-18
lines changed

src/xccl/ProcessGroupXCCL.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -959,7 +959,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
959959
xcclGroupStart();
960960
fn(tensor, *comm, stream, cclstream, p2pTargetRank);
961961
xcclGroupEnd();
962-
962+
963963
if (!coalescing_state_) {
964964
post(stream);
965965

src/xccl/ProcessGroupXCCL.hpp

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <torch/csrc/distributed/c10d/TraceUtils.h>
2323
#include <torch/csrc/distributed/c10d/logger.hpp>
2424
#include <xccl/ProcessGroupXCCLMonitor.hpp>
25+
#include <xccl/XPUEventCache.hpp>
2526
namespace c10d {
2627

2728
static std::vector<std::string> TORCH_XCCL_HIGH_PRIORITY = {
@@ -308,23 +309,23 @@ class TORCH_API ProcessGroupXCCL : public Backend {
308309
/*nanCheck =*/false);
309310
}
310311

311-
template <typename Fn>
312-
c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
313-
at::Tensor& tensor,
314-
Fn fn,
315-
int peer,
316-
OpType opType,
317-
const char* profilingTitle) {
318-
return pointToPoint(
319-
tensor,
320-
fn,
321-
peer,
322-
opType,
323-
[](at::xpu::XPUStream&,
324-
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL>& work) {},
325-
[](at::xpu::XPUStream&) {},
326-
profilingTitle);
327-
}
312+
template <typename Fn>
313+
c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
314+
at::Tensor& tensor,
315+
Fn fn,
316+
int peer,
317+
OpType opType,
318+
const char* profilingTitle) {
319+
return pointToPoint(
320+
tensor,
321+
fn,
322+
peer,
323+
opType,
324+
[](at::xpu::XPUStream&,
325+
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL>& work) {},
326+
[](at::xpu::XPUStream&) {},
327+
profilingTitle);
328+
}
328329

329330
template <typename Fn, typename PreProcess, typename PostProcess>
330331
c10::intrusive_ptr<Work> pointToPoint(
@@ -441,6 +442,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
441442

442443
uint64_t getSequenceNumberForGroup() override;
443444

445+
float getDuration() const override;
446+
444447
std::string createLogPrefix() const;
445448

446449
const std::string& logPrefix() const;

src/xccl/XPUEventCache.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#include <c10/xpu/XPUStream.h>
2+
#include <xccl/XPUEventCache.hpp>
3+
#include <map>
4+
5+
namespace c10d {
6+
7+
XPUEventCache::XPUEventCache() = default;
8+
9+
std::shared_ptr<at::xpu::XPUEvent> XPUEventCache::create(bool timing) {
10+
auto deleter = [cache = shared_from_this(),
11+
timing](at::xpu::XPUEvent* event) {
12+
std::lock_guard<std::mutex> lock(cache->cacheMutex_);
13+
14+
cache->eventsArray_[timing ? 1 : 0].push_back(event);
15+
};
16+
at::xpu::XPUEvent* event = nullptr;
17+
{
18+
std::lock_guard<std::mutex> lock(cacheMutex_);
19+
auto& events = eventsArray_[timing ? 1 : 0];
20+
// If we still have events in the cache, we reuse it. Otherwise, we create a
21+
// new one.
22+
if (!events.empty()) {
23+
event = events.front();
24+
events.pop_front();
25+
} else {
26+
event = new at::xpu::XPUEvent(timing ? 1 : 0);
27+
}
28+
}
29+
return std::shared_ptr<at::xpu::XPUEvent>(event, std::move(deleter));
30+
}
31+
32+
std::shared_ptr<XPUEventCache> XPUEventCache::get(at::DeviceIndex device) {
33+
static thread_local std::map<at::DeviceIndex, std::shared_ptr<XPUEventCache>>
34+
cacheDeviceMap;
35+
// Check if device has already been in the map, if not, add a new entry
36+
auto it = cacheDeviceMap.find(device);
37+
if (it == cacheDeviceMap.end()) {
38+
cacheDeviceMap.emplace(device, std::make_shared<XPUEventCache>());
39+
}
40+
return cacheDeviceMap[device];
41+
}
42+
43+
} // namespace c10d

src/xccl/XPUEventCache.hpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
3+
#include <array>
4+
#include <deque>
5+
#include <memory>
6+
#include <mutex>
7+
8+
#include <ATen/xpu/XPUEvent.h>
9+
#include <c10/macros/Export.h>
10+
11+
namespace c10d {
12+
13+
class TORCH_API XPUEventCache
14+
: public std::enable_shared_from_this<XPUEventCache> {
15+
public:
16+
XPUEventCache();
17+
std::shared_ptr<at::xpu::XPUEvent> create(bool timing);
18+
static std::shared_ptr<XPUEventCache> get(at::DeviceIndex device);
19+
20+
private:
21+
std::mutex cacheMutex_;
22+
// NOTE: We intentionally store raw pointers so that
23+
// we do not attempt to destroy the event objects on process exit,
24+
// because cuda may be gone.
25+
std::array<std::deque<at::xpu::XPUEvent*>, 2>
26+
eventsArray_; // 0 for timing=false, 1 for timing=true
27+
};
28+
29+
} // namespace c10d

0 commit comments

Comments
 (0)