Skip to content

Commit

Permalink
More qualifications
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet committed Nov 21, 2023
1 parent 0453e13 commit 5727843
Show file tree
Hide file tree
Showing 45 changed files with 312 additions and 294 deletions.
27 changes: 14 additions & 13 deletions cpp/include/cuvs/cluster/detail/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void initRandom(raft::resources const& handle,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_matrix_view<DataT, IndexT> centroids)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("initRandom");
raft::common::nvtx::range<raft::common::nvtx::domain::raft> fun_scope("initRandom");
cudaStream_t stream = resource::get_cuda_stream(handle);
auto n_clusters = params.n_clusters;
detail::shuffleAndGather<DataT, IndexT>(handle, X, centroids, n_clusters, params.rng_state.seed);
Expand All @@ -93,7 +93,7 @@ void kmeansPlusPlus(raft::resources const& handle,
raft::device_matrix_view<DataT, IndexT> centroidsRawData,
rmm::device_uvector<char>& workspace)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("kmeansPlusPlus");
raft::common::nvtx::range<raft::common::nvtx::domain::raft> fun_scope("kmeansPlusPlus");
cudaStream_t stream = resource::get_cuda_stream(handle);
auto n_samples = X.extent(0);
auto n_features = X.extent(1);
Expand Down Expand Up @@ -367,7 +367,7 @@ void kmeans_fit_main(raft::resources const& handle,
raft::host_scalar_view<IndexT> n_iter,
rmm::device_uvector<char>& workspace)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("kmeans_fit_main");
raft::common::nvtx::range<raft::common::nvtx::domain::raft> fun_scope("kmeans_fit_main");
logger::get(RAFT_NAME).set_level(params.verbosity);
cudaStream_t stream = resource::get_cuda_stream(handle);
auto n_samples = X.extent(0);
Expand Down Expand Up @@ -524,7 +524,7 @@ void kmeans_fit_main(raft::resources const& handle,
workspace);

// TODO: add different templates for InType of binaryOp to avoid thrust transform
thrust::transform(resource::get_thrust_policy(handle),
thrust::transform(raft::resource::get_thrust_policy(handle),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + minClusterAndDistance.size(),
weight.data_handle(),
Expand Down Expand Up @@ -581,7 +581,8 @@ void initScalableKMeansPlusPlus(raft::resources const& handle,
raft::device_matrix_view<DataT, IndexT> centroidsRawData,
rmm::device_uvector<char>& workspace)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("initScalableKMeansPlusPlus");
raft::common::nvtx::range<raft::common::nvtx::domain::raft> fun_scope(
"initScalableKMeansPlusPlus");
cudaStream_t stream = resource::get_cuda_stream(handle);
auto n_samples = X.extent(0);
auto n_features = X.extent(1);
Expand Down Expand Up @@ -826,7 +827,7 @@ void kmeans_fit(raft::resources const& handle,
raft::host_scalar_view<DataT> inertia,
raft::host_scalar_view<IndexT> n_iter)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("kmeans_fit");
raft::common::nvtx::range<raft::common::nvtx::domain::raft> fun_scope("kmeans_fit");
auto n_samples = X.extent(0);
auto n_features = X.extent(1);
auto n_clusters = params.n_clusters;
Expand Down Expand Up @@ -872,7 +873,7 @@ void kmeans_fit(raft::resources const& handle,
if (sample_weight.has_value())
raft::copy(weight.data_handle(), sample_weight.value().data_handle(), n_samples, stream);
else
thrust::fill(resource::get_thrust_policy(handle),
thrust::fill(raft::resource::get_thrust_policy(handle),
weight.data_handle(),
weight.data_handle() + weight.size(),
1);
Expand Down Expand Up @@ -993,7 +994,7 @@ void kmeans_predict(raft::resources const& handle,
bool normalize_weight,
raft::host_scalar_view<DataT> inertia)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("kmeans_predict");
raft::common::nvtx::range<raft::common::nvtx::domain::raft> fun_scope("kmeans_predict");
auto n_samples = X.extent(0);
auto n_features = X.extent(1);
cudaStream_t stream = resource::get_cuda_stream(handle);
Expand All @@ -1019,7 +1020,7 @@ void kmeans_predict(raft::resources const& handle,
if (sample_weight.has_value())
raft::copy(weight.data_handle(), sample_weight.value().data_handle(), n_samples, stream);
else
thrust::fill(resource::get_thrust_policy(handle),
thrust::fill(raft::resource::get_thrust_policy(handle),
weight.data_handle(),
weight.data_handle() + weight.size(),
1);
Expand Down Expand Up @@ -1065,7 +1066,7 @@ void kmeans_predict(raft::resources const& handle,
// calculate cluster cost phi_x(C)
rmm::device_scalar<DataT> clusterCostD(stream);
// TODO: add different templates for InType of binaryOp to avoid thrust transform
thrust::transform(resource::get_thrust_policy(handle),
thrust::transform(raft::resource::get_thrust_policy(handle),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + minClusterAndDistance.size(),
weight.data_handle(),
Expand All @@ -1084,7 +1085,7 @@ void kmeans_predict(raft::resources const& handle,
raft::value_op{},
raft::add_op{});

thrust::transform(resource::get_thrust_policy(handle),
thrust::transform(raft::resource::get_thrust_policy(handle),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + minClusterAndDistance.size(),
labels.data_handle(),
Expand Down Expand Up @@ -1135,7 +1136,7 @@ void kmeans_fit_predict(raft::resources const& handle,
raft::host_scalar_view<DataT> inertia,
raft::host_scalar_view<IndexT> n_iter)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("kmeans_fit_predict");
raft::common::nvtx::range<raft::common::nvtx::domain::raft> fun_scope("kmeans_fit_predict");
if (!centroids.has_value()) {
auto n_features = X.extent(1);
auto centroids_matrix =
Expand Down Expand Up @@ -1199,7 +1200,7 @@ void kmeans_transform(raft::resources const& handle,
raft::device_matrix_view<const DataT> centroids,
raft::device_matrix_view<DataT> X_new)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("kmeans_transform");
raft::common::nvtx::range<raft::common::nvtx::domain::raft> fun_scope("kmeans_transform");
logger::get(RAFT_NAME).set_level(params.verbosity);
cudaStream_t stream = resource::get_cuda_stream(handle);
auto n_samples = X.extent(0);
Expand Down
16 changes: 8 additions & 8 deletions cpp/include/cuvs/cluster/detail/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ inline std::enable_if_t<std::is_floating_point_v<MathT>> predict_core(
auto minClusterAndDistance = raft::make_device_mdarray<raft::KeyValuePair<IdxT, MathT>, IdxT>(
handle, mr, make_extents<IdxT>(n_rows));
raft::KeyValuePair<IdxT, MathT> initial_value(0, std::numeric_limits<MathT>::max());
thrust::fill(resource::get_thrust_policy(handle),
thrust::fill(raft::resource::get_thrust_policy(handle),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + minClusterAndDistance.size(),
initial_value);
Expand All @@ -130,7 +130,7 @@ inline std::enable_if_t<std::is_floating_point_v<MathT>> predict_core(

// todo(lsugy): use KVP + iterator in caller.
// Copy keys to output labels
thrust::transform(resource::get_thrust_policy(handle),
thrust::transform(raft::resource::get_thrust_policy(handle),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + n_rows,
labels,
Expand Down Expand Up @@ -325,7 +325,7 @@ void compute_norm(const raft::resources& handle,
MappingOpT mapping_op,
rmm::mr::device_memory_resource* mr = nullptr)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("compute_norm");
raft::common::nvtx::range<raft::common::nvtx::domain::raft> fun_scope("compute_norm");
auto stream = resource::get_cuda_stream(handle);
if (mr == nullptr) { mr = resource::get_workspace_resource(handle); }
rmm::device_uvector<MathT> mapped_dataset(0, stream, mr);
Expand Down Expand Up @@ -381,7 +381,7 @@ void predict(const raft::resources& handle,
const MathT* dataset_norm = nullptr)
{
auto stream = resource::get_cuda_stream(handle);
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
raft::common::nvtx::range<raft::common::nvtx::domain::raft> fun_scope(
"predict(%zu, %u)", static_cast<size_t>(n_rows), n_clusters);
if (mr == nullptr) { mr = resource::get_workspace_resource(handle); }
auto [max_minibatch_size, _mem_per_row] =
Expand Down Expand Up @@ -473,7 +473,7 @@ __launch_bounds__((WarpSize * BlockDimY)) RAFT_KERNEL
const MathT wc = min(static_cast<MathT>(csize), static_cast<MathT>(kAdjustCentersWeight));
// Weight for the datapoint used to shift the center.
const MathT wd = 1.0;
for (; j < dim; j += WarpSize) {
for (; j < dim; j += raft::WarpSize) {
MathT val = 0;
val += wc * centers[j + dim * li];
val += wd * mapping_op(dataset[j + dim * i]);
Expand Down Expand Up @@ -533,7 +533,7 @@ auto adjust_centers(MathT* centers,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* device_memory) -> bool
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
raft::common::nvtx::range<raft::common::nvtx::domain::raft> fun_scope(
"adjust_centers(%zu, %u)", static_cast<size_t>(n_rows), n_clusters);
if (n_clusters == 0) { return false; }
constexpr static std::array kPrimes{29, 71, 113, 173, 229, 281, 349, 409, 463, 541,
Expand Down Expand Up @@ -901,7 +901,7 @@ auto build_fine_clusters(const raft::resources& handle,
raft::matrix::gather(mapping_itr, dim, n_rows, mc_trainset_ids, k, mc_trainset, stream);
if (params.metric == cuvs::distance::DistanceType::L2Expanded ||
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
thrust::gather(resource::get_thrust_policy(handle),
thrust::gather(raft::resource::get_thrust_policy(handle),
mc_trainset_ids,
mc_trainset_ids + k,
dataset_norm_mptr,
Expand Down Expand Up @@ -964,7 +964,7 @@ void build_hierarchical(const raft::resources& handle,
auto stream = resource::get_cuda_stream(handle);
using LabelT = uint32_t;

common::nvtx::range<common::nvtx::domain::raft> fun_scope(
raft::common::nvtx::range<raft::common::nvtx::domain::raft> fun_scope(
"build_hierarchical(%zu, %u)", static_cast<size_t>(n_rows), n_clusters);

IdxT n_mesoclusters = std::min(n_clusters, static_cast<IdxT>(std::sqrt(n_clusters) + 0.5));
Expand Down
6 changes: 3 additions & 3 deletions cpp/include/cuvs/cluster/detail/kmeans_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ void sampleCentroids(raft::resources const& handle,
resource::sync_stream(handle, stream);

uint8_t* rawPtr_isSampleCentroid = isSampleCentroid.data_handle();
thrust::for_each_n(resource::get_thrust_policy(handle),
thrust::for_each_n(raft::resource::get_thrust_policy(handle),
sampledMinClusterDistance.data_handle(),
nPtsSampledInRank,
[=] __device__(raft::KeyValuePair<ptrdiff_t, DataT> val) {
Expand Down Expand Up @@ -399,7 +399,7 @@ void minClusterAndDistanceCompute(

raft::KeyValuePair<IndexT, DataT> initial_value(0, std::numeric_limits<DataT>::max());

thrust::fill(resource::get_thrust_policy(handle),
thrust::fill(raft::resource::get_thrust_policy(handle),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + minClusterAndDistance.size(),
initial_value);
Expand Down Expand Up @@ -527,7 +527,7 @@ void minClusterDistanceCompute(raft::resources const& handle,
auto pairwiseDistance = raft::make_device_matrix_view<DataT, IndexT>(
L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize);

thrust::fill(resource::get_thrust_policy(handle),
thrust::fill(raft::resource::get_thrust_policy(handle),
minClusterDistance.data_handle(),
minClusterDistance.data_handle() + minClusterDistance.size(),
std::numeric_limits<DataT>::max());
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/cuvs/distance/detail/fused_l2_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ void initialize(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp, cudaStream_t s
}

// TODO: specialize this function for MinAndDistanceReduceOp<int, float>
// with atomicCAS of 64 bit which will eliminate mutex and shfls
// with atomicCAS of 64 bit which will eliminate mutex and raft::shfls
template <typename P, typename OutT, typename IdxT, typename KVPair, typename ReduceOpT>
DI void updateReducedVal(
int* mutex, OutT* min, KVPair* val, ReduceOpT red_op, IdxT m, IdxT gridStrideY)
Expand Down Expand Up @@ -204,7 +204,7 @@ __launch_bounds__(P::Nthreads, 2) RAFT_KERNEL fusedL2NNkernel(OutT* min,
#pragma unroll
for (int j = P::AccThCols / 2; j > 0; j >>= 1) {
// Actually, the srcLane (lid +j) should be (lid +j) % P:AccThCols,
// but the shfl op applies the modulo internally.
// but the raft::shfl op applies the modulo internally.
auto tmpkey = raft::shfl(val[i].key, lid + j, P::AccThCols);
auto tmpvalue = raft::shfl(val[i].value, lid + j, P::AccThCols);
KVPair tmp = {tmpkey, tmpvalue};
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/cuvs/distance/detail/masked_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ void masked_l2_nn_impl(raft::resources const& handle,

// Get stream and workspace memory resource
rmm::mr::device_memory_resource* ws_mr =
dynamic_cast<rmm::mr::device_memory_resource*>(resource::get_workspace_resource(handle));
dynamic_cast<rmm::mr::device_memory_resource*>(raft::resource::get_workspace_resource(handle));
auto stream = resource::get_cuda_stream(handle);

// Acquire temporary buffers and initialize to zero:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ struct PairwiseDistances : public BaseClass {

DI void accumulate()
{
// We have a separate ldsXY and accumulate_reg_tile outside the loop body,
// We have a separate raft::ldsXY and accumulate_reg_tile outside the loop body,
// so that these separated calls can be interspersed with preceding and
// following instructions, thereby hiding latency.
this->ldsXY(0);
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/cuvs/distance/distance-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ void distance(raft::resources const& handle,
raft::device_matrix_view<OutT, IdxT, layout> dist,
DataT metric_arg = 2.0f) RAFT_EXPLICIT;

template <typename Type, typename layout = layout_c_contiguous, typename IdxT = int>
template <typename Type, typename layout = raft::layout_c_contiguous, typename IdxT = int>
void pairwise_distance(raft::resources const& handle,
device_matrix_view<Type, IdxT, layout> const x,
device_matrix_view<Type, IdxT, layout> const y,
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/cuvs/neighbors/brute_force-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ void knn(raft::resources const& handle,
RAFT_EXPECTS(indices.extent(1) == distances.extent(1) && distances.extent(1),
"Number of columns in output indices and distances matrices must the same");

bool rowMajorIndex = std::is_same_v<index_layout, layout_c_contiguous>;
bool rowMajorQuery = std::is_same_v<search_layout, layout_c_contiguous>;
bool rowMajorIndex = std::is_same_v<index_layout, raft::layout_c_contiguous>;
bool rowMajorQuery = std::is_same_v<search_layout, raft::layout_c_contiguous>;

std::vector<value_t*> inputs;
std::vector<matrix_idx> sizes;
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/cuvs/neighbors/cagra_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ struct index : ann::index {
* @endcode
* In the above example, we have passed a host dataset to build. The returned index will own a
* device copy of the dataset and the knn_graph. In contrast, if we pass the dataset as a
* device_mdspan to build, then it will only store a reference to it.
* raft::device_mdspan to build, then it will only store a reference to it.
*
* - Constructing index using existing knn-graph
* @code{.cpp}
Expand Down
9 changes: 5 additions & 4 deletions cpp/include/cuvs/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ void build_knn_graph(
"Currently only L2Expanded metric is supported");

uint32_t node_degree = knn_graph.extent(1);
common::nvtx::range<common::nvtx::domain::raft> fun_scope("cagra::build_graph(%zu, %zu, %u)",
size_t(dataset.extent(0)),
size_t(dataset.extent(1)),
node_degree);
raft::common::nvtx::range<raft::common::nvtx::domain::raft> fun_scope(
"cagra::build_graph(%zu, %zu, %u)",
size_t(dataset.extent(0)),
size_t(dataset.extent(1)),
node_degree);

if (!build_params) {
build_params = ivf_pq::index_params{};
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/cuvs/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ void search_main(raft::resources const& res,

if (params.max_queries == 0) { params.max_queries = queries.extent(0); }

common::nvtx::range<common::nvtx::domain::raft> fun_scope(
raft::common::nvtx::range<raft::common::nvtx::domain::raft> fun_scope(
"cagra::search(max_queries = %u, k = %u, dim = %zu)", params.max_queries, topk, index.dim());

using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector<CagraSampleFilterT>::type;
Expand Down
7 changes: 4 additions & 3 deletions cpp/include/cuvs/neighbors/detail/cagra/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ void serialize(raft::resources const& res,
const index<T, IdxT>& index_,
bool include_dataset)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("cagra::serialize");
raft::common::nvtx::range<raft::common::nvtx::domain::raft> fun_scope("cagra::serialize");

RAFT_LOG_DEBUG(
"Saving CAGRA index, size %zu, dim %u", static_cast<size_t>(index_.size()), index_.dim());
Expand Down Expand Up @@ -103,7 +103,8 @@ void serialize_to_hnswlib(raft::resources const& res,
std::ostream& os,
const index<T, IdxT>& index_)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("cagra::serialize_to_hnswlib");
raft::common::nvtx::range<raft::common::nvtx::domain::raft> fun_scope(
"cagra::serialize_to_hnswlib");
RAFT_LOG_DEBUG("Saving CAGRA index to hnswlib format, size %zu, dim %u",
static_cast<size_t>(index_.size()),
index_.dim());
Expand Down Expand Up @@ -233,7 +234,7 @@ void serialize_to_hnswlib(raft::resources const& res,
template <typename T, typename IdxT>
auto deserialize(raft::resources const& res, std::istream& is) -> index<T, IdxT>
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("cagra::deserialize");
raft::common::nvtx::range<raft::common::nvtx::domain::raft> fun_scope("cagra::deserialize");

char dtype_string[4];
is.read(dtype_string, 4);
Expand Down
Loading

0 comments on commit 5727843

Please sign in to comment.