Skip to content

Commit 2d7ad56

Browse files
issue/810 add common graph op macros
1 parent 006d530 commit 2d7ad56

File tree

12 files changed

+151
-106
lines changed

12 files changed

+151
-106
lines changed

include/infinicore/graph/graph.hpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,50 @@ class Graph {
4343
friend class GraphManager;
4444
};
4545
} // namespace infinicore::graph
46+
47+
#define INFINICORE_GRAPH_OP_CLASS(__OP_NAME__, ...) \
48+
class __OP_NAME__ : public graph::GraphOperator { \
49+
public: \
50+
using schema = void (*)(__VA_ARGS__); \
51+
using plan_schema = void *(*)(__VA_ARGS__); \
52+
static common::OpDispatcher<plan_schema> &plan_dispatcher(); \
53+
static common::OpDispatcher<run_schema> &run_dispatcher(); \
54+
static common::OpDispatcher<cleanup_schema> &cleanup_dispatcher(); \
55+
__OP_NAME__(__VA_ARGS__); \
56+
static void execute(__VA_ARGS__); \
57+
};
58+
59+
#define INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(__OP_NAME__) \
60+
common::OpDispatcher<__OP_NAME__::plan_schema> &__OP_NAME__::plan_dispatcher() { \
61+
static common::OpDispatcher<__OP_NAME__::plan_schema> dispatcher_; \
62+
return dispatcher_; \
63+
} \
64+
common::OpDispatcher<__OP_NAME__::run_schema> &__OP_NAME__::run_dispatcher() { \
65+
static common::OpDispatcher<__OP_NAME__::run_schema> dispatcher_; \
66+
return dispatcher_; \
67+
} \
68+
common::OpDispatcher<__OP_NAME__::cleanup_schema> &__OP_NAME__::cleanup_dispatcher() { \
69+
static common::OpDispatcher<__OP_NAME__::cleanup_schema> dispatcher_; \
70+
return dispatcher_; \
71+
}
72+
73+
#define INFINICORE_GRAPH_OP_DISPATCH(__DEVICE_TYPE__, ...) \
74+
planned_meta_ = plan_dispatcher().lookup(__DEVICE_TYPE__)(__VA_ARGS__); \
75+
runner_ = run_dispatcher().lookup(__DEVICE_TYPE__); \
76+
deleter_ = cleanup_dispatcher().lookup(__DEVICE_TYPE__);
77+
78+
#define INFINICORE_GRAPH_OP_RECORD_OR_RUN(__OP_NAME__, ...) \
79+
auto op = std::make_shared<__OP_NAME__>(__VA_ARGS__); \
80+
if (context::isGraphRecording()) { \
81+
context::addGraphOperator(op); \
82+
} else { \
83+
op->run(); \
84+
}
85+
86+
#define INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(__OP_NAME__, __PLAN_F__, __RUN_F__, __CLEANUP_F__) \
87+
static bool registered = []() { \
88+
__OP_NAME__::plan_dispatcher().registerAll(__PLAN_F__, false); \
89+
__OP_NAME__::run_dispatcher().registerAll(__RUN_F__, false); \
90+
__OP_NAME__::cleanup_dispatcher().registerAll(__CLEANUP_F__, false); \
91+
return true; \
92+
}();

include/infinicore/ops/gemm.hpp

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,7 @@
66

77
namespace infinicore::op {
88

9-
class Gemm : public graph::GraphOperator {
10-
public:
11-
using schema = void (*)(Tensor, Tensor, Tensor, float, float);
12-
using plan_schema = void *(*)(Tensor, Tensor, Tensor, float, float);
13-
14-
Gemm(Tensor c, Tensor a, Tensor b, float alpha, float beta);
15-
16-
static void execute(Tensor c, Tensor a, Tensor b, float alpha, float beta);
17-
18-
static common::OpDispatcher<schema> &dispatcher();
19-
static common::OpDispatcher<plan_schema> &plan_dispatcher();
20-
static common::OpDispatcher<run_schema> &run_dispatcher();
21-
static common::OpDispatcher<cleanup_schema> &cleanup_dispatcher();
22-
};
9+
INFINICORE_GRAPH_OP_CLASS(Gemm, Tensor, Tensor, Tensor, float, float);
2310

2411
Tensor gemm(Tensor a, Tensor b, float alpha = 1.0f, float beta = 0.0f);
2512
void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta);

src/infinicore/context/allocators/device_pinned_allocator.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,18 @@ DevicePinnedHostAllocator::~DevicePinnedHostAllocator() {
1212
}
1313

1414
std::byte *DevicePinnedHostAllocator::allocate(size_t size) {
15+
if (size == 0) {
16+
return nullptr;
17+
}
1518
void *ptr;
1619
INFINICORE_CHECK_ERROR(infinirtMallocHost(&ptr, size));
1720
return (std::byte *)ptr;
1821
}
1922

2023
void DevicePinnedHostAllocator::deallocate(std::byte *ptr) {
24+
if (ptr == nullptr) {
25+
return;
26+
}
2127
if (owner_ == context::getDevice()) {
2228
INFINICORE_CHECK_ERROR(infinirtFreeHost(ptr));
2329
gc();

src/infinicore/context/allocators/host_allocator.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,16 @@
44

55
namespace infinicore {
66
std::byte *HostAllocator::allocate(size_t size) {
7+
if (size == 0) {
8+
return nullptr;
9+
}
710
return (std::byte *)std::malloc(size);
811
}
912

1013
void HostAllocator::deallocate(std::byte *ptr) {
14+
if (ptr == nullptr) {
15+
return;
16+
}
1117
std::free(ptr);
1218
}
1319

src/infinicore/context/allocators/pinnable_block_allocator.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ PinnableBlockAllocator::PinnableBlockAllocator(Device device)
3737

3838
// ------------------- allocate -------------------
3939
std::byte *PinnableBlockAllocator::allocate(size_t size) {
40+
if (size == 0) {
41+
return nullptr;
42+
}
4043
std::lock_guard<std::mutex> lock(mutex_);
4144

4245
// Align size to 256 bytes for GPU
@@ -94,7 +97,7 @@ std::byte *PinnableBlockAllocator::allocate(size_t size) {
9497

9598
// ------------------- deallocate -------------------
9699
void PinnableBlockAllocator::deallocate(std::byte *ptr) {
97-
if (!ptr) {
100+
if (ptr == nullptr) {
98101
return;
99102
}
100103

src/infinicore/context/allocators/stream_ordered_allocator.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,18 @@ namespace infinicore {
88
StreamOrderedAllocator::StreamOrderedAllocator(Device device) : MemoryAllocator(), device_(device) {}
99

1010
std::byte *StreamOrderedAllocator::allocate(size_t size) {
11+
if (size == 0) {
12+
return nullptr;
13+
}
1114
void *ptr = nullptr;
1215
INFINICORE_CHECK_ERROR(infinirtMallocAsync(&ptr, size, context::getStream()));
1316
return (std::byte *)ptr;
1417
}
1518

1619
void StreamOrderedAllocator::deallocate(std::byte *ptr) {
20+
if (ptr == nullptr) {
21+
return;
22+
}
1723
INFINICORE_CHECK_ERROR(infinirtFreeAsync(ptr, context::getStream()));
1824
}
1925
} // namespace infinicore

src/infinicore/ops/gemm/gemm.cc

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,15 @@
33
#include "../../utils.hpp"
44

55
namespace infinicore::op {
6-
7-
common::OpDispatcher<Gemm::schema> &Gemm::dispatcher() {
8-
static common::OpDispatcher<Gemm::schema> dispatcher_;
9-
return dispatcher_;
10-
};
11-
12-
common::OpDispatcher<Gemm::plan_schema> &Gemm::plan_dispatcher() {
13-
static common::OpDispatcher<Gemm::plan_schema> dispatcher_;
14-
return dispatcher_;
15-
}
16-
common::OpDispatcher<Gemm::run_schema> &Gemm::run_dispatcher() {
17-
static common::OpDispatcher<Gemm::run_schema> dispatcher_;
18-
return dispatcher_;
19-
}
20-
common::OpDispatcher<Gemm::cleanup_schema> &Gemm::cleanup_dispatcher() {
21-
static common::OpDispatcher<Gemm::cleanup_schema> dispatcher_;
22-
return dispatcher_;
23-
}
6+
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Gemm);
247

258
Gemm::Gemm(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
269
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b);
27-
planned_meta_ = plan_dispatcher().lookup(c->device().getType())(c, a, b, alpha, beta);
28-
runner_ = run_dispatcher().lookup(c->device().getType());
29-
deleter_ = cleanup_dispatcher().lookup(c->device().getType());
10+
INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a, b, alpha, beta);
3011
}
3112

3213
void Gemm::execute(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
33-
34-
auto op = std::make_shared<Gemm>(c, a, b, alpha, beta);
35-
if (context::isGraphRecording()) {
36-
context::addGraphOperator(op);
37-
} else {
38-
op->run();
39-
}
14+
INFINICORE_GRAPH_OP_RECORD_OR_RUN(Gemm, c, a, b, alpha, beta);
4015
}
4116

4217
Tensor gemm(Tensor a, Tensor b, float alpha, float beta) {
Lines changed: 8 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,9 @@
1-
#include "../../utils.hpp"
2-
#include "infinicore/common/hash.hpp"
3-
#include "infinicore/ops/common/cache.hpp"
1+
#include "../infiniop_impl.hpp"
42
#include "infinicore/ops/gemm.hpp"
5-
#include <infiniop.h>
63

74
namespace infinicore::op::gemm_impl::infiniop {
8-
// A desc holder to make it a shared pointer that can auto clean-up
9-
struct Descriptor {
10-
infiniopGemmDescriptor_t desc;
11-
Descriptor(infiniopGemmDescriptor_t desc) : desc(desc) {}
12-
~Descriptor() {
13-
if (desc != nullptr) {
14-
infiniopDestroyGemmDescriptor(desc);
15-
desc = nullptr;
16-
}
17-
}
18-
};
195

20-
thread_local common::OpCache<size_t, std::shared_ptr<Descriptor>>
21-
caches(
22-
// capacity
23-
100,
24-
// on evict
25-
[](std::shared_ptr<Descriptor> &desc) {
26-
desc = nullptr;
27-
});
6+
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Gemm, 100);
287

298
struct PlannedMeta {
309
std::shared_ptr<Descriptor> descriptor;
@@ -33,25 +12,13 @@ struct PlannedMeta {
3312
};
3413

3514
void *plan(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
36-
size_t seed = hash_combine(c, b, a, alpha, beta);
37-
38-
auto device = context::getDevice();
39-
auto &cache = caches.getCache(device);
40-
41-
auto descriptor = cache.get(seed).value_or(nullptr);
15+
size_t seed = hash_combine(c, a, b);
4216

43-
if (!descriptor) {
44-
descriptor = std::make_shared<Descriptor>(nullptr);
45-
INFINICORE_CHECK_ERROR(infiniopCreateGemmDescriptor(
46-
context::getInfiniopHandle(device),
47-
&descriptor->desc,
48-
c->desc(), a->desc(), b->desc()));
49-
cache.put(seed, descriptor);
50-
}
17+
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
18+
Descriptor, descriptor, Gemm,
19+
seed, c->desc(), a->desc(), b->desc());
5120

52-
size_t workspace_size = 0;
53-
INFINICORE_CHECK_ERROR(infiniopGetGemmWorkspaceSize(descriptor->desc, &workspace_size));
54-
Tensor workspace = Tensor::empty({workspace_size}, DataType::U8, device);
21+
INFINIOP_WORKSPACE_TENSOR(workspace, Gemm, descriptor);
5522

5623
auto planned = new PlannedMeta{
5724
descriptor,
@@ -77,18 +44,6 @@ void cleanup(void **planned_meta_ptr) {
7744
*planned_meta_ptr = nullptr;
7845
}
7946

80-
void calculate(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
81-
auto planned = plan(c, a, b, alpha, beta);
82-
run(planned);
83-
cleanup(&planned);
84-
}
85-
86-
static bool registered = []() {
87-
Gemm::dispatcher().registerAll(&calculate, false);
88-
Gemm::plan_dispatcher().registerAll(&plan, false);
89-
Gemm::run_dispatcher().registerAll(&run, false);
90-
Gemm::cleanup_dispatcher().registerAll(&cleanup, false);
91-
return true;
92-
}();
47+
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Gemm, &plan, &run, &cleanup);
9348

9449
} // namespace infinicore::op::gemm_impl::infiniop
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#pragma once
2+
3+
#include "../utils.hpp"
4+
#include "infinicore/common/hash.hpp"
5+
#include "infinicore/ops/common/cache.hpp"
6+
#include <infiniop.h>
7+
8+
#define INFINIOP_CACHABLE_DESCRIPTOR(__DESC_TYPE__, __OP_NAME__, __SIZE__) \
9+
struct __DESC_TYPE__ { \
10+
infiniop##__OP_NAME__##Descriptor_t desc; \
11+
Descriptor(infiniop##__OP_NAME__##Descriptor_t desc) : desc(desc) {} \
12+
~Descriptor() { \
13+
if (desc != nullptr) { \
14+
infiniopDestroy##__OP_NAME__##Descriptor(desc); \
15+
desc = nullptr; \
16+
} \
17+
} \
18+
}; \
19+
\
20+
thread_local common::OpCache<size_t, std::shared_ptr<__DESC_TYPE__>> \
21+
caches( \
22+
__SIZE__, \
23+
[](std::shared_ptr<__DESC_TYPE__> &desc) { \
24+
desc = nullptr; \
25+
});
26+
27+
#define INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(__DESC_TYPE__, __DESC_NAME__, __INFINIOP_NAME__, __HASH_KEY__, ...) \
28+
std::shared_ptr<__DESC_TYPE__> __DESC_NAME__; \
29+
{ \
30+
auto device__ = context::getDevice(); \
31+
auto &cache__ = caches.getCache(device__); \
32+
__DESC_NAME__ = cache__.get(__HASH_KEY__).value_or(nullptr); \
33+
if (!__DESC_NAME__) { \
34+
__DESC_NAME__ = std::make_shared<__DESC_TYPE__>(nullptr); \
35+
INFINICORE_CHECK_ERROR(infiniopCreate##__INFINIOP_NAME__##Descriptor( \
36+
context::getInfiniopHandle(device__), \
37+
&__DESC_NAME__->desc, \
38+
__VA_ARGS__)); \
39+
cache__.put(__HASH_KEY__, __DESC_NAME__); \
40+
} \
41+
}
42+
43+
#define INFINIOP_WORKSPACE_TENSOR(__TENSOR_NAME__, __INFINIOP_NAME__, __DESC_NAME__) \
44+
Tensor __TENSOR_NAME__; \
45+
{ \
46+
auto device__ = context::getDevice(); \
47+
size_t workspace_size = 0; \
48+
INFINICORE_CHECK_ERROR(infiniopGetGemmWorkspaceSize(descriptor->desc, &workspace_size)); \
49+
__TENSOR_NAME__ = Tensor::empty({workspace_size}, DataType::U8, device__); \
50+
}

src/infinicore/ops/linear/linear.cc

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include "infinicore/ops/linear.hpp"
2-
#include "infinicore/ops/add.hpp"
3-
#include "infinicore/ops/matmul.hpp"
2+
#include "infinicore/ops/gemm.hpp"
3+
#include "infinicore/ops/rearrange.hpp"
44

55
namespace infinicore::op {
66

@@ -42,16 +42,18 @@ void linear_(Tensor out,
4242

4343
// linear transformation
4444
Tensor out_view = out->view({N, out_features});
45-
matmul_(out_view,
46-
input->view({N, in_features}),
47-
weight->permute({1, 0}));
48-
4945
// Add bias
46+
float alpha = 1.0f;
47+
float beta = 0.0f;
5048
if (bias.has_value()) {
51-
add_(out_view,
52-
out_view,
53-
bias.value()->as_strided({N, out_features}, {0, 1}));
49+
rearrange_(out_view,
50+
bias.value()->as_strided({N, out_features}, {0, 1}));
51+
beta = 1.0f;
5452
}
53+
54+
gemm_(out_view,
55+
input->view({N, in_features}),
56+
weight->permute({1, 0}), alpha, beta);
5557
}
5658

5759
} // namespace infinicore::op

0 commit comments

Comments
 (0)