Skip to content

Commit

Permalink
clean up based on pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
artv3 committed Dec 26, 2024
1 parent a1a254d commit a691f40
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 36 deletions.
2 changes: 1 addition & 1 deletion examples/reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ void checkResult(int *ptr, int K, int N, int M)
for(int n = 0; n < N; ++n) {
for(int m = 0; m < M; ++m) {
const int idx = m + M * (n + N * k);
if (std::abs(ptr[idx] - idx) > 0) {
if (ptr[idx] != idx) {
status = false;
}
}
Expand Down
75 changes: 40 additions & 35 deletions include/RAJA/util/View.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@
#ifndef RAJA_VIEW_HPP
#define RAJA_VIEW_HPP

#include <type_traits>
#include <array>
#include <type_traits>

#include "RAJA/config.hpp"

#include "RAJA/pattern/atomic.hpp"

#include "RAJA/util/IndexLayout.hpp"
#include "RAJA/util/Layout.hpp"
#include "RAJA/util/OffsetLayout.hpp"
Expand Down Expand Up @@ -297,81 +295,88 @@ RAJA_INLINE AtomicViewWrapper<ViewType, AtomicPolicy> make_atomic_view(
return RAJA::AtomicViewWrapper<ViewType, AtomicPolicy>(view);
}

struct layout_left{};
struct layout_right{};

template<typename LAYOUT>
struct Reshape;
struct layout_left {
};
struct layout_right {
};

template<typename LAYOUT>
template <typename LAYOUT>
struct Reshape;

template<typename T>
namespace detail
{
template <typename T>
constexpr auto get_last_index(T last)
{
return last;
}

template<typename T, typename... Args>
constexpr auto get_last_index(T , Args... args)
template <typename T, typename... Args>
constexpr auto get_last_index(T, Args... args)
{
return get_last_index(args...);
}
} // namespace detail

template<std::size_t...Is>
struct Reshape<std::index_sequence<Is...>>
{
template<typename T, typename...Ts>
template <std::size_t... Is>
struct Reshape<std::index_sequence<Is...>> {
template <typename T, typename... Ts>
static auto get(T *ptr, Ts... s)
{
constexpr int N = sizeof...(Ts);
std::array<RAJA::idx_t, N> extent{s...};

auto custom_layout =
RAJA::make_permuted_layout(extent, std::array<RAJA::idx_t, N>{Is...});
RAJA::make_permuted_layout(extent, std::array<RAJA::idx_t, N>{Is...});

constexpr auto unit_stride = get_last_index(Is...);
constexpr auto unit_stride = detail::get_last_index(Is...);
using view_t = RAJA::View<T, RAJA::Layout<N, RAJA::Index_type, unit_stride>>;

return RAJA::View<T, RAJA::Layout<N, RAJA::Index_type, unit_stride>>
(ptr, custom_layout);
return view_t(ptr, custom_layout);
}
};

template<>
struct Reshape<layout_right>
{
template<typename T, typename...Ts>
template <>
struct Reshape<layout_right> {
template <typename T, typename... Ts>
static auto get(T *ptr, Ts... s)
{
constexpr int N = sizeof...(Ts);
using view_t = RAJA::View<T, RAJA::Layout<N, RAJA::Index_type, N-1>>;
using view_t = RAJA::View<T, RAJA::Layout<N, RAJA::Index_type, N - 1>>;

return view_t(ptr, s...);
}
};

template<std::size_t... Is>
constexpr std::array<RAJA::idx_t, sizeof...(Is)> make_reverse_array(std::index_sequence<Is...>) {
return std::array<RAJA::idx_t, sizeof...(Is)>{sizeof...(Is) - 1U - Is ...};
}
namespace detail
{

template<>
struct Reshape<layout_left>
template <std::size_t... Is>
constexpr std::array<RAJA::idx_t, sizeof...(Is)> make_reverse_array(
std::index_sequence<Is...>)
{
template<typename T, typename...Ts>
return std::array<RAJA::idx_t, sizeof...(Is)>{sizeof...(Is) - 1U - Is...};
}

} // namespace detail

template <>
struct Reshape<layout_left> {
template <typename T, typename... Ts>
static auto get(T *ptr, Ts... s)
{
constexpr int N = sizeof...(Ts);

std::array<RAJA::idx_t, N> extent{s...};

constexpr auto reverse_array = make_reverse_array(std::make_index_sequence<N>{});
constexpr auto reverse_array =
detail::make_reverse_array(std::make_index_sequence<N>{});

auto reverse_layout = RAJA::make_permuted_layout(extent, reverse_array);
using view_t = RAJA::View<T, RAJA::Layout<N, RAJA::Index_type, 0U>>;

return RAJA::View<T, RAJA::Layout<N, RAJA::Index_type, 0U>>(ptr, reverse_layout);
return view_t(ptr, reverse_layout);
}

};

} // namespace RAJA
Expand Down

0 comments on commit a691f40

Please sign in to comment.