Skip to content

Commit

Permalink
Merge (#964): Create CSR submatrix from Index sets
Browse files Browse the repository at this point in the history
This PR adds functionality to create submatrices from IndexSet objects, allowing one to create submatrices not only from contiguous spans, but also with dis-contiguous sets of indices.

Some index set related changes were also made:
+ Make index set a non-polymorphic class.
+ Rename IndexSet to index_set

Related PR: #964
  • Loading branch information
pratikvn authored Mar 30, 2022
2 parents 4430fb8 + 43545a0 commit 6df4a68
Show file tree
Hide file tree
Showing 26 changed files with 1,324 additions and 346 deletions.
6 changes: 3 additions & 3 deletions common/unified/base/index_set_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ namespace gko {
namespace kernels {
namespace GKO_DEVICE_NAMESPACE {
/**
* @brief The IndexSet namespace.
* @brief The index_set namespace.
*
* @ingroup index_set
*/
namespace index_set {
namespace idx_set {


template <typename IndexType>
Expand All @@ -68,7 +68,7 @@ GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(
GKO_DECLARE_INDEX_SET_COMPUTE_VALIDITY_KERNEL);


} // namespace index_set
} // namespace idx_set
} // namespace GKO_DEVICE_NAMESPACE
} // namespace kernels
} // namespace gko
58 changes: 31 additions & 27 deletions core/base/index_set.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,40 +47,40 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


namespace gko {
namespace index_set {
namespace idx_set {


GKO_REGISTER_OPERATION(to_global_indices, index_set::to_global_indices);
GKO_REGISTER_OPERATION(populate_subsets, index_set::populate_subsets);
GKO_REGISTER_OPERATION(global_to_local, index_set::global_to_local);
GKO_REGISTER_OPERATION(local_to_global, index_set::local_to_global);
GKO_REGISTER_OPERATION(to_global_indices, idx_set::to_global_indices);
GKO_REGISTER_OPERATION(populate_subsets, idx_set::populate_subsets);
GKO_REGISTER_OPERATION(global_to_local, idx_set::global_to_local);
GKO_REGISTER_OPERATION(local_to_global, idx_set::local_to_global);


} // namespace index_set
} // namespace idx_set


template <typename IndexType>
void IndexSet<IndexType>::populate_subsets(const gko::Array<IndexType>& indices,
const bool is_sorted)
void index_set<IndexType>::populate_subsets(
const gko::Array<IndexType>& indices, const bool is_sorted)
{
auto exec = this->get_executor();
this->num_stored_indices_ = indices.get_num_elems();
exec->run(index_set::make_populate_subsets(
exec->run(idx_set::make_populate_subsets(
this->index_space_size_, &indices, &this->subsets_begin_,
&this->subsets_end_, &this->superset_cumulative_indices_, is_sorted));
}


template <typename IndexType>
bool IndexSet<IndexType>::contains(const IndexType input_index) const
bool index_set<IndexType>::contains(const IndexType input_index) const
{
auto local_index = this->get_local_index(input_index);
return local_index != invalid_index<IndexType>();
}


template <typename IndexType>
IndexType IndexSet<IndexType>::get_global_index(const IndexType index) const
IndexType index_set<IndexType>::get_global_index(const IndexType index) const
{
auto exec = this->get_executor();
const auto local_idx =
Expand All @@ -93,7 +93,7 @@ IndexType IndexSet<IndexType>::get_global_index(const IndexType index) const


template <typename IndexType>
IndexType IndexSet<IndexType>::get_local_index(const IndexType index) const
IndexType index_set<IndexType>::get_local_index(const IndexType index) const
{
auto exec = this->get_executor();
const auto global_idx =
Expand All @@ -106,56 +106,60 @@ IndexType IndexSet<IndexType>::get_local_index(const IndexType index) const


template <typename IndexType>
Array<IndexType> IndexSet<IndexType>::to_global_indices() const
Array<IndexType> index_set<IndexType>::to_global_indices() const
{
auto exec = this->get_executor();
auto num_elems = exec->copy_val_to_host(
this->superset_cumulative_indices_.get_const_data() +
this->superset_cumulative_indices_.get_num_elems() - 1);
auto decomp_indices = gko::Array<IndexType>(exec, num_elems);
exec->run(index_set::make_to_global_indices(
this->index_space_size_, &this->subsets_begin_, &this->subsets_end_,
&this->superset_cumulative_indices_, &decomp_indices));
exec->run(idx_set::make_to_global_indices(
this->get_num_subsets(), this->get_subsets_begin(),
this->get_subsets_end(), this->get_superset_indices(),
decomp_indices.get_data()));

return decomp_indices;
}


template <typename IndexType>
Array<IndexType> IndexSet<IndexType>::map_local_to_global(
Array<IndexType> index_set<IndexType>::map_local_to_global(
const Array<IndexType>& local_indices, const bool is_sorted) const
{
auto exec = this->get_executor();
auto global_indices =
gko::Array<IndexType>(exec, local_indices.get_num_elems());

GKO_ASSERT(this->get_num_subsets() >= 1);
exec->run(index_set::make_local_to_global(
this->index_space_size_, &this->subsets_begin_, &this->subsets_end_,
&this->superset_cumulative_indices_, &local_indices, &global_indices,
is_sorted));
exec->run(idx_set::make_local_to_global(
this->get_num_subsets(), this->get_subsets_begin(),
this->get_superset_indices(),
static_cast<IndexType>(local_indices.get_num_elems()),
local_indices.get_const_data(), global_indices.get_data(), is_sorted));
return global_indices;
}


template <typename IndexType>
Array<IndexType> IndexSet<IndexType>::map_global_to_local(
Array<IndexType> index_set<IndexType>::map_global_to_local(
const Array<IndexType>& global_indices, const bool is_sorted) const
{
auto exec = this->get_executor();
auto local_indices =
gko::Array<IndexType>(exec, global_indices.get_num_elems());

GKO_ASSERT(this->get_num_subsets() >= 1);
exec->run(index_set::make_global_to_local(
this->index_space_size_, &this->subsets_begin_, &this->subsets_end_,
&this->superset_cumulative_indices_, &global_indices, &local_indices,
is_sorted));
exec->run(idx_set::make_global_to_local(
this->index_space_size_, this->get_num_subsets(),
this->get_subsets_begin(), this->get_subsets_end(),
this->get_superset_indices(),
static_cast<IndexType>(local_indices.get_num_elems()),
global_indices.get_const_data(), local_indices.get_data(), is_sorted));
return local_indices;
}


#define GKO_DECLARE_INDEX_SET(_type) class IndexSet<_type>
#define GKO_DECLARE_INDEX_SET(_type) class index_set<_type>
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_INDEX_SET);


Expand Down
49 changes: 22 additions & 27 deletions core/base/index_set_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,11 @@ namespace kernels {
Array<bool>* validity_array)

#define GKO_DECLARE_INDEX_SET_TO_GLOBAL_INDICES_KERNEL(IndexType) \
void to_global_indices(std::shared_ptr<const DefaultExecutor> exec, \
const IndexType index_space_size, \
const Array<IndexType>* subset_begin, \
const Array<IndexType>* subset_end, \
const Array<IndexType>* superset_indices, \
Array<IndexType>* decomp_indices)
void to_global_indices( \
std::shared_ptr<const DefaultExecutor> exec, \
const IndexType num_subsets, const IndexType* subset_begin, \
const IndexType* subset_end, const IndexType* superset_indices, \
IndexType* decomp_indices)

#define GKO_DECLARE_INDEX_SET_POPULATE_KERNEL(IndexType) \
void populate_subsets( \
Expand All @@ -67,25 +66,22 @@ namespace kernels {
Array<IndexType>* subset_begin, Array<IndexType>* subset_end, \
Array<IndexType>* superset_indices, const bool is_sorted)

#define GKO_DECLARE_INDEX_SET_GLOBAL_TO_LOCAL_KERNEL(IndexType) \
void global_to_local(std::shared_ptr<const DefaultExecutor> exec, \
const IndexType index_space_size, \
const Array<IndexType>* subset_begin, \
const Array<IndexType>* subset_end, \
const Array<IndexType>* superset_indices, \
const Array<IndexType>* global_indices, \
Array<IndexType>* local_indices, \
const bool is_sorted)

#define GKO_DECLARE_INDEX_SET_LOCAL_TO_GLOBAL_KERNEL(IndexType) \
void local_to_global(std::shared_ptr<const DefaultExecutor> exec, \
const IndexType index_space_size, \
const Array<IndexType>* subset_begin, \
const Array<IndexType>* subset_end, \
const Array<IndexType>* superset_indices, \
const Array<IndexType>* local_indices, \
Array<IndexType>* global_indices, \
const bool is_sorted)
#define GKO_DECLARE_INDEX_SET_GLOBAL_TO_LOCAL_KERNEL(IndexType) \
void global_to_local( \
std::shared_ptr<const DefaultExecutor> exec, \
const IndexType index_space_size, const IndexType num_subsets, \
const IndexType* subset_begin, const IndexType* subset_end, \
const IndexType* superset_indices, const IndexType num_indices, \
const IndexType* global_indices, IndexType* local_indices, \
const bool is_sorted)

#define GKO_DECLARE_INDEX_SET_LOCAL_TO_GLOBAL_KERNEL(IndexType) \
void local_to_global( \
std::shared_ptr<const DefaultExecutor> exec, \
const IndexType num_subsets, const IndexType* subset_begin, \
const IndexType* superset_indices, const IndexType num_indices, \
const IndexType* local_indices, IndexType* global_indices, \
const bool is_sorted)


#define GKO_DECLARE_ALL_AS_TEMPLATES \
Expand All @@ -101,8 +97,7 @@ namespace kernels {
GKO_DECLARE_INDEX_SET_LOCAL_TO_GLOBAL_KERNEL(IndexType)


GKO_DECLARE_FOR_ALL_EXECUTOR_NAMESPACES(index_set,
GKO_DECLARE_ALL_AS_TEMPLATES);
GKO_DECLARE_FOR_ALL_EXECUTOR_NAMESPACES(idx_set, GKO_DECLARE_ALL_AS_TEMPLATES);


#undef GKO_DECLARE_ALL_AS_TEMPLATES
Expand Down
8 changes: 6 additions & 2 deletions core/device_hooks/common_kernels.inc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_CONVERT_PTRS_TO_SIZES);
} // namespace components


namespace index_set {
namespace idx_set {


GKO_STUB_INDEX_TYPE(GKO_DECLARE_INDEX_SET_COMPUTE_VALIDITY_KERNEL);
Expand All @@ -224,7 +224,7 @@ GKO_STUB_INDEX_TYPE(GKO_DECLARE_INDEX_SET_GLOBAL_TO_LOCAL_KERNEL);
GKO_STUB_INDEX_TYPE(GKO_DECLARE_INDEX_SET_LOCAL_TO_GLOBAL_KERNEL);


} // namespace index_set
} // namespace idx_set


namespace partition {
Expand Down Expand Up @@ -494,9 +494,13 @@ GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_SORT_BY_COLUMN_INDEX);
GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_IS_SORTED_BY_COLUMN_INDEX);
GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_EXTRACT_DIAGONAL);
GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_CALC_NNZ_PER_ROW_IN_SPAN_KERNEL);
GKO_STUB_VALUE_AND_INDEX_TYPE(
GKO_DECLARE_CSR_CALC_NNZ_PER_ROW_IN_INDEX_SET_KERNEL);
GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_COMPUTE_SUB_MATRIX_KERNEL);
GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_CHECK_DIAGONAL_ENTRIES_EXIST);
GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_ADD_SCALED_IDENTITY_KERNEL);
GKO_STUB_VALUE_AND_INDEX_TYPE(
GKO_DECLARE_CSR_COMPUTE_SUB_MATRIX_FROM_INDEX_SET_KERNEL);

template <typename ValueType, typename IndexType>
GKO_DECLARE_CSR_SCALE_KERNEL(ValueType, IndexType)
Expand Down
51 changes: 51 additions & 0 deletions core/matrix/csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/base/array.hpp>
#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/executor.hpp>
#include <ginkgo/core/base/index_set.hpp>
#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/base/precision_dispatch.hpp>
#include <ginkgo/core/base/utils.hpp>
Expand Down Expand Up @@ -83,7 +84,11 @@ GKO_REGISTER_OPERATION(compute_hybrid_coo_row_ptrs,
GKO_REGISTER_OPERATION(convert_to_hybrid, csr::convert_to_hybrid);
GKO_REGISTER_OPERATION(calculate_nonzeros_per_row_in_span,
csr::calculate_nonzeros_per_row_in_span);
GKO_REGISTER_OPERATION(calculate_nonzeros_per_row_in_index_set,
csr::calculate_nonzeros_per_row_in_index_set);
GKO_REGISTER_OPERATION(compute_submatrix, csr::compute_submatrix);
GKO_REGISTER_OPERATION(compute_submatrix_from_index_set,
csr::compute_submatrix_from_index_set);
GKO_REGISTER_OPERATION(transpose, csr::transpose);
GKO_REGISTER_OPERATION(conj_transpose, csr::conj_transpose);
GKO_REGISTER_OPERATION(inv_symm_permute, csr::inv_symm_permute);
Expand Down Expand Up @@ -612,6 +617,52 @@ Csr<ValueType, IndexType>::create_submatrix(const gko::span& row_span,
}


template <typename ValueType, typename IndexType>
std::unique_ptr<Csr<ValueType, IndexType>>
Csr<ValueType, IndexType>::create_submatrix(
const index_set<IndexType>& row_index_set,
const index_set<IndexType>& col_index_set) const
{
using Mat = Csr<ValueType, IndexType>;
auto exec = this->get_executor();
if (!row_index_set.get_num_elems() || !col_index_set.get_num_elems()) {
return Mat::create(exec);
}
if (row_index_set.is_contiguous() && col_index_set.is_contiguous()) {
auto row_st = row_index_set.get_executor()->copy_val_to_host(
row_index_set.get_subsets_begin());
auto row_end = row_index_set.get_executor()->copy_val_to_host(
row_index_set.get_subsets_end());
auto col_st = col_index_set.get_executor()->copy_val_to_host(
col_index_set.get_subsets_begin());
auto col_end = col_index_set.get_executor()->copy_val_to_host(
col_index_set.get_subsets_end());

return this->create_submatrix(span(row_st, row_end),
span(col_st, col_end));
} else {
auto submat_num_rows = row_index_set.get_num_elems();
auto submat_num_cols = col_index_set.get_num_elems();
auto sub_mat_size = gko::dim<2>(submat_num_rows, submat_num_cols);
Array<IndexType> row_ptrs(exec, submat_num_rows + 1);
exec->run(csr::make_calculate_nonzeros_per_row_in_index_set(
this, row_index_set, col_index_set, row_ptrs.get_data()));
exec->run(
csr::make_prefix_sum(row_ptrs.get_data(), submat_num_rows + 1));
auto num_nnz =
exec->copy_val_to_host(row_ptrs.get_data() + sub_mat_size[0]);
auto sub_mat = Mat::create(exec, sub_mat_size,
std::move(Array<ValueType>(exec, num_nnz)),
std::move(Array<IndexType>(exec, num_nnz)),
std::move(row_ptrs), this->get_strategy());
exec->run(csr::make_compute_submatrix_from_index_set(
this, row_index_set, col_index_set, sub_mat.get()));
sub_mat->make_srow();
return sub_mat;
}
}


template <typename ValueType, typename IndexType>
std::unique_ptr<Diagonal<ValueType>>
Csr<ValueType, IndexType>::extract_diagonal() const
Expand Down
24 changes: 24 additions & 0 deletions core/matrix/csr_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


#include <ginkgo/core/base/array.hpp>
#include <ginkgo/core/base/index_set.hpp>
#include <ginkgo/core/base/types.hpp>
#include <ginkgo/core/matrix/coo.hpp>
#include <ginkgo/core/matrix/dense.hpp>
Expand Down Expand Up @@ -165,12 +166,29 @@ namespace kernels {
const matrix::Csr<ValueType, IndexType>* source, const span& row_span, \
const span& col_span, Array<IndexType>* row_nnz)

#define GKO_DECLARE_CSR_CALC_NNZ_PER_ROW_IN_INDEX_SET_KERNEL(ValueType, \
IndexType) \
void calculate_nonzeros_per_row_in_index_set( \
std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Csr<ValueType, IndexType>* source, \
const gko::index_set<IndexType>& row_index_set, \
const gko::index_set<IndexType>& col_index_set, IndexType* row_nnz)

#define GKO_DECLARE_CSR_COMPUTE_SUB_MATRIX_KERNEL(ValueType, IndexType) \
void compute_submatrix(std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Csr<ValueType, IndexType>* source, \
gko::span row_span, gko::span col_span, \
matrix::Csr<ValueType, IndexType>* result)

#define GKO_DECLARE_CSR_COMPUTE_SUB_MATRIX_FROM_INDEX_SET_KERNEL(ValueType, \
IndexType) \
void compute_submatrix_from_index_set( \
std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Csr<ValueType, IndexType>* source, \
const gko::index_set<IndexType>& row_index_set, \
const gko::index_set<IndexType>& col_index_set, \
matrix::Csr<ValueType, IndexType>* result)

#define GKO_DECLARE_CSR_SORT_BY_COLUMN_INDEX(ValueType, IndexType) \
void sort_by_column_index(std::shared_ptr<const DefaultExecutor> exec, \
matrix::Csr<ValueType, IndexType>* to_sort)
Expand Down Expand Up @@ -247,6 +265,12 @@ namespace kernels {
template <typename ValueType, typename IndexType> \
GKO_DECLARE_CSR_COMPUTE_SUB_MATRIX_KERNEL(ValueType, IndexType); \
template <typename ValueType, typename IndexType> \
GKO_DECLARE_CSR_CALC_NNZ_PER_ROW_IN_INDEX_SET_KERNEL(ValueType, \
IndexType); \
template <typename ValueType, typename IndexType> \
GKO_DECLARE_CSR_COMPUTE_SUB_MATRIX_FROM_INDEX_SET_KERNEL(ValueType, \
IndexType); \
template <typename ValueType, typename IndexType> \
GKO_DECLARE_CSR_SORT_BY_COLUMN_INDEX(ValueType, IndexType); \
template <typename ValueType, typename IndexType> \
GKO_DECLARE_CSR_IS_SORTED_BY_COLUMN_INDEX(ValueType, IndexType); \
Expand Down
Loading

0 comments on commit 6df4a68

Please sign in to comment.