Skip to content

Commit

Permalink
FIX style fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dantegd committed Feb 27, 2024
1 parent d6929a9 commit 3fbd19d
Show file tree
Hide file tree
Showing 13 changed files with 42 additions and 32 deletions.
2 changes: 1 addition & 1 deletion conda/recipes/cuvs/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
# Copyright (c) 2022-2024, NVIDIA CORPORATION.

# Usage:
# conda build . -c conda-forge -c numba -c rapidsai -c pytorch
Expand Down
5 changes: 1 addition & 4 deletions python/cuvs/cuvs/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,4 @@

from .temp_raft import auto_sync_resources


__all__ = [
"auto_sync_resources"
]
__all__ = ["auto_sync_resources"]
2 changes: 1 addition & 1 deletion python/cuvs/cuvs/common/c_api.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
# cython: language_level=3


from libc.stdint cimport uintptr_t
from cuda.ccudart cimport cudaStream_t
from libc.stdint cimport uintptr_t


cdef extern from "cuvs/core/c_api.h":
Expand Down
1 change: 1 addition & 0 deletions python/cuvs/cuvs/common/cydlpack.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from libc.stdint cimport int32_t, int64_t, uint8_t, uint16_t, uint64_t


cdef extern from 'dlpack.h' nogil:
ctypedef enum DLDeviceType:
kDLCPU
Expand Down
3 changes: 1 addition & 2 deletions python/cuvs/cuvs/common/cydlpack.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ cdef void deleter(DLManagedTensor* tensor) noexcept:


cdef DLManagedTensor dlpack_c(ary):
#todo(dgd): add checking options/parameters
# todo(dgd): add checking options/parameters
cdef DLDeviceType dev_type
cdef DLDevice dev
cdef DLDataType dtype
Expand Down Expand Up @@ -65,7 +65,6 @@ cdef DLManagedTensor dlpack_c(ary):
else:
tensor_ptr = ary.__array_interface__["data"][0]


tensor.data = <void*> tensor_ptr
tensor.device = dev
tensor.dtype = dtype
Expand Down
1 change: 0 additions & 1 deletion python/cuvs/cuvs/common/temp_raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from pylibraft.common import DeviceResources


_resources_param_string = """
handle : Optional RAFT resource handle for reusing CUDA resources.
If a handle isn't supplied, CUDA resources will be
Expand Down
2 changes: 1 addition & 1 deletion python/cuvs/cuvs/neighbors/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# =============================================================================
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
# Copyright (c) 2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
Expand Down
5 changes: 1 addition & 4 deletions python/cuvs/cuvs/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,4 @@

from cuvs.neighbors import cagra

__all__ = [
"common",
"cagra"
]
__all__ = ["common", "cagra"]
2 changes: 1 addition & 1 deletion python/cuvs/cuvs/neighbors/cagra/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# =============================================================================
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
Expand Down
17 changes: 10 additions & 7 deletions python/cuvs/cuvs/neighbors/cagra/cagra.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@
#
# cython: language_level=3

from libc.stdint cimport int8_t, int64_t, uint8_t, uint32_t, uint64_t, uintptr_t

from cuvs.common.cydlpack cimport DLDataType, DLManagedTensor
from libc.stdint cimport (
int8_t,
int64_t,
uint8_t,
uint32_t,
uint64_t,
uintptr_t,
)

from cuvs.common.c_api cimport cuvsError_t, cuvsResources_t
from cuvs.common.cydlpack cimport DLDataType, DLManagedTensor


cdef extern from "cuvs/neighbors/cagra_c.h" nogil:
Expand All @@ -28,14 +34,12 @@ cdef extern from "cuvs/neighbors/cagra_c.h" nogil:
IVF_PQ
NN_DESCENT


ctypedef struct cagraIndexParams:
size_t intermediate_graph_degree
size_t graph_degree
cagraGraphBuildAlgo build_algo
size_t nn_descent_niter


ctypedef enum cagraSearchAlgo:
SINGLE_CTA,
MULTI_CTA,
Expand Down Expand Up @@ -75,12 +79,11 @@ cdef extern from "cuvs/neighbors/cagra_c.h" nogil:
cuvsError_t cagraBuild(cuvsResources_t res,
cagraIndexParams* params,
DLManagedTensor* dataset,
cagraIndex_t index);
cagraIndex_t index)

cuvsError_t cagraSearch(cuvsResources_t res,
cagraSearchParams* params,
cagraIndex_t index,
DLManagedTensor* queries,
DLManagedTensor* neighbors,
DLManagedTensor* distances)

30 changes: 22 additions & 8 deletions python/cuvs/cuvs/neighbors/cagra/cagra.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
# cython: language_level=3

import numpy as np

cimport cuvs.common.cydlpack

from cuvs.common.temp_raft import auto_sync_resources
from cuvs.common cimport cydlpack

from cython.operator cimport dereference as deref

from cuvs.common cimport cydlpack

from pylibraft.common import (
DeviceResources,
auto_convert_output,
Expand All @@ -31,12 +33,19 @@ from pylibraft.common import (
)
from pylibraft.common.cai_wrapper import wrap_array
from pylibraft.common.interruptible import cuda_interruptible

from pylibraft.neighbors.common import _check_input_array
from cuvs.common.c_api cimport cuvsError_t, cuvsResources_t

from libc.stdint cimport (
int8_t,
int64_t,
uint8_t,
uint32_t,
uint64_t,
uintptr_t,
)
from pylibraft.common.handle cimport device_resources

from libc.stdint cimport int8_t, int64_t, uint8_t, uint32_t, uint64_t, uintptr_t
from cuvs.common.c_api cimport cuvsError_t, cuvsResources_t


cdef class IndexParams:
Expand Down Expand Up @@ -186,7 +195,8 @@ def build_index(IndexParams index_params, dataset, resources=None):

cdef Index idx = Index()
cdef cuvsError_t build_status
cdef cydlpack.DLManagedTensor dataset_dlpack = cydlpack.dlpack_c(dataset_ai)
cdef cydlpack.DLManagedTensor dataset_dlpack = \
cydlpack.dlpack_c(dataset_ai)
cdef cagraIndexParams* params = &index_params.params

with cuda_interruptible():
Expand Down Expand Up @@ -364,6 +374,7 @@ cdef class SearchParams:
def rand_xor_mask(self):
return self.params.rand_xor_mask


@auto_sync_resources
@auto_convert_output
def search(SearchParams search_params,
Expand Down Expand Up @@ -457,9 +468,12 @@ def search(SearchParams search_params,
exp_rows=n_queries, exp_cols=k)

cdef cagraSearchParams* params = &search_params.params
cdef cydlpack.DLManagedTensor queries_dlpack = cydlpack.dlpack_c(queries_cai)
cdef cydlpack.DLManagedTensor neighbors_dlpack = cydlpack.dlpack_c(neighbors_cai)
cdef cydlpack.DLManagedTensor distances_dlpack = cydlpack.dlpack_c(distances_cai)
cdef cydlpack.DLManagedTensor queries_dlpack = \
cydlpack.dlpack_c(queries_cai)
cdef cydlpack.DLManagedTensor neighbors_dlpack = \
cydlpack.dlpack_c(neighbors_cai)
cdef cydlpack.DLManagedTensor distances_dlpack = \
cydlpack.dlpack_c(distances_cai)

with cuda_interruptible():
cagraSearch(
Expand Down
2 changes: 1 addition & 1 deletion python/cuvs/cuvs/test/test_cagra.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

import numpy as np
import pytest
from pylibraft.common import device_ndarray
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import normalize

from pylibraft.common import device_ndarray
from cuvs.neighbors import cagra
from cuvs.test.ann_utils import calc_recall, generate_data

Expand Down
2 changes: 1 addition & 1 deletion python/cuvs/cuvs/test/test_doctests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down

0 comments on commit 3fbd19d

Please sign in to comment.