diff --git a/include/matx/operators/collapse.h b/include/matx/operators/collapse.h index a6a33bca6..a25c39fe9 100644 --- a/include/matx/operators/collapse.h +++ b/include/matx/operators/collapse.h @@ -51,6 +51,7 @@ namespace matx using scalar_type = typename T1::scalar_type; using shape_type = index_t; using matxoplvalue = bool; + using self_type = LCollapseOp; __MATX_INLINE__ std::string str() const { return "lcollapse<" + std::to_string(DIM) + ">(" + op_.str() + ")"; } __MATX_INLINE__ LCollapseOp(const T1 op) : op_(op) @@ -132,6 +133,12 @@ namespace matx return op_.Size(DIM + dim - 1); } + ~LCollapseOp() = default; + LCollapseOp(const LCollapseOp &rhs) = default; + __MATX_INLINE__ auto operator=(const self_type &rhs) { + return set(*this, rhs); + } + template __MATX_INLINE__ auto operator=(const R &rhs) { if constexpr (is_matx_transform_op()) { @@ -199,6 +206,7 @@ namespace matx using scalar_type = typename T1::scalar_type; using shape_type = index_t; using matxlvalue = bool; + using self_type = RCollapseOp; __MATX_INLINE__ std::string str() const { return "rcollapse<" + std::to_string(DIM) + ">(" + op_.str() + ")"; } @@ -281,6 +289,12 @@ namespace matx return op_.Size(dim); } + ~RCollapseOp() = default; + RCollapseOp(const RCollapseOp &rhs) = default; + __MATX_INLINE__ auto operator=(const self_type &rhs) { + return set(*this, rhs); + } + template __MATX_INLINE__ auto operator=(const R &rhs) { if constexpr (is_matx_transform_op()) { diff --git a/include/matx/operators/concat.h b/include/matx/operators/concat.h index c414baf34..76e442d3f 100644 --- a/include/matx/operators/concat.h +++ b/include/matx/operators/concat.h @@ -51,6 +51,7 @@ namespace matx { using first_type = cuda::std::tuple_element_t<0, cuda::std::tuple>; using first_value_type = typename first_type::scalar_type; + using self_type = ConcatOp; static constexpr int RANK = first_type::Rank(); @@ -165,6 +166,12 @@ namespace matx return cuda::std::get<0>(ops_).Size(dim); } + ~ConcatOp() = default; + ConcatOp(const ConcatOp &rhs) = default; + __MATX_INLINE__ auto operator=(const self_type &rhs) { + return set(*this, rhs); + } + template __MATX_INLINE__ auto operator=(const R &rhs) { if constexpr (is_matx_transform_op()) { diff --git a/include/matx/operators/overlap.h b/include/matx/operators/overlap.h index 788835afd..41f718a58 100644 --- a/include/matx/operators/overlap.h +++ b/include/matx/operators/overlap.h @@ -48,6 +48,7 @@ namespace matx public: using scalar_type = typename T::scalar_type; using shape_type = index_t; + using self_type = OverlapOp; private: typename base_type::type op_; @@ -118,7 +119,13 @@ namespace matx if constexpr (is_matx_op()) { op_.PostRun(std::forward(shape), std::forward(ex)); } - } + } + + ~OverlapOp() = default; + OverlapOp(const OverlapOp &rhs) = default; + __MATX_INLINE__ auto operator=(const self_type &rhs) { + return set(*this, rhs); + } template __MATX_INLINE__ auto operator=(const R &rhs) { diff --git a/include/matx/operators/permute.h b/include/matx/operators/permute.h index 34cf57407..cfde597a4 100644 --- a/include/matx/operators/permute.h +++ b/include/matx/operators/permute.h @@ -47,6 +47,7 @@ namespace matx { public: using scalar_type = typename T::scalar_type; + using self_type = PermuteOp; private: typename base_type::type op_; @@ -133,7 +134,13 @@ namespace matx if constexpr (is_matx_op()) { op_.PostRun(std::forward(shape), std::forward(ex)); } - } + } + + ~PermuteOp() = default; + PermuteOp(const PermuteOp &rhs) = default; + __MATX_INLINE__ auto operator=(const self_type &rhs) { + return set(*this, rhs); + } template __MATX_INLINE__ auto operator=(const R &rhs) { diff --git a/include/matx/operators/remap.h b/include/matx/operators/remap.h index 3d1567978..ba3cda1a6 100644 --- a/include/matx/operators/remap.h +++ b/include/matx/operators/remap.h @@ -57,6 +57,7 @@ namespace matx using scalar_type = typename T::scalar_type; using shape_type = std::conditional_t, typename T::shape_type, index_t>; using index_type = typename IdxType::scalar_type; + using self_type = RemapOp; static_assert(std::is_integral::value, "RemapOp: Type for index operator must be integral"); static_assert(IdxType::Rank() <= 1, "RemapOp: Rank of index operator must be 0 or 1"); static_assert(DIM()) { op_.PostRun(std::forward(shape), std::forward(ex)); } - } + } + + ~RemapOp() = default; + RemapOp(const RemapOp &rhs) = default; + __MATX_INLINE__ auto operator=(const self_type &rhs) { + return set(*this, rhs); + } template __MATX_INLINE__ auto operator=(const R &rhs) { diff --git a/include/matx/operators/reshape.h b/include/matx/operators/reshape.h index 035b4caef..de3a06302 100644 --- a/include/matx/operators/reshape.h +++ b/include/matx/operators/reshape.h @@ -56,6 +56,7 @@ namespace matx public: using matxop = bool; using matxoplvalue = bool; + using self_type = ReshapeOp; __MATX_INLINE__ std::string str() const { return "reshape(" + op_.str() + ")"; } @@ -149,7 +150,13 @@ namespace matx if constexpr (is_matx_op()) { op_.PostRun(std::forward(shape), std::forward(ex)); } - } + } + + ~ReshapeOp() = default; + ReshapeOp(const ReshapeOp &rhs) = default; + __MATX_INLINE__ auto operator=(const self_type &rhs) { + return set(*this, rhs); + } template __MATX_INLINE__ auto operator=(const R &rhs) { diff --git a/include/matx/operators/reverse.h b/include/matx/operators/reverse.h index d72b03d0d..39701d8fa 100644 --- a/include/matx/operators/reverse.h +++ b/include/matx/operators/reverse.h @@ -57,6 +57,7 @@ namespace matx using matxop = bool; using matxoplvalue = bool; using scalar_type = typename T1::scalar_type; + using self_type = ReverseOp; __MATX_INLINE__ std::string str() const { return "reverse(" + op_.str() + ")"; } @@ -112,7 +113,13 @@ namespace matx if constexpr (is_matx_op()) { op_.PostRun(std::forward(shape), std::forward(ex)); } - } + } + + ~ReverseOp() = default; + ReverseOp(const ReverseOp &rhs) = default; + __MATX_INLINE__ auto operator=(const self_type &rhs) { + return set(*this, rhs); + } template __MATX_INLINE__ auto operator=(const R &rhs) { diff --git a/include/matx/operators/set.h b/include/matx/operators/set.h index cb1b1dc2c..8bdc8e411 100644 --- a/include/matx/operators/set.h +++ b/include/matx/operators/set.h @@ -68,7 +68,6 @@ class set : public BaseOp> { public: // Type specifier for reflection on class using scalar_type = typename T::scalar_type; - using shape_type = std::conditional_t, typename T::shape_type, index_t>; using tensor_type = T; using op_type = Op; using matx_setop = bool; @@ -136,7 +135,9 @@ class set : public BaseOp> { return r; } } - __MATX_DEVICE__ __MATX_HOST__ inline decltype(auto) operator()(cuda::std::array idx) const noexcept + + template + __MATX_DEVICE__ __MATX_HOST__ inline decltype(auto) operator()(cuda::std::array idx) const noexcept { auto res = cuda::std::apply([&](auto &&...args) { return _internal_mapply(args...); diff --git a/include/matx/operators/shift.h b/include/matx/operators/shift.h index c347876ea..e8a02663f 100644 --- a/include/matx/operators/shift.h +++ b/include/matx/operators/shift.h @@ -61,6 +61,7 @@ namespace matx using matxop = bool; using matxoplvalue = bool; using scalar_type = typename T1::scalar_type; + using self_type = ShiftOp; __MATX_INLINE__ std::string str() const { return "shift(" + op_.str() + ")"; } @@ -131,6 +132,12 @@ namespace matx return detail::matx_max(size1,size2); } + ~ShiftOp() = default; + ShiftOp(const ShiftOp &rhs) = default; + __MATX_INLINE__ auto operator=(const self_type &rhs) { + return set(*this, rhs); + } + template __MATX_INLINE__ auto operator=(const R &rhs) { if constexpr (is_matx_transform_op()) { diff --git a/include/matx/operators/slice.h b/include/matx/operators/slice.h index 28c9334b6..e8a518614 100644 --- a/include/matx/operators/slice.h +++ b/include/matx/operators/slice.h @@ -48,6 +48,7 @@ namespace matx public: using scalar_type = typename T::scalar_type; using shape_type = index_t; + using self_type = SliceOp; private: typename base_type::type op_; @@ -158,6 +159,12 @@ namespace matx return sizes_[dim]; } + ~SliceOp() = default; + SliceOp(const SliceOp &rhs) = default; + __MATX_INLINE__ auto operator=(const self_type &rhs) { + return set(*this, rhs); + } + template __MATX_INLINE__ auto operator=(const R &rhs) { if constexpr (is_matx_transform_op()) { diff --git a/include/matx/operators/stack.h b/include/matx/operators/stack.h index d7ab9258c..c2de85245 100644 --- a/include/matx/operators/stack.h +++ b/include/matx/operators/stack.h @@ -50,6 +50,7 @@ namespace matx { using first_type = cuda::std::tuple_element_t<0, cuda::std::tuple>; using first_value_type = typename first_type::scalar_type; + using self_type = StackOp; static constexpr int RANK = first_type::Rank(); @@ -181,6 +182,12 @@ namespace matx } } + ~StackOp() = default; + StackOp(const StackOp &rhs) = default; + __MATX_INLINE__ auto operator=(const self_type &rhs) { + return set(*this, rhs); + } + template __MATX_INLINE__ auto operator=(const R &rhs) { if constexpr (is_matx_transform_op()) { diff --git a/include/matx/operators/transpose.h b/include/matx/operators/transpose.h index 7a56f840f..c8d7a61ab 100644 --- a/include/matx/operators/transpose.h +++ b/include/matx/operators/transpose.h @@ -57,6 +57,7 @@ namespace detail { using matx_transform_op = bool; using matxoplvalue = bool; using transpose_xform_op = bool; + using self_type = TransposeMatrixOp; __MATX_INLINE__ std::string str() const { return "transpose_matrix(" + get_type_str(a_) + ")"; } __MATX_INLINE__ TransposeMatrixOp(OpA a) : a_(a) { @@ -121,6 +122,12 @@ namespace detail { return out_dims_[dim]; } + ~TransposeMatrixOp() = default; + TransposeMatrixOp(const TransposeMatrixOp &rhs) = default; + __MATX_INLINE__ auto operator=(const self_type &rhs) { + return set(*this, rhs); + } + template __MATX_INLINE__ auto operator=(const R &rhs) { if constexpr (is_matx_transform_op()) { diff --git a/include/matx/operators/updownsample.h b/include/matx/operators/updownsample.h index bda9513f2..d211c60ab 100644 --- a/include/matx/operators/updownsample.h +++ b/include/matx/operators/updownsample.h @@ -54,6 +54,7 @@ namespace matx using matxop = bool; using matxoplvalue = bool; using scalar_type = typename T::scalar_type; + using self_type = UpsampleOp; static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() { @@ -109,6 +110,12 @@ namespace matx } } + ~UpsampleOp() = default; + UpsampleOp(const UpsampleOp &rhs) = default; + __MATX_INLINE__ auto operator=(const self_type &rhs) { + return set(*this, rhs); + } + template __MATX_INLINE__ auto operator=(const R &rhs) { return set(*this, rhs); } }; } diff --git a/test/00_operators/OperatorTests.cu b/test/00_operators/OperatorTests.cu index 1d5d6b5fc..4b7e3b4f9 100644 --- a/test/00_operators/OperatorTests.cu +++ b/test/00_operators/OperatorTests.cu @@ -1589,6 +1589,44 @@ TYPED_TEST(OperatorTestsNumericAllExecs, RemapOp) } } + { + // Remap as both LHS and RHS + auto in = make_tensor({4,4,4}); + auto out = make_tensor({4,4,4}); + TestType c = GenerateData(); + for (int i = 0; i < in.Size(0); i++){ + for (int j = 0; j < in.Size(1); j++){ + for (int k = 0; k < in.Size(2); k++){ + in(i,j,k) = c; + } + } + } + + auto map1 = matx::make_tensor({2}); + auto map2 = matx::make_tensor({2}); + map1(0) = 1; + map1(1) = 2; + map2(0) = 0; + map2(1) = 1; + + (out=0).run(); + (matx::remap<2>(out, map2) = matx::remap<2>(in, map1)).run(); + exec.sync(); + + for (int i = 0; i < in.Size(0); i++){ + for (int j = 0; j < in.Size(1); j++){ + for (int k = 0; k < in.Size(2); k++){ + if (k > 1) { + EXPECT_TRUE(out(i,j,k) == (TestType)0); + } + else { + EXPECT_TRUE(out(i,j,k) == in(i,j,k)); + } + } + } + } + } + MATX_EXIT_HANDLER(); }