From 2f1157047c4db32897bcfe2c477c55c0708a34ec Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Thu, 29 Aug 2024 00:45:20 -0400 Subject: [PATCH 01/10] Update rapidsai/pre-commit-hooks (#6048) This PR updates rapidsai/pre-commit-hooks to the version 0.4.0. Authors: - Kyle Edwards (https://github.com/KyleFromNVIDIA) - https://github.com/jakirkham Approvers: - James Lamb (https://github.com/jameslamb) - https://github.com/jakirkham URL: https://github.com/rapidsai/cuml/pull/6048 --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bd25832ab7..af2b63a5f6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -60,7 +60,7 @@ repos: pass_filenames: false language: python - repo: https://github.com/rapidsai/pre-commit-hooks - rev: v0.3.1 + rev: v0.4.0 hooks: - id: verify-copyright files: | From d00c1b1bd0b0431a5d1ddcee6ee3ed32896b6778 Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Tue, 3 Sep 2024 11:33:12 -0700 Subject: [PATCH 02/10] Fix compiler warning about signed vs unsigned ints (#6053) ``` src/cpp/src/fil/treelite_import.cu:496:26: warning: comparison of integer expressions of different signedness: 'size_t' {aka 'long unsigned int'} and 'const int' [-Wsign-compare] 103.2 496 | ASSERT(leaf_vec_size == model.num_class[0], "treelite model inconsistent"); 103.2 | ~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~ 103.2 /rapids_triton/build/_deps/cuml-src/cpp/src/fil/treelite_import.cu:516:40: warning: comparison of integer expressions of different signedness: 'const int' and 'size_t' {aka 'long unsigned int'} [-Wsign-compare] 103.2 516 | ASSERT(model.class_id[tree_id] == tree_id % static_cast(model.num_class[0]), 103.2 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ``` Authors: - Philip Hyunsu Cho (https://github.com/hcho3) Approvers: - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/cuml/pull/6053 --- cpp/src/fil/treelite_import.cu | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cpp/src/fil/treelite_import.cu b/cpp/src/fil/treelite_import.cu index bc3a13abb8..2a584c0095 100644 --- a/cpp/src/fil/treelite_import.cu +++ b/cpp/src/fil/treelite_import.cu @@ -490,10 +490,11 @@ void tl2fil_common(forest_params_t* params, ASSERT(model.num_target == 1, "FIL does not support multi-target models"); // assuming either all leaves use the .leaf_vector() or all leaves use .leaf_value() - size_t leaf_vec_size = tl_leaf_vector_size(model); + std::size_t leaf_vec_size = tl_leaf_vector_size(model); std::string pred_transform(model.postprocessor); if (leaf_vec_size > 0) { - ASSERT(leaf_vec_size == model.num_class[0], "treelite model inconsistent"); + ASSERT(leaf_vec_size == static_cast(model.num_class[0]), + "treelite model inconsistent"); params->num_classes = leaf_vec_size; params->leaf_algo = leaf_algo_t::VECTOR_LEAF; @@ -513,7 +514,8 @@ void tl2fil_common(forest_params_t* params, // Ensure that the trees follow the grove-per-class layout. for (size_t tree_id = 0; tree_id < model_preset.trees.size(); ++tree_id) { ASSERT(model.target_id[tree_id] == 0, "FIL does not support multi-target models"); - ASSERT(model.class_id[tree_id] == tree_id % static_cast(model.num_class[0]), + ASSERT(static_cast(model.class_id[tree_id]) == + tree_id % static_cast(model.num_class[0]), "The tree model is not compatible with FIL; the trees must be laid out " "such that tree i's output contributes towards class (i %% num_class)."); } From fbae844585cd312b45eaa0c057a74f9b85639b1e Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Tue, 3 Sep 2024 12:34:35 -0700 Subject: [PATCH 03/10] Update README in experimental FIL (#6052) `treelite::frontend::LoadXGBoostModel` is no longer present in the latest Treelite; use `treelite::model_loader::LoadXGBoostModelJSON` instead. Authors: - Philip Hyunsu Cho (https://github.com/hcho3) Approvers: - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/cuml/pull/6052 --- cpp/include/cuml/experimental/fil/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/cuml/experimental/fil/README.md b/cpp/include/cuml/experimental/fil/README.md index e195e6cd64..48d4a4ab16 100644 --- a/cpp/include/cuml/experimental/fil/README.md +++ b/cpp/include/cuml/experimental/fil/README.md @@ -39,7 +39,7 @@ similar load methods for each of the serialization formats it supports. ```cpp auto filename = "xgboost.json"; -auto tl_model = treelite::frontend::LoadXGBoostModel(filename); +auto tl_model = treelite::model_loader::LoadXGBoostModelJSON(filename, "{}"); ``` We then import the Treelite model into FIL via the From d87b0cedb755be7f1a488a294856f3b9669e242b Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Wed, 4 Sep 2024 10:02:56 -0500 Subject: [PATCH 04/10] Fix np.NAN to np.nan. (#6056) NumPy 2 requires `np.nan` instead of `np.NAN`. This appeared as a nightly test failure: ``` FAILED test_kernel_ridge.py::test_estimator - AttributeError: module 'numpy' has no attribute 'NAN' ``` Referencing the NumPy 2 changelog: https://numpy.org/devdocs/release/2.0.0-notes.html > Alias np.NaN has been removed. Use np.nan instead. Authors: - Bradley Dice (https://github.com/bdice) Approvers: - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/cuml/pull/6056 --- python/cuml/cuml/tests/test_kernel_ridge.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cuml/cuml/tests/test_kernel_ridge.py b/python/cuml/cuml/tests/test_kernel_ridge.py index d5534b5662..23148e7907 100644 --- a/python/cuml/cuml/tests/test_kernel_ridge.py +++ b/python/cuml/cuml/tests/test_kernel_ridge.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -46,7 +46,7 @@ def gradient_norm(model, X, y, K, sw=None): ).reshape(y.shape) # initialise to NaN in case below loop has 0 iterations - grads = cp.full_like(y, np.NAN) + grads = cp.full_like(y, np.nan) for i, (beta, target, current_alpha) in enumerate( zip(betas.T, y.T, model.alpha) ): From 8d062385d11beca378d40ae6d6cd5dc07073366b Mon Sep 17 00:00:00 2001 From: Divye Gala Date: Mon, 9 Sep 2024 12:23:31 -0400 Subject: [PATCH 05/10] Enable GPU `fit` and CPU `transform` in UMAP (#6032) Authors: - Divye Gala (https://github.com/divyegala) Approvers: - Victor Lafargue (https://github.com/viclafargue) URL: https://github.com/rapidsai/cuml/pull/6032 --- python/cuml/cuml/manifold/umap.pyx | 65 ++++++++++++++++++- .../cuml/cuml/tests/test_device_selection.py | 4 -- 2 files changed, 63 insertions(+), 6 deletions(-) diff --git a/python/cuml/cuml/manifold/umap.pyx b/python/cuml/cuml/manifold/umap.pyx index 86933ab31b..3aad3e76d3 100644 --- a/python/cuml/cuml/manifold/umap.pyx +++ b/python/cuml/cuml/manifold/umap.pyx @@ -16,9 +16,11 @@ # distutils: language = c++ -from cuml.internals.safe_imports import cpu_only_import +from cuml.internals.safe_imports import cpu_only_import, safe_import_from np = cpu_only_import('numpy') pd = cpu_only_import('pandas') +nearest_neighbors = safe_import_from('umap.umap_', 'nearest_neighbors') +DISCONNECTION_DISTANCES = safe_import_from('umap.umap_', 'DISCONNECTION_DISTANCES') import joblib import warnings @@ -627,6 +629,8 @@ class UMAP(UniversalBase, _knn_dists_ptr = knn_dists.ptr _knn_indices_ptr = knn_indices.ptr + self._knn_dists = knn_dists + self._knn_indices = knn_indices self.n_neighbors = min(self.n_rows, self.n_neighbors) @@ -853,6 +857,60 @@ class UMAP(UniversalBase, del X_m return embedding + @property + def _n_neighbors(self): + return self.n_neighbors + + @_n_neighbors.setter + def _n_neighbors(self, value): + self.n_neighbors = value + + @property + def _a(self): + return self.a + + @_a.setter + def _a(self, value): + self.a = value + + @property + def _b(self): + return self.b + + @_b.setter + def _b(self, value): + self.b = value + + @property + def _initial_alpha(self): + return self.learning_rate + + @_initial_alpha.setter + def _initial_alpha(self, value): + self.learning_rate = value + + @property + def _disconnection_distance(self): + self.disconnection_distance = DISCONNECTION_DISTANCES.get(self.metric, np.inf) + return self.disconnection_distance + + @_disconnection_distance.setter + def _disconnection_distance(self, value): + self.disconnection_distance = value + + def gpu_to_cpu(self): + if hasattr(self, 'knn_dists') and hasattr(self, 'knn_indices'): + self._knn_dists = self.knn_dists + self._knn_indices = self.knn_indices + self._knn_search_index = None + elif hasattr(self, '_raw_data'): + self._raw_data = self._raw_data.to_output('numpy') + self._knn_dists, self._knn_indices, self._knn_search_index = \ + nearest_neighbors(self._raw_data, self.n_neighbors, self.metric, + self.metric_kwds, False, self.random_state) + + super().gpu_to_cpu() + def get_param_names(self): return super().get_param_names() + [ "n_neighbors", @@ -883,4 +941,7 @@ class UMAP(UniversalBase, ] def get_attr_names(self): - return ['_raw_data', 'embedding_', '_input_hash', '_small_data'] + return ['_raw_data', 'embedding_', '_input_hash', '_small_data', + '_knn_dists', '_knn_indices', '_knn_search_index', + '_disconnection_distance', '_n_neighbors', '_a', '_b', + '_initial_alpha'] diff --git a/python/cuml/cuml/tests/test_device_selection.py b/python/cuml/cuml/tests/test_device_selection.py index 1da3b0738e..6c7d1852c1 100644 --- a/python/cuml/cuml/tests/test_device_selection.py +++ b/python/cuml/cuml/tests/test_device_selection.py @@ -596,8 +596,6 @@ def test_train_cpu_infer_cpu(test_data): def test_train_gpu_infer_cpu(test_data): cuEstimator = test_data["cuEstimator"] - if cuEstimator is UMAP: - pytest.skip("UMAP GPU training CPU inference not yet implemented") model = cuEstimator(**test_data["kwargs"]) with using_device_type("gpu"): @@ -655,8 +653,6 @@ def test_pickle_interop(tmp_path, test_data): pickle_filepath = tmp_path / "model.pickle" cuEstimator = test_data["cuEstimator"] - if cuEstimator is UMAP: - pytest.skip("UMAP GPU training CPU inference not yet implemented") model = cuEstimator(**test_data["kwargs"]) with using_device_type("gpu"): if "y_train" in test_data: From 488ed576f43dd8fe950cdc379b1b8b50486376fb Mon Sep 17 00:00:00 2001 From: Divye Gala Date: Tue, 10 Sep 2024 12:49:29 -0400 Subject: [PATCH 06/10] TSNE CPU/GPU Interop (#6063) As TSNE has no standalone `transform` function, this PR simply adds the ability to `fit/fit_transform` a CPU TSNE model Authors: - Divye Gala (https://github.com/divyegala) Approvers: - Victor Lafargue (https://github.com/viclafargue) URL: https://github.com/rapidsai/cuml/pull/6063 --- python/cuml/cuml/manifold/t_sne.pyx | 34 +++++++++++++++++-- .../cuml/cuml/tests/test_device_selection.py | 21 +++++++++++- 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/python/cuml/cuml/manifold/t_sne.pyx b/python/cuml/cuml/manifold/t_sne.pyx index 264722af76..d230ee8467 100644 --- a/python/cuml/cuml/manifold/t_sne.pyx +++ b/python/cuml/cuml/manifold/t_sne.pyx @@ -27,10 +27,13 @@ cupy = gpu_only_import('cupy') import cuml.internals from cuml.common.array_descriptor import CumlArrayDescriptor -from cuml.internals.base import Base +from cuml.internals.base import UniversalBase from pylibraft.common.handle cimport handle_t +from cuml.internals.api_decorators import device_interop_preparation +from cuml.internals.api_decorators import enable_device_interop import cuml.internals.logger as logger + from cuml.internals.array import CumlArray from cuml.internals.array_sparse import SparseCumlArray from cuml.common.sparse_utils import is_sparse @@ -115,7 +118,7 @@ cdef extern from "cuml/manifold/tsne.h" namespace "ML": float* kl_div) except + -class TSNE(Base, +class TSNE(UniversalBase, CMajorInputTagMixin): """ t-SNE (T-Distributed Stochastic Neighbor Embedding) is an extremely @@ -263,9 +266,11 @@ class TSNE(Base, """ + _cpu_estimator_import_path = 'sklearn.manifold.TSNE' X_m = CumlArrayDescriptor() embedding_ = CumlArrayDescriptor() + @device_interop_preparation def __init__(self, *, n_components=2, perplexity=30.0, @@ -405,6 +410,7 @@ class TSNE(Base, @generate_docstring(skip_parameters_heading=True, X='dense_sparse', convert_dtype_cast='np.float32') + @enable_device_interop def fit(self, X, convert_dtype=True, knn_graph=None) -> "TSNE": """ Fit X into an embedded space. @@ -444,6 +450,8 @@ class TSNE(Base, if convert_dtype else None)) + self.n_features_in_ = p + if n <= 1: raise ValueError("There needs to be more than 1 sample to build " "nearest the neighbors graph") @@ -561,6 +569,7 @@ class TSNE(Base, low-dimensional space.', 'shape': '(n_samples, n_components)'}) @cuml.internals.api_base_fit_transform() + @enable_device_interop def fit_transform(self, X, convert_dtype=True, knn_graph=None) -> CumlArray: """ @@ -648,6 +657,22 @@ class TSNE(Base, def kl_divergence_(self, value): self._kl_divergence_ = value + @property + def learning_rate_(self): + return self.learning_rate + + @learning_rate_.setter + def learning_rate_(self, value): + self.learning_rate = value + + @property + def n_iter_(self): + return self.n_iter + + @n_iter_.setter + def n_iter_(self, value): + self.n_iter = value + def __del__(self): if hasattr(self, "embedding_"): @@ -690,3 +715,8 @@ class TSNE(Base, "square_distances", "precomputed_knn" ] + + def get_attr_names(self): + return ["embedding", "kl_divergence_", + "n_features_in_", "learning_rate_", + "n_iter_"] diff --git a/python/cuml/cuml/tests/test_device_selection.py b/python/cuml/cuml/tests/test_device_selection.py index 6c7d1852c1..449c032161 100644 --- a/python/cuml/cuml/tests/test_device_selection.py +++ b/python/cuml/cuml/tests/test_device_selection.py @@ -21,7 +21,10 @@ from cuml.neighbors import NearestNeighbors from cuml.metrics import trustworthiness from cuml.metrics import adjusted_rand_score -from cuml.manifold import UMAP +from cuml.manifold import ( + UMAP, + TSNE, +) from cuml.linear_model import ( ElasticNet, Lasso, @@ -48,6 +51,7 @@ from sklearn.cluster import KMeans as skKMeans from sklearn.cluster import DBSCAN as skDBSCAN from sklearn.datasets import make_regression, make_blobs +from sklearn.manifold import TSNE as refTSNE from pytest_cases import fixture_union, fixture from importlib import import_module import inspect @@ -857,6 +861,21 @@ def test_umap_methods(device): assert ref_trust - tol <= trust <= ref_trust + tol +@pytest.mark.parametrize("device", ["cpu", "gpu"]) +def test_tsne_methods(device): + ref_model = refTSNE() + ref_embedding = ref_model.fit_transform(X_train_blob) + ref_trust = trustworthiness(X_train_blob, ref_embedding, n_neighbors=12) + + model = TSNE(n_neighbors=12) + with using_device_type(device): + embedding = model.fit_transform(X_train_blob) + trust = trustworthiness(X_train_blob, embedding, n_neighbors=12) + + tol = 0.02 + assert trust >= ref_trust - tol + + @pytest.mark.parametrize("train_device", ["cpu", "gpu"]) @pytest.mark.parametrize("infer_device", ["cpu", "gpu"]) def test_pca_methods(train_device, infer_device): From 28aa837c0e9e37164d08402c0b0972065374ed49 Mon Sep 17 00:00:00 2001 From: Victor Lafargue Date: Wed, 11 Sep 2024 01:28:39 +0200 Subject: [PATCH 07/10] Update UMAP doc (#6064) Closes https://github.com/rapidsai/cuml/issues/6062 Authors: - Victor Lafargue (https://github.com/viclafargue) Approvers: - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/cuml/pull/6064 --- python/cuml/cuml/dask/manifold/umap.py | 6 ++---- python/cuml/cuml/manifold/umap.pyx | 4 +--- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/python/cuml/cuml/dask/manifold/umap.py b/python/cuml/cuml/dask/manifold/umap.py index 9af1047050..181bfb0728 100644 --- a/python/cuml/cuml/dask/manifold/umap.py +++ b/python/cuml/cuml/dask/manifold/umap.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -83,9 +83,7 @@ class UMAP(BaseEstimator, DelayedTransformMixin): In addition to these missing features, you should expect to see the final embeddings differing between `cuml.umap` and the reference - UMAP. In particular, the reference UMAP uses an approximate kNN - algorithm for large data sizes while cuml.umap always uses exact - kNN. + UMAP. **Known issue:** If a UMAP model has not yet been fit, it cannot be pickled diff --git a/python/cuml/cuml/manifold/umap.pyx b/python/cuml/cuml/manifold/umap.pyx index 3aad3e76d3..260b32ee6b 100644 --- a/python/cuml/cuml/manifold/umap.pyx +++ b/python/cuml/cuml/manifold/umap.pyx @@ -315,9 +315,7 @@ class UMAP(UniversalBase, In addition to these missing features, you should expect to see the final embeddings differing between cuml.umap and the reference - UMAP. In particular, the reference UMAP uses an approximate kNN - algorithm for large data sizes while cuml.umap always uses exact - kNN. + UMAP. References ---------- From 90d06228a55dddc8cab11009c2c10ceeee8e8824 Mon Sep 17 00:00:00 2001 From: Victor Lafargue Date: Wed, 11 Sep 2024 17:49:39 +0200 Subject: [PATCH 08/10] Fix for `simplicial_set_embedding` (#6043) Answers https://github.com/rapidsai/cuml/issues/6041 and https://github.com/rapidsai/cuml/issues/6035 Authors: - Victor Lafargue (https://github.com/viclafargue) - Divye Gala (https://github.com/divyegala) Approvers: - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/cuml/pull/6043 --- cpp/include/cuml/manifold/umap.hpp | 21 ++++++ cpp/src/umap/runner.cuh | 19 +++++ cpp/src/umap/umap.cu | 14 ++++ python/cuml/cuml/manifold/simpl_set.pyx | 96 +++++++++++++++++------- python/cuml/cuml/tests/test_simpl_set.py | 12 +-- 5 files changed, 130 insertions(+), 32 deletions(-) diff --git a/cpp/include/cuml/manifold/umap.hpp b/cpp/include/cuml/manifold/umap.hpp index 62a875e685..7de08c5488 100644 --- a/cpp/include/cuml/manifold/umap.hpp +++ b/cpp/include/cuml/manifold/umap.hpp @@ -84,6 +84,27 @@ void refine(const raft::handle_t& handle, UMAPParams* params, float* embeddings); +/** + * Initializes embeddings and performs a UMAP fit on them, which enables + * iterative fitting without callbacks. + * + * @param[in] handle: raft::handle_t + * @param[in] X: pointer to input array + * @param[in] n: n_samples of input array + * @param[in] d: n_features of input array + * @param[in] graph: pointer to raft::sparse::COO object computed using ML::UMAP::get_graph + * @param[in] params: pointer to ML::UMAPParams object + * @param[out] embeddings: pointer to current embedding with shape n * n_components, stores updated + * embeddings on executing refine + */ +void init_and_refine(const raft::handle_t& handle, + float* X, + int n, + int d, + raft::sparse::COO* graph, + UMAPParams* params, + float* embeddings); + /** * Dense fit * diff --git a/cpp/src/umap/runner.cuh b/cpp/src/umap/runner.cuh index 41bac31678..0ceeb3acaa 100644 --- a/cpp/src/umap/runner.cuh +++ b/cpp/src/umap/runner.cuh @@ -247,12 +247,31 @@ void _refine(const raft::handle_t& handle, value_t* embeddings) { cudaStream_t stream = handle.get_stream(); + ML::Logger::get().setLevel(params->verbosity); + /** * Run simplicial set embedding to approximate low-dimensional representation */ SimplSetEmbed::run(inputs.n, inputs.d, graph, params, embeddings, stream); } +template +void _init_and_refine(const raft::handle_t& handle, + const umap_inputs& inputs, + UMAPParams* params, + raft::sparse::COO* graph, + value_t* embeddings) +{ + cudaStream_t stream = handle.get_stream(); + ML::Logger::get().setLevel(params->verbosity); + + // Initialize embeddings + InitEmbed::run(handle, inputs.n, inputs.d, graph, params, embeddings, stream, params->init); + + // Run simplicial set embedding + SimplSetEmbed::run(inputs.n, inputs.d, graph, params, embeddings, stream); +} + template void _fit(const raft::handle_t& handle, const umap_inputs& inputs, diff --git a/cpp/src/umap/umap.cu b/cpp/src/umap/umap.cu index 86799ae6bc..899051f8de 100644 --- a/cpp/src/umap/umap.cu +++ b/cpp/src/umap/umap.cu @@ -92,6 +92,20 @@ void refine(const raft::handle_t& handle, handle, inputs, params, graph, embeddings); } +void init_and_refine(const raft::handle_t& handle, + float* X, + int n, + int d, + raft::sparse::COO* graph, + UMAPParams* params, + float* embeddings) +{ + CUML_LOG_DEBUG("Calling UMAP::init_and_refine() with precomputed KNN"); + manifold_dense_inputs_t inputs(X, nullptr, n, d); + UMAPAlgo::_init_and_refine, TPB_X>( + handle, inputs, params, graph, embeddings); +} + void fit(const raft::handle_t& handle, float* X, float* y, diff --git a/python/cuml/cuml/manifold/simpl_set.pyx b/python/cuml/cuml/manifold/simpl_set.pyx index f22f524bf7..b0be2d5de7 100644 --- a/python/cuml/cuml/manifold/simpl_set.pyx +++ b/python/cuml/cuml/manifold/simpl_set.pyx @@ -16,6 +16,7 @@ # distutils: language = c++ +import warnings from cuml.internals.safe_imports import cpu_only_import np = cpu_only_import('numpy') from cuml.internals.safe_imports import gpu_only_import @@ -26,7 +27,7 @@ from cuml.manifold.umap_utils cimport * from cuml.manifold.umap_utils import GraphHolder, find_ab_params, \ metric_parsing -from cuml.internals.input_utils import input_to_cuml_array +from cuml.internals.input_utils import input_to_cuml_array, is_array_like from cuml.internals.array import CumlArray from pylibraft.common.handle cimport handle_t @@ -56,6 +57,14 @@ cdef extern from "cuml/manifold/umap.hpp" namespace "ML::UMAP": UMAPParams* params, float* embeddings) + void init_and_refine(handle_t &handle, + float* X, + int n, + int d, + COO* cgraph_coo, + UMAPParams* params, + float* embeddings) + def fuzzy_simplicial_set(X, n_neighbors, @@ -73,6 +82,7 @@ def fuzzy_simplicial_set(X, locally approximating geodesic distance at each point, creating a fuzzy simplicial set for each such point, and then combining all the local fuzzy simplicial sets into a global one via a fuzzy union. + Parameters ---------- X: array of shape (n_samples, n_features) @@ -212,7 +222,7 @@ def simplicial_set_embedding( initial_alpha=1.0, a=None, b=None, - repulsion_strength=1.0, + gamma=1.0, negative_sample_rate=5, n_epochs=None, init="spectral", @@ -221,6 +231,7 @@ def simplicial_set_embedding( metric_kwds=None, output_metric="euclidean", output_metric_kwds=None, + repulsion_strength=None, convert_dtype=True, verbose=False, ): @@ -228,6 +239,7 @@ def simplicial_set_embedding( initialisation method and then minimizing the fuzzy set cross entropy between the 1-skeletons of the high and low dimensional fuzzy simplicial sets. + Parameters ---------- data: array of shape (n_samples, n_features) @@ -244,7 +256,7 @@ def simplicial_set_embedding( Parameter of differentiable approximation of right adjoint functor b: float Parameter of differentiable approximation of right adjoint functor - repulsion_strength: float + gamma: float Weight to apply to negative samples. negative_sample_rate: int (optional, default 5) The number of negative samples to select per positive sample @@ -260,7 +272,7 @@ def simplicial_set_embedding( How to initialize the low dimensional embedding. Options are: * 'spectral': use a spectral embedding of the fuzzy 1-skeleton * 'random': assign initial embedding positions at random. - * A numpy array of initial embedding positions. + * An array-like with initial embedding positions. random_state: numpy RandomState or equivalent A state capable being used as a numpy random state. metric: string (default='euclidean'). @@ -294,9 +306,6 @@ def simplicial_set_embedding( if output_metric_kwds is None: output_metric_kwds = {} - if init not in ['spectral', 'random']: - raise Exception("Initialization strategy not supported: %d" % init) - if output_metric not in ['euclidean', 'categorical']: raise Exception("Invalid output metric: {}" % output_metric) @@ -320,17 +329,29 @@ def simplicial_set_embedding( cdef UMAPParams* umap_params = new UMAPParams() umap_params.n_components = n_components umap_params.initial_alpha = initial_alpha - umap_params.a = a - umap_params.b = b - umap_params.repulsion_strength = repulsion_strength + umap_params.a = a + umap_params.b = b + + if repulsion_strength: + gamma = repulsion_strength + warnings.simplefilter(action="always", category=FutureWarning) + warnings.warn('Parameter "repulsion_strength" has been' + ' deprecated. It will be removed in version 24.12.' + ' Please use the "gamma" parameter instead.', + FutureWarning) + + umap_params.repulsion_strength = gamma umap_params.negative_sample_rate = negative_sample_rate umap_params.n_epochs = n_epochs - if init == 'spectral': - umap_params.init = 1 - else: # init == 'random' - umap_params.init = 0 umap_params.random_state = random_state umap_params.deterministic = deterministic + if isinstance(init, str): + if init == "random": + umap_params.init = 0 + elif init == 'spectral': + umap_params.init = 1 + else: + raise ValueError("Invalid initialization strategy") try: umap_params.metric = metric_parsing[metric.lower()] except KeyError: @@ -344,7 +365,7 @@ def simplicial_set_embedding( else: # output_metric == 'categorical' umap_params.target_metric = MetricType.CATEGORICAL umap_params.target_weight = output_metric_kwds['p'] \ - if 'p' in output_metric_kwds else 0 + if 'p' in output_metric_kwds else 0.5 umap_params.verbosity = verbose X_m, _, _, _ = \ @@ -365,17 +386,40 @@ def simplicial_set_embedding( handle, graph) - embedding = CumlArray.zeros((X_m.shape[0], n_components), - order="C", dtype=np.float32, - index=X_m.index) - - refine(handle_[0], - X_m.ptr, - X_m.shape[0], - X_m.shape[1], - fss_graph.get(), - umap_params, - embedding.ptr) + if isinstance(init, str): + if init in ['spectral', 'random']: + embedding = CumlArray.zeros((X_m.shape[0], n_components), + order="C", dtype=np.float32, + index=X_m.index) + init_and_refine(handle_[0], + X_m.ptr, + X_m.shape[0], + X_m.shape[1], + fss_graph.get(), + umap_params, + embedding.ptr) + else: + raise ValueError("Invalid initialization strategy") + elif is_array_like(init): + embedding, _, _, _ = \ + input_to_cuml_array(init, + order='C', + convert_to_dtype=(np.float32 if convert_dtype + else None), + check_dtype=np.float32, + check_rows=X_m.shape[0], + check_cols=n_components) + refine(handle_[0], + X_m.ptr, + X_m.shape[0], + X_m.shape[1], + fss_graph.get(), + umap_params, + embedding.ptr) + else: + raise ValueError( + "Initialization not supported. Please provide a valid " + "initialization strategy or a pre-initialized embedding.") free(umap_params) diff --git a/python/cuml/cuml/tests/test_simpl_set.py b/python/cuml/cuml/tests/test_simpl_set.py index cbc5ebc635..7f55155a9f 100644 --- a/python/cuml/cuml/tests/test_simpl_set.py +++ b/python/cuml/cuml/tests/test_simpl_set.py @@ -24,6 +24,7 @@ import pytest from cuml.datasets import make_blobs from cuml.internals.safe_imports import cpu_only_import +from cuml.metrics import trustworthiness np = cpu_only_import("numpy") cp = gpu_only_import("cupy") @@ -133,7 +134,7 @@ def test_simplicial_set_embedding( metric = "euclidean" initial_alpha = 1.0 a, b = UMAP.find_ab_params(1.0, 0.1) - gamma = 0 + gamma = 1.0 negative_sample_rate = 5 n_epochs = 500 init = "random" @@ -180,7 +181,6 @@ def test_simplicial_set_embedding( cu_fss_graph = cu_fuzzy_simplicial_set( X, n_neighbors, random_state, metric ) - cu_embedding = cu_simplicial_set_embedding( X, cu_fss_graph, @@ -199,7 +199,7 @@ def test_simplicial_set_embedding( output_metric_kwds=output_metric_kwds, ) - ref_embedding = cp.array(ref_embedding) - assert correctness_dense( - ref_embedding, cu_embedding, rtol=0.1, threshold=0.95 - ) + ref_t_score = trustworthiness(X, ref_embedding, n_neighbors=n_neighbors) + t_score = trustworthiness(X, cu_embedding, n_neighbors=n_neighbors) + abs_tol = 0.05 + assert t_score >= ref_t_score - abs_tol From cd19b30c080542ccee90a68d419a22d5816a929f Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Wed, 11 Sep 2024 22:29:19 -0500 Subject: [PATCH 09/10] Set default values for conftest options. (#6067) There is no default being set for `run_ucx` or `run_ucxx`. This results in an error on Python 3.12 (maybe other versions too): ``` AttributeError: 'Namespace' object has no attribute 'run_ucx'. Did you mean: 'run_unit'? ``` This PR adds a default value for those conftest options. Authors: - Bradley Dice (https://github.com/bdice) Approvers: - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/cuml/pull/6067 --- python/cuml/cuml/tests/dask/conftest.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/cuml/cuml/tests/dask/conftest.py b/python/cuml/cuml/tests/dask/conftest.py index 27fb746e1c..bdaf591538 100644 --- a/python/cuml/cuml/tests/dask/conftest.py +++ b/python/cuml/cuml/tests/dask/conftest.py @@ -72,11 +72,17 @@ def pytest_addoption(parser): group = parser.getgroup("Dask cuML Custom Options") group.addoption( - "--run_ucx", action="store_true", help="run _only_ UCX-Py tests" + "--run_ucx", + action="store_true", + default=False, + help="run _only_ UCX-Py tests", ) group.addoption( - "--run_ucxx", action="store_true", help="run _only_ UCXX tests" + "--run_ucxx", + action="store_true", + default=False, + help="run _only_ UCXX tests", ) From 28641bb090c0432efa006a55cb755d32f3799a0c Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 12 Sep 2024 09:03:13 -0500 Subject: [PATCH 10/10] Add support for Python 3.12, update to umap-learn==0.5.6 (#6060) Contributes to https://github.com/rapidsai/build-planning/issues/40 This PR adds support for Python 3.12. Other changes required for this: * updating `umap-learn`, `0.5.3 -> 0.5.6` (https://github.com/rapidsai/cuml/pull/6060/files#r1745915933) ## Notes for Reviewers This is part of ongoing work to add Python 3.12 support across RAPIDS. It temporarily introduces a build/test matrix including Python 3.12, from https://github.com/rapidsai/shared-workflows/pull/213. A follow-up PR will revert back to pointing at the `branch-24.10` branch of `shared-workflows` once all RAPIDS repos have added Python 3.12 support. ### This will fail until all dependencies have been updates to Python 3.12 CI here is expected to fail until all of this project's upstream dependencies support Python 3.12. This can be merged whenever all CI jobs are passing. Authors: - James Lamb (https://github.com/jameslamb) - Bradley Dice (https://github.com/bdice) Approvers: - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/cuml/pull/6060 --- .github/workflows/build.yaml | 12 ++++---- .github/workflows/pr.yaml | 30 +++++++++---------- .github/workflows/test.yaml | 10 +++---- BUILD.md | 2 +- .../all_cuda-118_arch-x86_64.yaml | 5 ++-- .../all_cuda-125_arch-x86_64.yaml | 5 ++-- conda/recipes/cuml-cpu/meta.yaml | 2 +- dependencies.yaml | 9 ++++-- python/cuml/pyproject.toml | 4 ++- 9 files changed, 44 insertions(+), 35 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index fcc0aec68c..bf3fb52c50 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -28,7 +28,7 @@ concurrency: jobs: cpp-build: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@python-3.12 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -38,7 +38,7 @@ jobs: if: github.ref_type == 'branch' needs: [python-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@python-3.12 with: arch: "amd64" branch: ${{ inputs.branch }} @@ -51,7 +51,7 @@ jobs: python-build: needs: [cpp-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@python-3.12 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -60,7 +60,7 @@ jobs: upload-conda: needs: [cpp-build, python-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-upload-packages.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/conda-upload-packages.yaml@python-3.12 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -68,7 +68,7 @@ jobs: sha: ${{ inputs.sha }} wheel-build-cuml: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@python-3.12 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -84,7 +84,7 @@ jobs: wheel-publish-cuml: needs: wheel-build-cuml secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@python-3.12 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index c0c0601afb..d28a073775 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -26,10 +26,10 @@ jobs: - wheel-tests-cuml - devcontainer secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@python-3.12 checks: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/checks.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/checks.yaml@python-3.12 with: enable_check_generated_files: false ignored_pr_jobs: >- @@ -37,7 +37,7 @@ jobs: clang-tidy: needs: checks secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@python-3.12 with: build_type: pull-request node_type: "cpu8" @@ -47,19 +47,19 @@ jobs: conda-cpp-build: needs: checks secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@python-3.12 with: build_type: pull-request conda-cpp-tests: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@python-3.12 with: build_type: pull-request conda-cpp-checks: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-post-build-checks.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-post-build-checks.yaml@python-3.12 with: build_type: pull-request enable_check_symbols: true @@ -67,20 +67,20 @@ jobs: conda-python-build: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@python-3.12 with: build_type: pull-request conda-python-tests-singlegpu: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@python-3.12 with: build_type: pull-request script: "ci/test_python_singlegpu.sh" optional-job-conda-python-tests-cudf-pandas-integration: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@python-3.12 with: matrix_filter: map(select(.ARCH == "amd64")) build_type: pull-request @@ -88,14 +88,14 @@ jobs: conda-python-tests-dask: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@python-3.12 with: build_type: pull-request script: "ci/test_python_dask.sh" conda-notebook-tests: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@python-3.12 with: build_type: pull-request node_type: "gpu-v100-latest-1" @@ -105,7 +105,7 @@ jobs: docs-build: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@python-3.12 with: build_type: pull-request node_type: "gpu-v100-latest-1" @@ -115,7 +115,7 @@ jobs: wheel-build-cuml: needs: checks secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@python-3.12 with: build_type: pull-request script: ci/build_wheel.sh @@ -125,13 +125,13 @@ jobs: wheel-tests-cuml: needs: wheel-build-cuml secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@python-3.12 with: build_type: pull-request script: ci/test_wheel.sh devcontainer: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/build-in-devcontainer.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/build-in-devcontainer.yaml@python-3.12 with: arch: '["amd64"]' cuda: '["12.5"]' diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 9a2c0086ea..85f10e134b 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -16,7 +16,7 @@ on: jobs: conda-cpp-checks: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-post-build-checks.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-post-build-checks.yaml@python-3.12 with: build_type: nightly branch: ${{ inputs.branch }} @@ -26,7 +26,7 @@ jobs: symbol_exclusions: raft_cutlass conda-cpp-tests: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@python-3.12 with: build_type: nightly branch: ${{ inputs.branch }} @@ -34,7 +34,7 @@ jobs: sha: ${{ inputs.sha }} conda-python-tests-singlegpu: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@python-3.12 with: build_type: nightly branch: ${{ inputs.branch }} @@ -43,7 +43,7 @@ jobs: script: "ci/test_python_singlegpu.sh" conda-python-tests-dask: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@python-3.12 with: build_type: nightly branch: ${{ inputs.branch }} @@ -52,7 +52,7 @@ jobs: script: "ci/test_python_dask.sh" wheel-tests-cuml: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@python-3.12 with: build_type: nightly branch: ${{ inputs.branch }} diff --git a/BUILD.md b/BUILD.md index 4bc8310407..059836e57d 100644 --- a/BUILD.md +++ b/BUILD.md @@ -18,7 +18,7 @@ To install cuML from source, ensure the following dependencies are met: It is recommended to use conda for environment/package management. If doing so, development environment .yaml files are located in `conda/environments/all_*.yaml`. These files contains most of the dependencies mentioned above (notable exceptions are `gcc` and `zlib`). To create a development environment named `cuml_dev`, you can use the follow commands: ```bash -conda create -n cuml_dev python=3.11 +conda create -n cuml_dev python=3.12 conda env update -n cuml_dev --file=conda/environments/all_cuda-118_arch-x86_64.yaml conda activate cuml_dev ``` diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 363ce4f13e..e7dcb0a323 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -58,7 +58,7 @@ dependencies: - pytest-cov - pytest-xdist - pytest==7.* -- python>=3.10,<3.12 +- python>=3.10,<3.13 - raft-dask==24.10.*,>=0.0.0a0 - rapids-build-backend>=0.3.0,<0.4.0.dev0 - rapids-dask-dependency==24.10.*,>=0.0.0a0 @@ -68,13 +68,14 @@ dependencies: - scikit-learn==1.5 - scipy>=1.8.0 - seaborn +- setuptools - sphinx-copybutton - sphinx-markdown-tables - sphinx<6 - statsmodels - sysroot_linux-64==2.17 - treelite==4.3.0 -- umap-learn==0.5.3 +- umap-learn==0.5.6 - pip: - dask-glm==0.3.0 name: all_cuda-118_arch-x86_64 diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml index b70186deab..2340040085 100644 --- a/conda/environments/all_cuda-125_arch-x86_64.yaml +++ b/conda/environments/all_cuda-125_arch-x86_64.yaml @@ -54,7 +54,7 @@ dependencies: - pytest-cov - pytest-xdist - pytest==7.* -- python>=3.10,<3.12 +- python>=3.10,<3.13 - raft-dask==24.10.*,>=0.0.0a0 - rapids-build-backend>=0.3.0,<0.4.0.dev0 - rapids-dask-dependency==24.10.*,>=0.0.0a0 @@ -64,13 +64,14 @@ dependencies: - scikit-learn==1.5 - scipy>=1.8.0 - seaborn +- setuptools - sphinx-copybutton - sphinx-markdown-tables - sphinx<6 - statsmodels - sysroot_linux-64==2.17 - treelite==4.3.0 -- umap-learn==0.5.3 +- umap-learn==0.5.6 - pip: - dask-glm==0.3.0 name: all_cuda-125_arch-x86_64 diff --git a/conda/recipes/cuml-cpu/meta.yaml b/conda/recipes/cuml-cpu/meta.yaml index bf59fed151..97e5cdd813 100644 --- a/conda/recipes/cuml-cpu/meta.yaml +++ b/conda/recipes/cuml-cpu/meta.yaml @@ -35,7 +35,7 @@ requirements: - pandas - scikit-learn=1.2 - hdbscan>=0.8.38,<0.8.39 - - umap-learn=0.5.3 + - umap-learn=0.5.6 - nvtx tests: # [linux64] diff --git a/dependencies.yaml b/dependencies.yaml index e3a045efea..23a72d1db8 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -472,8 +472,12 @@ dependencies: packages: - python=3.11 - matrix: + py: "3.12" packages: - - python>=3.10,<3.12 + - python=3.12 + - matrix: + packages: + - python>=3.10,<3.13 test_libcuml: common: - output_types: conda @@ -509,8 +513,9 @@ dependencies: - seaborn - *scikit_learn - statsmodels - - umap-learn==0.5.3 + - umap-learn==0.5.6 - pynndescent + - setuptools # Needed on Python 3.12 for dask-glm, which requires pkg_resources but Python 3.12 doesn't have setuptools by default - output_types: conda packages: - pip diff --git a/python/cuml/pyproject.toml b/python/cuml/pyproject.toml index 4149c721ce..8934a0f226 100644 --- a/python/cuml/pyproject.toml +++ b/python/cuml/pyproject.toml @@ -105,6 +105,7 @@ classifiers = [ "Programming Language :: Python", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ] [project.optional-dependencies] @@ -124,8 +125,9 @@ test = [ "pytest==7.*", "scikit-learn==1.5", "seaborn", + "setuptools", "statsmodels", - "umap-learn==0.5.3", + "umap-learn==0.5.6", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. [project.urls]