Skip to content

Commit

Permalink
Fixed issue with operator types used as both lvalue/rvalue no assigning
Browse files Browse the repository at this point in the history
  • Loading branch information
cliffburdick committed Jun 25, 2024
1 parent ed09e1c commit 9bcfeab
Show file tree
Hide file tree
Showing 14 changed files with 137 additions and 7 deletions.
14 changes: 14 additions & 0 deletions include/matx/operators/collapse.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ namespace matx
using scalar_type = typename T1::scalar_type;
using shape_type = index_t;
using matxoplvalue = bool;
using self_type = LCollapseOp<DIM, T1>;

__MATX_INLINE__ std::string str() const { return "lcollapse<" + std::to_string(DIM) + ">(" + op_.str() + ")"; }
__MATX_INLINE__ LCollapseOp(const T1 op) : op_(op)
Expand Down Expand Up @@ -132,6 +133,12 @@ namespace matx
return op_.Size(DIM + dim - 1);
}

~LCollapseOp() = default;
LCollapseOp(const LCollapseOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
if constexpr (is_matx_transform_op<R>()) {
Expand Down Expand Up @@ -199,6 +206,7 @@ namespace matx
using scalar_type = typename T1::scalar_type;
using shape_type = index_t;
using matxlvalue = bool;
using self_type = RCollapseOp<DIM, T1>;

__MATX_INLINE__ std::string str() const { return "rcollapse<" + std::to_string(DIM) + ">(" + op_.str() + ")"; }

Expand Down Expand Up @@ -281,6 +289,12 @@ namespace matx
return op_.Size(dim);
}

~RCollapseOp() = default;
RCollapseOp(const RCollapseOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
if constexpr (is_matx_transform_op<R>()) {
Expand Down
7 changes: 7 additions & 0 deletions include/matx/operators/concat.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ namespace matx
{
using first_type = cuda::std::tuple_element_t<0, cuda::std::tuple<Ts...>>;
using first_value_type = typename first_type::scalar_type;
using self_type = ConcatOp<Ts...>;

static constexpr int RANK = first_type::Rank();

Expand Down Expand Up @@ -165,6 +166,12 @@ namespace matx
return cuda::std::get<0>(ops_).Size(dim);
}

~ConcatOp() = default;
ConcatOp(const ConcatOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
if constexpr (is_matx_transform_op<R>()) {
Expand Down
9 changes: 8 additions & 1 deletion include/matx/operators/overlap.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ namespace matx
public:
using scalar_type = typename T::scalar_type;
using shape_type = index_t;
using self_type = OverlapOp<DIM, T>;

private:
typename base_type<T>::type op_;
Expand Down Expand Up @@ -118,7 +119,13 @@ namespace matx
if constexpr (is_matx_op<T>()) {
op_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
}
}

~OverlapOp() = default;
OverlapOp(const OverlapOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
Expand Down
9 changes: 8 additions & 1 deletion include/matx/operators/permute.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ namespace matx
{
public:
using scalar_type = typename T::scalar_type;
using self_type = PermuteOp<T>;

private:
typename base_type<T>::type op_;
Expand Down Expand Up @@ -133,7 +134,13 @@ namespace matx
if constexpr (is_matx_op<T>()) {
op_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
}
}

~PermuteOp() = default;
PermuteOp(const PermuteOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
Expand Down
9 changes: 8 additions & 1 deletion include/matx/operators/remap.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ namespace matx
using scalar_type = typename T::scalar_type;
using shape_type = std::conditional_t<has_shape_type_v<T>, typename T::shape_type, index_t>;
using index_type = typename IdxType::scalar_type;
using self_type = RemapOp<DIM, T, IdxType>;
static_assert(std::is_integral<index_type>::value, "RemapOp: Type for index operator must be integral");
static_assert(IdxType::Rank() <= 1, "RemapOp: Rank of index operator must be 0 or 1");
static_assert(DIM<T::Rank(), "RemapOp: DIM must be less than Rank of tensor");
Expand Down Expand Up @@ -134,7 +135,13 @@ namespace matx
if constexpr (is_matx_op<T>()) {
op_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
}
}

~RemapOp() = default;
RemapOp(const RemapOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
Expand Down
9 changes: 8 additions & 1 deletion include/matx/operators/reshape.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ namespace matx
public:
using matxop = bool;
using matxoplvalue = bool;
using self_type = ReshapeOp<RANK, T, ShapeType>;

__MATX_INLINE__ std::string str() const { return "reshape(" + op_.str() + ")"; }

Expand Down Expand Up @@ -149,7 +150,13 @@ namespace matx
if constexpr (is_matx_op<T>()) {
op_.PostRun(std::forward<S2>(shape), std::forward<Executor>(ex));
}
}
}

~ReshapeOp() = default;
ReshapeOp(const ReshapeOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
Expand Down
9 changes: 8 additions & 1 deletion include/matx/operators/reverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ namespace matx
using matxop = bool;
using matxoplvalue = bool;
using scalar_type = typename T1::scalar_type;
using self_type = ReverseOp<DIM, T1>;

__MATX_INLINE__ std::string str() const { return "reverse(" + op_.str() + ")"; }

Expand Down Expand Up @@ -112,7 +113,13 @@ namespace matx
if constexpr (is_matx_op<T1>()) {
op_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
}
}

~ReverseOp() = default;
ReverseOp(const ReverseOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
Expand Down
5 changes: 3 additions & 2 deletions include/matx/operators/set.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ class set : public BaseOp<set<T, Op>> {
public:
// Type specifier for reflection on class
using scalar_type = typename T::scalar_type;
using shape_type = std::conditional_t<has_shape_type_v<T>, typename T::shape_type, index_t>;
using tensor_type = T;
using op_type = Op;
using matx_setop = bool;
Expand Down Expand Up @@ -136,7 +135,9 @@ class set : public BaseOp<set<T, Op>> {
return r;
}
}
__MATX_DEVICE__ __MATX_HOST__ inline decltype(auto) operator()(cuda::std::array<shape_type, T::Rank()> idx) const noexcept

template <typename ShapeType>
__MATX_DEVICE__ __MATX_HOST__ inline decltype(auto) operator()(cuda::std::array<ShapeType, T::Rank()> idx) const noexcept
{
auto res = cuda::std::apply([&](auto &&...args) {
return _internal_mapply(args...);
Expand Down
7 changes: 7 additions & 0 deletions include/matx/operators/shift.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ namespace matx
using matxop = bool;
using matxoplvalue = bool;
using scalar_type = typename T1::scalar_type;
using self_type = ShiftOp<DIM, T1, T2>;

__MATX_INLINE__ std::string str() const { return "shift(" + op_.str() + ")"; }

Expand Down Expand Up @@ -131,6 +132,12 @@ namespace matx
return detail::matx_max(size1,size2);
}

~ShiftOp() = default;
ShiftOp(const ShiftOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
if constexpr (is_matx_transform_op<R>()) {
Expand Down
7 changes: 7 additions & 0 deletions include/matx/operators/slice.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ namespace matx
public:
using scalar_type = typename T::scalar_type;
using shape_type = index_t;
using self_type = SliceOp<DIM, T>;

private:
typename base_type<T>::type op_;
Expand Down Expand Up @@ -158,6 +159,12 @@ namespace matx
return sizes_[dim];
}

~SliceOp() = default;
SliceOp(const SliceOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
if constexpr (is_matx_transform_op<R>()) {
Expand Down
7 changes: 7 additions & 0 deletions include/matx/operators/stack.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ namespace matx
{
using first_type = cuda::std::tuple_element_t<0, cuda::std::tuple<Ts...>>;
using first_value_type = typename first_type::scalar_type;
using self_type = StackOp<Ts...>;

static constexpr int RANK = first_type::Rank();

Expand Down Expand Up @@ -181,6 +182,12 @@ namespace matx
}
}

~StackOp() = default;
StackOp(const StackOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
if constexpr (is_matx_transform_op<R>()) {
Expand Down
7 changes: 7 additions & 0 deletions include/matx/operators/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ namespace detail {
using matx_transform_op = bool;
using matxoplvalue = bool;
using transpose_xform_op = bool;
using self_type = TransposeMatrixOp<OpA>;

__MATX_INLINE__ std::string str() const { return "transpose_matrix(" + get_type_str(a_) + ")"; }
__MATX_INLINE__ TransposeMatrixOp(OpA a) : a_(a) {
Expand Down Expand Up @@ -121,6 +122,12 @@ namespace detail {
return out_dims_[dim];
}

~TransposeMatrixOp() = default;
TransposeMatrixOp(const TransposeMatrixOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
if constexpr (is_matx_transform_op<R>()) {
Expand Down
7 changes: 7 additions & 0 deletions include/matx/operators/updownsample.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ namespace matx
using matxop = bool;
using matxoplvalue = bool;
using scalar_type = typename T::scalar_type;
using self_type = UpsampleOp<T>;

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
{
Expand Down Expand Up @@ -109,6 +110,12 @@ namespace matx
}
}

~UpsampleOp() = default;
UpsampleOp(const UpsampleOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R> __MATX_INLINE__ auto operator=(const R &rhs) { return set(*this, rhs); }
};
}
Expand Down
38 changes: 38 additions & 0 deletions test/00_operators/OperatorTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1589,6 +1589,44 @@ TYPED_TEST(OperatorTestsNumericAllExecs, RemapOp)
}
}

{
// Remap as both LHS and RHS
auto in = make_tensor<TestType>({4,4,4});
auto out = make_tensor<TestType>({4,4,4});
TestType c = GenerateData<TestType>();
for (int i = 0; i < in.Size(0); i++){
for (int j = 0; j < in.Size(1); j++){
for (int k = 0; k < in.Size(2); k++){
in(i,j,k) = c;
}
}
}

auto map1 = matx::make_tensor<int>({2});
auto map2 = matx::make_tensor<int>({2});
map1(0) = 1;
map1(1) = 2;
map2(0) = 0;
map2(1) = 1;

(out=0).run();
(matx::remap<2>(out, map2) = matx::remap<2>(in, map1)).run();
exec.sync();

for (int i = 0; i < in.Size(0); i++){
for (int j = 0; j < in.Size(1); j++){
for (int k = 0; k < in.Size(2); k++){
if (k > 1) {
EXPECT_TRUE(out(i,j,k) == (TestType)0);
}
else {
EXPECT_TRUE(out(i,j,k) == in(i,j,k));
}
}
}
}
}

MATX_EXIT_HANDLER();
}

Expand Down

0 comments on commit 9bcfeab

Please sign in to comment.