Skip to content

Commit

Permalink
Merge pull request #24 from yasahi-hpc/optimization-executors
Browse files Browse the repository at this point in the history
Optimization executors
  • Loading branch information
yasahi-hpc authored Jul 2, 2023
2 parents 52843f4 + b0c9d75 commit 22537e3
Show file tree
Hide file tree
Showing 8 changed files with 282 additions and 112 deletions.
199 changes: 135 additions & 64 deletions lib/cuda_linalg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,23 +220,70 @@ namespace Impl {
return cusolverDnDsyevjBatched(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params, batchSize);
}

/*
* Batched matrix matrix product
* Matrix shape
* A (n, m, l), B (m, k, l), C (n, k, l)
* */
struct blasHandle_t {
cublasHandle_t handle_;

public:
void create() {
cublasCreate(&handle_);
}

void destroy() {
cublasDestroy(handle_);
}
};

template <class T>
struct syevjHandle_t {
cusolverDnHandle_t handle_;
thrust::device_vector<T> workspace_;
thrust::device_vector<int> info_;
syevjInfo_t params_;
cusolverEigMode_t jobz_ = CUSOLVER_EIG_MODE_VECTOR;
cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_LOWER;

public:
template <class MatrixView, class VectorView,
std::enable_if_t<MatrixView::rank()==3 && VectorView::rank()==2, std::nullptr_t> = nullptr>
void create(MatrixView& a, VectorView& v, T tol=1.0e-7, int max_sweeps=100, int sort_eig=0) {
cusolverDnCreate(&handle_);

cusolverDnCreateSyevjInfo(&params_);
cusolverDnXsyevjSetTolerance(params_, tol);
cusolverDnXsyevjSetMaxSweeps(params_, max_sweeps);
cusolverDnXsyevjSetSortEig(params_, sort_eig);

const int batchSize = v.extent(1);
int lwork = 0;
syevjBatched_bufferSize(
handle_,
jobz_,
uplo_,
a.extent(0), a.data_handle(),
v.extent(0), v.data_handle(),
&lwork,
params_,
batchSize
);
workspace_.resize(lwork);
info_.resize(batchSize, 0);
}

void destroy() {
cusolverDnDestroy(handle_);
}
};

template <class ViewA, class ViewB, class ViewC,
std::enable_if_t<ViewA::rank()==3 && ViewB::rank()==3 && ViewC::rank()==3, std::nullptr_t> = nullptr>
void matrix_matrix_product(const ViewA& A,
void matrix_matrix_product(const blasHandle_t& blas_handle,
const ViewA& A,
const ViewB& B,
ViewC& C,
std::string _transa,
std::string _transb,
typename ViewA::value_type alpha = 1,
typename ViewA::value_type beta = 0) {
cublasHandle_t handle;
cublasCreate(&handle);

cublasOperation_t transa = _transa == "N" ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t transb = _transb == "N" ? CUBLAS_OP_N : CUBLAS_OP_T;

Expand All @@ -252,7 +299,7 @@ namespace Impl {
const auto Bk = _transb == "N" ? B.extent(0) : B.extent(1);
assert(Ak == Bk);

auto status = gemmStridedBatched(handle,
auto status = gemmStridedBatched(blas_handle.handle_,
transa,
transb,
Cm,
Expand All @@ -271,9 +318,28 @@ namespace Impl {
C.extent(0) * C.extent(1),
C.extent(2)
);
cublasDestroy(handle);
}


/*
* Batched matrix matrix product
* Matrix shape
* A (n, m, l), B (m, k, l), C (n, k, l)
* */
template <class ViewA, class ViewB, class ViewC,
std::enable_if_t<ViewA::rank()==3 && ViewB::rank()==3 && ViewC::rank()==3, std::nullptr_t> = nullptr>
void matrix_matrix_product(const ViewA& A,
const ViewB& B,
ViewC& C,
std::string _transa,
std::string _transb,
typename ViewA::value_type alpha = 1,
typename ViewA::value_type beta = 0) {
Impl::blasHandle_t blas_handle;
blas_handle.create();
matrix_matrix_product(blas_handle, A, B, C, _transa, _transb, alpha, beta);
blas_handle.destroy();
}

/*
* Batched matrix vector product
* Matrix shape
Expand All @@ -282,15 +348,13 @@ namespace Impl {
* */
template <class ViewA, class ViewB, class ViewC,
std::enable_if_t<ViewA::rank()==3 && ViewB::rank()==2 && ViewC::rank()==2, std::nullptr_t> = nullptr>
void matrix_vector_product(const ViewA& A,
void matrix_vector_product(const blasHandle_t& blas_handle,
const ViewA& A,
const ViewB& B,
ViewC& C,
std::string _transa,
typename ViewA::value_type alpha = 1
) {
cublasHandle_t handle;
cublasCreate(&handle);

cublasOperation_t transa = _transa == "N" ? CUBLAS_OP_N : CUBLAS_OP_T;

const auto Cm = C.extent(0);
Expand All @@ -303,7 +367,7 @@ namespace Impl {

using value_type = ViewA::value_type;
const value_type beta = 0;
auto status = gemmStridedBatched(handle,
auto status = gemmStridedBatched(blas_handle.handle_,
transa,
CUBLAS_OP_N,
Cm,
Expand All @@ -322,7 +386,26 @@ namespace Impl {
C.extent(0),
C.extent(1)
);
cublasDestroy(handle);
}

/*
* Batched matrix vector product
* Matrix shape
* A (n, m, l), B (m, l), C (n, l)
* C = A * B
* */
template <class ViewA, class ViewB, class ViewC,
std::enable_if_t<ViewA::rank()==3 && ViewB::rank()==2 && ViewC::rank()==2, std::nullptr_t> = nullptr>
void matrix_vector_product(const ViewA& A,
const ViewB& B,
ViewC& C,
std::string _transa,
typename ViewA::value_type alpha = 1
) {
Impl::blasHandle_t blas_handle;
blas_handle.create();
matrix_vector_product(blas_handle, A, B, C, _transa, alpha);
blas_handle.destroy();
}

/*
Expand All @@ -332,70 +415,52 @@ namespace Impl {
* v (m, l)
* w (m, m, l)
* */
template <class MatrixView, class VectorView,
template <class Handle, class MatrixView, class VectorView,
std::enable_if_t<MatrixView::rank()==3 && VectorView::rank()==2, std::nullptr_t> = nullptr>
void eig(MatrixView& a, VectorView& v) {
void eig(const Handle& syevj_handle, MatrixView& a, VectorView& v) {
static_assert( std::is_same_v<typename MatrixView::value_type, typename VectorView::value_type> );
static_assert( std::is_same_v<typename MatrixView::layout_type, typename VectorView::layout_type> );

using value_type = MatrixView::value_type;
assert(a.extent(0) == v.extent(0));
assert(a.extent(0) == a.extent(1)); // Square array
assert(a.extent(2) == v.extent(1)); // batch size

cusolverDnHandle_t handle;
cusolverDnCreate(&handle);

syevjInfo_t params;
constexpr value_type tol = 1.0e-7;
constexpr int max_sweeps = 100;
constexpr int sort_eig = 0;
const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR;
const cublasFillMode_t uplo = CUBLAS_FILL_MODE_LOWER;

cusolverDnCreateSyevjInfo(&params);
cusolverDnXsyevjSetTolerance(params, tol);
cusolverDnXsyevjSetMaxSweeps(params, max_sweeps);
cusolverDnXsyevjSetSortEig(params, sort_eig);

const int batchSize = v.extent(1);
int lwork = 0;
syevjBatched_bufferSize(
handle,
jobz,
uplo,
a.extent(0), a.data_handle(),
v.extent(0), v.data_handle(),
&lwork,
params,
batchSize
);

thrust::device_vector<value_type> workspace(lwork);
thrust::device_vector<int> info(batchSize, 0);
value_type* workspace_data = (value_type *)thrust::raw_pointer_cast(workspace.data());
int* info_data = (int *)thrust::raw_pointer_cast(info.data());
value_type* workspace_data = (value_type *)thrust::raw_pointer_cast(syevj_handle.workspace_.data());
int* info_data = (int *)thrust::raw_pointer_cast(syevj_handle.info_.data());

auto status = syevjBatched(
handle,
jobz,
uplo,
syevj_handle.handle_,
syevj_handle.jobz_,
syevj_handle.uplo_,
a.extent(0), a.data_handle(),
v.extent(0), v.data_handle(),
workspace_data, workspace.size(),
workspace_data, syevj_handle.workspace_.size(),
info_data,
params,
syevj_handle.params_,
batchSize
);

cudaDeviceSynchronize();
cusolverDnDestroy(handle);
}

template <class MatrixView, class VectorView,
std::enable_if_t<MatrixView::rank()==3 && VectorView::rank()==2, std::nullptr_t> = nullptr>
void eig(MatrixView& a, VectorView& v) {
static_assert( std::is_same_v<typename MatrixView::value_type, typename VectorView::value_type> );
static_assert( std::is_same_v<typename MatrixView::layout_type, typename VectorView::layout_type> );

using value_type = MatrixView::value_type;
Impl::syevjHandle_t<value_type> syevj_handle;
syevj_handle.create(a, v);
eig(syevj_handle, a, v);
syevj_handle.destroy();
}

// 2D transpose
template <class InputView, class OutputView,
std::enable_if_t<InputView::rank()==2 && OutputView::rank()==2, std::nullptr_t> = nullptr>
void transpose(const InputView& in, OutputView& out) {
void transpose(const blasHandle_t& blas_handle, const InputView& in, OutputView& out) {
static_assert( std::is_same_v<typename InputView::value_type, typename OutputView::value_type> );
static_assert( std::is_same_v<typename InputView::layout_type, typename OutputView::layout_type> );
static_assert( std::is_same_v<typename InputView::layout_type, stdex::layout_left> );
Expand All @@ -406,11 +471,9 @@ namespace Impl {
using value_type = InputView::value_type;
constexpr value_type alpha = 1;
constexpr value_type beta = 0;
// transpose by cublas
cublasHandle_t handle;
cublasCreate(&handle);

geam(handle,
// transpose by cublas
geam(blas_handle.handle_,
CUBLAS_OP_T,
CUBLAS_OP_T,
in.extent(1),
Expand All @@ -424,8 +487,16 @@ namespace Impl {
out.data_handle(),
out.extent(0)
);
}

cublasDestroy(handle);
// 2D transpose
template <class InputView, class OutputView,
std::enable_if_t<InputView::rank()==2 && OutputView::rank()==2, std::nullptr_t> = nullptr>
void transpose(const InputView& in, OutputView& out) {
Impl::blasHandle_t blas_handle;
blas_handle.create();
transpose(blas_handle, in, out);
blas_handle.destroy();
}
};

Expand Down
18 changes: 15 additions & 3 deletions lib/executors/Transpose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace Impl {
/* Transpose batched matrix */
template <class InputView, class OutputView,
std::enable_if_t<InputView::rank()==3 && OutputView::rank()==3, std::nullptr_t> = nullptr>
void transpose(const InputView& in, OutputView& out, const std::array<int, 3>& axes) {
void transpose(const blasHandle_t& blas_handle, const InputView& in, OutputView& out, const std::array<int, 3>& axes) {
static_assert( std::is_same_v<typename InputView::value_type, typename OutputView::value_type> );
static_assert( std::is_same_v<typename InputView::layout_type, typename OutputView::layout_type> );
using value_type = InputView::value_type;
Expand Down Expand Up @@ -57,7 +57,7 @@ namespace Impl {

const mdspan2d_type sub_in(in.data_handle(), in_shape);
mdspan2d_type sub_out(out.data_handle(), out_shape);
transpose(sub_in, sub_out);
transpose(blas_handle, sub_in, sub_out);
} else if(axes == axes_type({2, 0, 1})) {
using mdspan2d_type = stdex::mdspan<value_type, stdex::dextents<size_type, 2>, layout_type>;
using extent2d_type = std::array<std::size_t, 2>;
Expand All @@ -66,7 +66,7 @@ namespace Impl {

const mdspan2d_type sub_in(in.data_handle(), in_shape);
mdspan2d_type sub_out(out.data_handle(), out_shape);
transpose(sub_in, sub_out);
transpose(blas_handle, sub_in, sub_out);
} else if(axes == axes_type({2, 1, 0})) {
Impl::for_each(policy3d,
[=](const int i0, const int i1, const int i2) {
Expand All @@ -76,6 +76,18 @@ namespace Impl {
std::runtime_error("Invalid axes specified.");
}
}

/* Transpose batched matrix */
template <class InputView, class OutputView,
std::enable_if_t<InputView::rank()==3 && OutputView::rank()==3, std::nullptr_t> = nullptr>
void transpose(const InputView& in, OutputView& out, const std::array<int, 3>& axes) {
static_assert( std::is_same_v<typename InputView::value_type, typename OutputView::value_type> );
static_assert( std::is_same_v<typename InputView::layout_type, typename OutputView::layout_type> );
Impl::blasHandle_t blas_handle;
blas_handle.create();
transpose(blas_handle, in, out, axes);
blas_handle.destroy();
}
};

#endif
Loading

0 comments on commit 22537e3

Please sign in to comment.