Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed issue with operator types used as both lvalue/rvalue not assigning #655

Merged
merged 1 commit into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions include/matx/operators/collapse.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ namespace matx
using scalar_type = typename T1::scalar_type;
using shape_type = index_t;
using matxoplvalue = bool;
using self_type = LCollapseOp<DIM, T1>;

__MATX_INLINE__ std::string str() const { return "lcollapse<" + std::to_string(DIM) + ">(" + op_.str() + ")"; }
__MATX_INLINE__ LCollapseOp(const T1 op) : op_(op)
Expand Down Expand Up @@ -132,6 +133,12 @@ namespace matx
return op_.Size(DIM + dim - 1);
}

~LCollapseOp() = default;
LCollapseOp(const LCollapseOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
if constexpr (is_matx_transform_op<R>()) {
Expand Down Expand Up @@ -199,6 +206,7 @@ namespace matx
using scalar_type = typename T1::scalar_type;
using shape_type = index_t;
using matxlvalue = bool;
using self_type = RCollapseOp<DIM, T1>;

__MATX_INLINE__ std::string str() const { return "rcollapse<" + std::to_string(DIM) + ">(" + op_.str() + ")"; }

Expand Down Expand Up @@ -281,6 +289,12 @@ namespace matx
return op_.Size(dim);
}

~RCollapseOp() = default;
RCollapseOp(const RCollapseOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
if constexpr (is_matx_transform_op<R>()) {
Expand Down
7 changes: 7 additions & 0 deletions include/matx/operators/concat.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ namespace matx
{
using first_type = cuda::std::tuple_element_t<0, cuda::std::tuple<Ts...>>;
using first_value_type = typename first_type::scalar_type;
using self_type = ConcatOp<Ts...>;

static constexpr int RANK = first_type::Rank();

Expand Down Expand Up @@ -165,6 +166,12 @@ namespace matx
return cuda::std::get<0>(ops_).Size(dim);
}

~ConcatOp() = default;
ConcatOp(const ConcatOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
if constexpr (is_matx_transform_op<R>()) {
Expand Down
9 changes: 8 additions & 1 deletion include/matx/operators/overlap.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ namespace matx
public:
using scalar_type = typename T::scalar_type;
using shape_type = index_t;
using self_type = OverlapOp<DIM, T>;

private:
typename base_type<T>::type op_;
Expand Down Expand Up @@ -118,7 +119,13 @@ namespace matx
if constexpr (is_matx_op<T>()) {
op_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
}
}

~OverlapOp() = default;
OverlapOp(const OverlapOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
Expand Down
9 changes: 8 additions & 1 deletion include/matx/operators/permute.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ namespace matx
{
public:
using scalar_type = typename T::scalar_type;
using self_type = PermuteOp<T>;

private:
typename base_type<T>::type op_;
Expand Down Expand Up @@ -133,7 +134,13 @@ namespace matx
if constexpr (is_matx_op<T>()) {
op_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
}
}

~PermuteOp() = default;
PermuteOp(const PermuteOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
Expand Down
9 changes: 8 additions & 1 deletion include/matx/operators/remap.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ namespace matx
using scalar_type = typename T::scalar_type;
using shape_type = std::conditional_t<has_shape_type_v<T>, typename T::shape_type, index_t>;
using index_type = typename IdxType::scalar_type;
using self_type = RemapOp<DIM, T, IdxType>;
static_assert(std::is_integral<index_type>::value, "RemapOp: Type for index operator must be integral");
static_assert(IdxType::Rank() <= 1, "RemapOp: Rank of index operator must be 0 or 1");
static_assert(DIM<T::Rank(), "RemapOp: DIM must be less than Rank of tensor");
Expand Down Expand Up @@ -134,7 +135,13 @@ namespace matx
if constexpr (is_matx_op<T>()) {
op_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
}
}

~RemapOp() = default;
RemapOp(const RemapOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
Expand Down
9 changes: 8 additions & 1 deletion include/matx/operators/reshape.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ namespace matx
public:
using matxop = bool;
using matxoplvalue = bool;
using self_type = ReshapeOp<RANK, T, ShapeType>;

__MATX_INLINE__ std::string str() const { return "reshape(" + op_.str() + ")"; }

Expand Down Expand Up @@ -149,7 +150,13 @@ namespace matx
if constexpr (is_matx_op<T>()) {
op_.PostRun(std::forward<S2>(shape), std::forward<Executor>(ex));
}
}
}

~ReshapeOp() = default;
ReshapeOp(const ReshapeOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
Expand Down
9 changes: 8 additions & 1 deletion include/matx/operators/reverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ namespace matx
using matxop = bool;
using matxoplvalue = bool;
using scalar_type = typename T1::scalar_type;
using self_type = ReverseOp<DIM, T1>;

__MATX_INLINE__ std::string str() const { return "reverse(" + op_.str() + ")"; }

Expand Down Expand Up @@ -112,7 +113,13 @@ namespace matx
if constexpr (is_matx_op<T1>()) {
op_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
}
}

~ReverseOp() = default;
ReverseOp(const ReverseOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
Expand Down
5 changes: 3 additions & 2 deletions include/matx/operators/set.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ class set : public BaseOp<set<T, Op>> {
public:
// Type specifier for reflection on class
using scalar_type = typename T::scalar_type;
using shape_type = std::conditional_t<has_shape_type_v<T>, typename T::shape_type, index_t>;
using tensor_type = T;
using op_type = Op;
using matx_setop = bool;
Expand Down Expand Up @@ -136,7 +135,9 @@ class set : public BaseOp<set<T, Op>> {
return r;
}
}
__MATX_DEVICE__ __MATX_HOST__ inline decltype(auto) operator()(cuda::std::array<shape_type, T::Rank()> idx) const noexcept

template <typename ShapeType>
__MATX_DEVICE__ __MATX_HOST__ inline decltype(auto) operator()(cuda::std::array<ShapeType, T::Rank()> idx) const noexcept
{
auto res = cuda::std::apply([&](auto &&...args) {
return _internal_mapply(args...);
Expand Down
7 changes: 7 additions & 0 deletions include/matx/operators/shift.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ namespace matx
using matxop = bool;
using matxoplvalue = bool;
using scalar_type = typename T1::scalar_type;
using self_type = ShiftOp<DIM, T1, T2>;

__MATX_INLINE__ std::string str() const { return "shift(" + op_.str() + ")"; }

Expand Down Expand Up @@ -131,6 +132,12 @@ namespace matx
return detail::matx_max(size1,size2);
}

~ShiftOp() = default;
ShiftOp(const ShiftOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
if constexpr (is_matx_transform_op<R>()) {
Expand Down
7 changes: 7 additions & 0 deletions include/matx/operators/slice.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ namespace matx
public:
using scalar_type = typename T::scalar_type;
using shape_type = index_t;
using self_type = SliceOp<DIM, T>;

private:
typename base_type<T>::type op_;
Expand Down Expand Up @@ -158,6 +159,12 @@ namespace matx
return sizes_[dim];
}

~SliceOp() = default;
SliceOp(const SliceOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
if constexpr (is_matx_transform_op<R>()) {
Expand Down
7 changes: 7 additions & 0 deletions include/matx/operators/stack.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ namespace matx
{
using first_type = cuda::std::tuple_element_t<0, cuda::std::tuple<Ts...>>;
using first_value_type = typename first_type::scalar_type;
using self_type = StackOp<Ts...>;

static constexpr int RANK = first_type::Rank();

Expand Down Expand Up @@ -181,6 +182,12 @@ namespace matx
}
}

~StackOp() = default;
StackOp(const StackOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
if constexpr (is_matx_transform_op<R>()) {
Expand Down
7 changes: 7 additions & 0 deletions include/matx/operators/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ namespace detail {
using matx_transform_op = bool;
using matxoplvalue = bool;
using transpose_xform_op = bool;
using self_type = TransposeMatrixOp<OpA>;

__MATX_INLINE__ std::string str() const { return "transpose_matrix(" + get_type_str(a_) + ")"; }
__MATX_INLINE__ TransposeMatrixOp(OpA a) : a_(a) {
Expand Down Expand Up @@ -121,6 +122,12 @@ namespace detail {
return out_dims_[dim];
}

~TransposeMatrixOp() = default;
TransposeMatrixOp(const TransposeMatrixOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R>
__MATX_INLINE__ auto operator=(const R &rhs) {
if constexpr (is_matx_transform_op<R>()) {
Expand Down
7 changes: 7 additions & 0 deletions include/matx/operators/updownsample.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ namespace matx
using matxop = bool;
using matxoplvalue = bool;
using scalar_type = typename T::scalar_type;
using self_type = UpsampleOp<T>;

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
{
Expand Down Expand Up @@ -109,6 +110,12 @@ namespace matx
}
}

~UpsampleOp() = default;
UpsampleOp(const UpsampleOp &rhs) = default;
__MATX_INLINE__ auto operator=(const self_type &rhs) {
return set(*this, rhs);
}

template<typename R> __MATX_INLINE__ auto operator=(const R &rhs) { return set(*this, rhs); }
};
}
Expand Down
38 changes: 38 additions & 0 deletions test/00_operators/OperatorTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1589,6 +1589,44 @@ TYPED_TEST(OperatorTestsNumericAllExecs, RemapOp)
}
}

{
// Remap as both LHS and RHS
auto in = make_tensor<TestType>({4,4,4});
auto out = make_tensor<TestType>({4,4,4});
TestType c = GenerateData<TestType>();
for (int i = 0; i < in.Size(0); i++){
for (int j = 0; j < in.Size(1); j++){
for (int k = 0; k < in.Size(2); k++){
in(i,j,k) = c;
}
}
}

auto map1 = matx::make_tensor<int>({2});
auto map2 = matx::make_tensor<int>({2});
map1(0) = 1;
map1(1) = 2;
map2(0) = 0;
map2(1) = 1;

(out = static_cast<TestType>(0)).run(exec);
(matx::remap<2>(out, map2) = matx::remap<2>(in, map1)).run(exec);
exec.sync();

for (int i = 0; i < in.Size(0); i++){
for (int j = 0; j < in.Size(1); j++){
for (int k = 0; k < in.Size(2); k++){
if (k > 1) {
ASSERT_EQ(out(i,j,k), (TestType)0);
}
else {
ASSERT_EQ(out(i,j,k), in(i,j,k));
}
}
}
}
}

MATX_EXIT_HANDLER();
}

Expand Down