1- #include " ../../utils.hpp"
2- #include " infinicore/common/hash.hpp"
3- #include " infinicore/ops/common/cache.hpp"
1+ #include " ../infiniop_impl.hpp"
42#include " infinicore/ops/gemm.hpp"
5- #include < infiniop.h>
63
74namespace infinicore ::op::gemm_impl::infiniop {
8- // A desc holder to make it a shared pointer that can auto clean-up
9- struct Descriptor {
10- infiniopGemmDescriptor_t desc;
11- Descriptor (infiniopGemmDescriptor_t desc) : desc(desc) {}
12- ~Descriptor () {
13- if (desc != nullptr ) {
14- infiniopDestroyGemmDescriptor (desc);
15- desc = nullptr ;
16- }
17- }
18- };
195
20- thread_local common::OpCache<size_t , std::shared_ptr<Descriptor>>
21- caches (
22- // capacity
23- 100 ,
24- // on evict
25- [](std::shared_ptr<Descriptor> &desc) {
26- desc = nullptr ;
27- });
6+ INFINIOP_CACHABLE_DESCRIPTOR (Descriptor, Gemm, 100 );
287
298struct PlannedMeta {
309 std::shared_ptr<Descriptor> descriptor;
@@ -33,25 +12,13 @@ struct PlannedMeta {
3312};
3413
3514void *plan (Tensor c, Tensor a, Tensor b, float alpha, float beta) {
36- size_t seed = hash_combine (c, b, a, alpha, beta);
37-
38- auto device = context::getDevice ();
39- auto &cache = caches.getCache (device);
40-
41- auto descriptor = cache.get (seed).value_or (nullptr );
15+ size_t seed = hash_combine (c, a, b);
4216
43- if (!descriptor) {
44- descriptor = std::make_shared<Descriptor>(nullptr );
45- INFINICORE_CHECK_ERROR (infiniopCreateGemmDescriptor (
46- context::getInfiniopHandle (device),
47- &descriptor->desc ,
48- c->desc (), a->desc (), b->desc ()));
49- cache.put (seed, descriptor);
50- }
17+ INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE (
18+ Descriptor, descriptor, Gemm,
19+ seed, c->desc (), a->desc (), b->desc ());
5120
52- size_t workspace_size = 0 ;
53- INFINICORE_CHECK_ERROR (infiniopGetGemmWorkspaceSize (descriptor->desc , &workspace_size));
54- Tensor workspace = Tensor::empty ({workspace_size}, DataType::U8, device);
21+ INFINIOP_WORKSPACE_TENSOR (workspace, Gemm, descriptor);
5522
5623 auto planned = new PlannedMeta{
5724 descriptor,
@@ -77,18 +44,6 @@ void cleanup(void **planned_meta_ptr) {
7744 *planned_meta_ptr = nullptr ;
7845}
7946
80- void calculate (Tensor c, Tensor a, Tensor b, float alpha, float beta) {
81- auto planned = plan (c, a, b, alpha, beta);
82- run (planned);
83- cleanup (&planned);
84- }
85-
86- static bool registered = []() {
87- Gemm::dispatcher ().registerAll (&calculate, false );
88- Gemm::plan_dispatcher ().registerAll (&plan, false );
89- Gemm::run_dispatcher ().registerAll (&run, false );
90- Gemm::cleanup_dispatcher ().registerAll (&cleanup, false );
91- return true ;
92- }();
47+ INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE (Gemm, &plan, &run, &cleanup);
9348
9449} // namespace infinicore::op::gemm_impl::infiniop
0 commit comments