Skip to content

Commit

Permalink
Fixes for sparse PCA
Browse files Browse the repository at this point in the history
  • Loading branch information
Intron7 committed Jul 6, 2023
1 parent e350b99 commit 4943285
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 3 deletions.
6 changes: 4 additions & 2 deletions python/cuml/decomposition/pca.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ from cuml.internals.input_utils import input_to_cuml_array
from cuml.internals.input_utils import input_to_cupy_array
from cuml.common.array_descriptor import CumlArrayDescriptor
from cuml.common import using_output_type
from cuml.prims.stats import cov
from cuml.prims.stats import cov, cov_sparse
from cuml.internals.input_utils import sparse_scipy_to_cp
from cuml.common.exceptions import NotFittedError
from cuml.internals.mixins import FMajorInputTagMixin
Expand Down Expand Up @@ -369,7 +369,7 @@ class PCA(UniversalBase,
# NOTE: All intermediate calculations are done using cupy.ndarray and
# then converted to CumlArray at the end to minimize conversions
# between types
covariance, self.mean_, _ = cov(X, X, return_mean=True)
covariance, self.mean = cov_sparse(X)

self.explained_variance_, self.components_ = \
cp.linalg.eigh(covariance, UPLO='U')
Expand Down Expand Up @@ -428,9 +428,11 @@ class PCA(UniversalBase,
self.n_components_ = self.n_components

if cupyx.scipy.sparse.issparse(X):
X = X.tocsr()
return self._sparse_fit(X)
elif scipy.sparse.issparse(X):
X = sparse_scipy_to_cp(X, dtype=None)
X = X.tocsr()
return self._sparse_fit(X)

X_m, self.n_samples_, self.n_features_in_, self.dtype = \
Expand Down
2 changes: 1 addition & 1 deletion python/cuml/prims/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
# limitations under the License.
#

from cuml.prims.stats.covariance import cov
from cuml.prims.stats.covariance import cov, cov_sparse
80 changes: 80 additions & 0 deletions python/cuml/prims/stats/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,35 @@
}
"""

mean_cov_kernel_str =r"""
(const int *indptr,const int *index, {0} *data,int nrows,int ncols, {0} * out,{0} *mean) {
int row = blockDim.x * blockIdx.x + threadIdx.x;
if(row >= nrows){
return;
}
int start_idx = indptr[row];
int stop_idx = indptr[row+1];
for(int idx = start_idx; idx< stop_idx; idx++){
int index1 = index[idx];
{0} data1 = data[idx];
atomicAdd(&out[index1*ncols+index1],data1*data1);
atomicAdd(&mean[index1],data1);
for(int idx2 = idx+1; idx2< stop_idx; idx2++){
int index2 = index[idx2];
{0} data2 = data[idx2];
atomicAdd(&out[index1*ncols+index2],data1*data2);
}
}
}
"""


def _cov_kernel(dtype):
return cuda_kernel_factory(cov_kernel_str, (dtype,), "cov_kernel")

def _mean_cov_kernel(dtype):
return cuda_kernel_factory(mean_cov_kernel_str, (dtype,), "_mean_cov_kernel")

@cuml.internals.api_return_any()
def cov(x, y, mean_x=None, mean_y=None, return_gram=False, return_mean=False):
Expand Down Expand Up @@ -156,3 +181,58 @@ def cov(x, y, mean_x=None, mean_y=None, return_gram=False, return_mean=False):
return cov_result, mean_x, mean_y
elif return_gram and return_mean:
return cov_result, gram_matrix, mean_x, mean_y

@cuml.internals.api_return_any()
def cov_sparse(x):
"""
Computes a covariance between two matrices using
the form Cov(X, X) = E(XX) - E(X)E(X)
Parameters
----------
x : cupyx.scipy.sparse of size (m, n)
Returns
-------
result : cov(X, X), mean(X)
"""

gram_matrix = cp.zeros((x.shape[1],x.shape[1]),dtype= x.data.dtype)
mean_x = cp.zeros((x.shape[1],),dtype= x.data.dtype)

block = (8,)
grid = (
math.ceil(x.shape[0] / block[0]),
)
compute_mean_cov = _mean_cov_kernel(x.data.dtype)
compute_mean_cov(
grid,
block,
(x.indptr, x.indices, x.data, x.shape[0], x.shape[1], gram_matrix,mean_x),
)
gram_matrix = gram_matrix + gram_matrix.T
gram_matrix -= cp.diag(cp.diag(gram_matrix)/2)
gram_matrix *= (1 / x.shape[0])
mean_x *= (1 / x.shape[0])


cov_result = gram_matrix

compute_cov = _cov_kernel(x.dtype)

block_size = (8, 8)
grid_size = (
math.ceil(gram_matrix.shape[0] / 8),
math.ceil(gram_matrix.shape[1] / 8),
)

compute_cov(
grid_size,
block_size,
(cov_result, gram_matrix, mean_x, mean_x, gram_matrix.shape[0]),
)

return cov_result, mean_x

0 comments on commit 4943285

Please sign in to comment.