Skip to content

Commit cb60ceb

Browse files
authored
[SYCLomatic] Refine migration of CuDNN types (#533)
* Add bool and operator= for some dnn types * add const to bool conversion
1 parent a953fc7 commit cb60ceb

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

clang/runtime/dpct-rt/include/dnnl_utils.hpp.inc

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,15 @@ public:
198198
}
199199
return result;
200200
}
201+
202+
operator bool() const {
203+
return bool(_desc);
204+
}
205+
206+
memory_desc_ext &operator=(std::nullptr_t) {
207+
_desc.reset(nullptr);
208+
return *this;
209+
}
201210
};
202211
// DPCT_LABEL_END
203212

@@ -653,6 +662,16 @@ public:
653662
_strides[i - 2];
654663
}
655664
}
665+
666+
convolution_desc &operator=(std::nullptr_t) {
667+
return *this = convolution_desc();
668+
}
669+
670+
operator bool() const {
671+
return _strides.size() == 0
672+
&& _dilates.size() == 0
673+
&& _paddings.size() == 0;
674+
}
656675
};
657676
// DPCT_LABEL_END
658677

@@ -2110,6 +2129,16 @@ public:
21102129
size_t workspace_size, void *workspace);
21112130
// DPCT_LABEL_END
21122131

2132+
operator bool() const {
2133+
return bool(_eng) && bool(_s) && bool(_q);
2134+
}
2135+
2136+
engine_ext &operator=(std::nullptr_t) {
2137+
_eng.reset(nullptr);
2138+
_s.reset(nullptr);
2139+
_q = nullptr;
2140+
return *this;
2141+
}
21132142
// DPCT_LABEL_BEGIN|engine_ext_1|dpct::dnnl
21142143
// DPCT_DEPENDENCY_BEGIN
21152144
// DnnlUtils|engine_ext
@@ -4745,4 +4774,4 @@ sycl::event engine_ext::async_rnn_backward(
47454774
} // namespace dnnl
47464775
} // namespace dpct
47474776

4748-
#endif // __DPCT_DNNL_UTILS_HPP__
4777+
#endif // __DPCT_DNNL_UTILS_HPP__

clang/test/dpct/helper_files_ref/include/dnnl_utils.hpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,15 @@ class memory_desc_ext {
170170
}
171171
return result;
172172
}
173+
174+
operator bool() const {
175+
return bool(_desc);
176+
}
177+
178+
memory_desc_ext &operator=(std::nullptr_t) {
179+
_desc.reset(nullptr);
180+
return *this;
181+
}
173182
};
174183

175184
/// A class holding description for an activation operation.
@@ -585,6 +594,16 @@ class convolution_desc {
585594
_strides[i - 2];
586595
}
587596
}
597+
598+
convolution_desc &operator=(std::nullptr_t) {
599+
return *this = convolution_desc();
600+
}
601+
602+
operator bool() const {
603+
return _strides.size() == 0
604+
&& _dilates.size() == 0
605+
&& _paddings.size() == 0;
606+
}
588607
};
589608

590609
/// An enum class representing rnn mode.
@@ -1757,6 +1776,16 @@ class engine_ext {
17571776
void *weight, void *diff_weight, size_t scratchpad_size, void *scratchpad,
17581777
size_t workspace_size, void *workspace);
17591778

1779+
operator bool() const {
1780+
return bool(_eng) && bool(_s) && bool(_q);
1781+
}
1782+
1783+
engine_ext &operator=(std::nullptr_t) {
1784+
_eng.reset(nullptr);
1785+
_s.reset(nullptr);
1786+
_q = nullptr;
1787+
return *this;
1788+
}
17601789
};
17611790

17621791
inline
@@ -4186,4 +4215,4 @@ sycl::event engine_ext::async_rnn_backward(
41864215
} // namespace dnnl
41874216
} // namespace dpct
41884217

4189-
#endif // __DPCT_DNNL_UTILS_HPP__
4218+
#endif // __DPCT_DNNL_UTILS_HPP__

0 commit comments

Comments
 (0)