44from typing import TYPE_CHECKING
55
66import cudf
7+ import cupy as cp
78import numpy as np
89import pandas as pd
910from natsort import natsorted
1516if TYPE_CHECKING :
1617 from collections .abc import Sequence
1718
18- import cupy as cp
1919 from anndata import AnnData
2020 from scipy import sparse
2121
@@ -63,9 +63,64 @@ def _create_graph(adjacency, dtype=np.float64, *, use_weights=True):
6363 return g
6464
6565
66+ def _create_graph_dask (adjacency , dtype = np .float64 , * , use_weights = True ):
67+ import cudf
68+ import dask .dataframe as dd
69+ from cugraph import Graph
70+
71+ rows = np .repeat (np .arange (adjacency .shape [0 ]), np .diff (adjacency .indptr )).astype (
72+ np .int32
73+ )
74+ cols = adjacency .indices
75+ weights = adjacency .data
76+
77+ n_devices = cp .cuda .runtime .getDeviceCount ()
78+ chunksize = int ((adjacency .nnz + n_devices - 1 ) / n_devices )
79+
80+ boundaries = list (range (0 , adjacency .nnz , chunksize ))
81+ pairs = [(start , min (start + chunksize , adjacency .nnz )) for start in boundaries ]
82+
83+ def mapper (pair ):
84+ start , end = pair
85+ return cudf .DataFrame (
86+ {
87+ "src" : rows [start :end ].astype (np .int64 ),
88+ "dst" : cols [start :end ].astype (np .int64 ),
89+ "weight" : weights [start :end ].astype (dtype ),
90+ }
91+ )
92+
93+ # meta must match the actual columns
94+ meta = {
95+ "src" : np .int64 ,
96+ "dst" : np .int64 ,
97+ "weight" : dtype ,
98+ }
99+
100+ ddf = dd .from_map (mapper , pairs , meta = meta ).to_backend ("cudf" ).persist ()
101+ import cugraph .dask .comms .comms as Comms
102+
103+ Comms .initialize (p2p = True )
104+ g = Graph ()
105+ if use_weights :
106+ g .from_dask_cudf_edgelist (
107+ ddf ,
108+ source = "src" ,
109+ destination = "dst" ,
110+ weight = "weight" ,
111+ )
112+ else :
113+ g .from_dask_cudf_edgelist (
114+ ddf ,
115+ source = "src" ,
116+ destination = "dst" ,
117+ )
118+ return g
119+
120+
66121def leiden (
67122 adata : AnnData ,
68- resolution : float = 1.0 ,
123+ resolution : float | list [ float ] = 1.0 ,
69124 * ,
70125 random_state : int | None = 0 ,
71126 theta : float = 1.0 ,
@@ -77,6 +132,7 @@ def leiden(
77132 neighbors_key : str | None = None ,
78133 obsp : str | None = None ,
79134 dtype : str | np .dtype | cp .dtype = np .float32 ,
135+ use_dask : bool = False ,
80136 copy : bool = False ,
81137) -> AnnData | None :
82138 """
@@ -93,9 +149,9 @@ def leiden(
93149 annData object
94150
95151 resolution
96- A parameter value controlling the coarseness of the clustering.
152+ A parameter value or a list of parameter values controlling the coarseness of the clustering.
97153 (called gamma in the modularity formula). Higher values lead to
98- more clusters.
154+ more clusters. If a list of values is provided, the Leiden algorithm will be run for each value in the list.
99155
100156 random_state
101157 Change the initialization of the optimization. Defaults to 0.
@@ -140,11 +196,13 @@ def leiden(
140196 dtype
141197 Data type to use for the adjacency matrix.
142198
199+ use_dask
200+ If `True`, use Dask to create the graph and cluster. This will use all GPUs available. This feature is experimental. For datasets with less than 10 Million cells, it is recommended to use `use_dask=False`.
201+
143202 copy
144203 Whether to copy `adata` or modify it in place.
145204 """
146205 # Adjacency graph
147- from cugraph import leiden as culeiden
148206
149207 adata = adata .copy () if copy else adata
150208
@@ -160,40 +218,61 @@ def leiden(
160218 restrict_categories = restrict_categories ,
161219 adjacency = adjacency ,
162220 )
221+ if use_dask :
222+ from cugraph .dask import leiden as culeiden
223+
224+ g = _create_graph_dask (adjacency , dtype , use_weights = use_weights )
225+ else :
226+ from cugraph import leiden as culeiden
163227
164- g = _create_graph (adjacency , dtype , use_weights = use_weights )
228+ g = _create_graph (adjacency , dtype , use_weights = use_weights )
165229 # Cluster
166- leiden_parts , _ = culeiden (
167- g ,
168- resolution = resolution ,
169- random_state = random_state ,
170- theta = theta ,
171- max_iter = n_iterations ,
172- )
230+ if isinstance (resolution , float | int ):
231+ resolutions = [resolution ]
232+ else :
233+ resolutions = resolution
234+ for resolution in resolutions :
235+ leiden_parts , _ = culeiden (
236+ g ,
237+ resolution = resolution ,
238+ random_state = random_state ,
239+ theta = theta ,
240+ max_iter = n_iterations ,
241+ )
242+ if use_dask :
243+ leiden_parts = leiden_parts .to_backend ("pandas" ).compute ()
244+ else :
245+ leiden_parts = leiden_parts .to_pandas ()
246+
247+ # Format output
248+ groups = leiden_parts .sort_values ("vertex" )[["partition" ]].to_numpy ().ravel ()
249+ key_added_to_use = key_added
250+ if restrict_to is not None :
251+ if key_added == "leiden" :
252+ key_added_to_use += "_R"
253+ groups = rename_groups (
254+ adata ,
255+ key_added = key_added_to_use ,
256+ restrict_key = restrict_key ,
257+ restrict_categories = restrict_categories ,
258+ restrict_indices = restrict_indices ,
259+ groups = groups ,
260+ )
261+ if len (resolutions ) > 1 :
262+ key_added_to_use += f"_{ resolution } "
173263
174- # Format output
175- groups = (
176- leiden_parts .to_pandas ().sort_values ("vertex" )[["partition" ]].to_numpy ().ravel ()
177- )
178- if restrict_to is not None :
179- if key_added == "leiden" :
180- key_added += "_R"
181- groups = rename_groups (
182- adata ,
183- key_added = key_added ,
184- restrict_key = restrict_key ,
185- restrict_categories = restrict_categories ,
186- restrict_indices = restrict_indices ,
187- groups = groups ,
264+ adata .obs [key_added_to_use ] = pd .Categorical (
265+ values = groups .astype ("U" ),
266+ categories = natsorted (map (str , np .unique (groups ))),
188267 )
189- adata . obs [ key_added ] = pd . Categorical (
190- values = groups . astype ( "U" ),
191- categories = natsorted ( map ( str , np . unique ( groups ))),
192- )
268+ if use_dask :
269+ import cugraph . dask . comms . comms as Comms
270+
271+ Comms . destroy ( )
193272 # store information on the clustering parameters
194273 adata .uns [key_added ] = {}
195274 adata .uns [key_added ]["params" ] = {
196- "resolution" : resolution ,
275+ "resolution" : resolutions ,
197276 "random_state" : random_state ,
198277 "n_iterations" : n_iterations ,
199278 }
@@ -202,7 +281,7 @@ def leiden(
202281
203282def louvain (
204283 adata : AnnData ,
205- resolution : float = 1.0 ,
284+ resolution : float | list [ float ] = 1.0 ,
206285 * ,
207286 restrict_to : tuple [str , Sequence [str ]] | None = None ,
208287 key_added : str = "louvain" ,
@@ -213,6 +292,7 @@ def louvain(
213292 neighbors_key : int | None = None ,
214293 obsp : str | None = None ,
215294 dtype : str | np .dtype | cp .dtype = np .float32 ,
295+ use_dask : bool = False ,
216296 copy : bool = False ,
217297) -> AnnData | None :
218298 """
@@ -229,9 +309,9 @@ def louvain(
229309 annData object
230310
231311 resolution
232- A parameter value controlling the coarseness of the clustering
312+ A parameter value or a list of parameter values controlling the coarseness of the clustering.
233313 (called gamma in the modularity formula). Higher values lead to
234- more clusters.
314+ more clusters. If a list of values is provided, the Leiden algorithm will be run for each value in the list.
235315
236316 restrict_to
237317 Restrict the clustering to the categories within the key for
@@ -275,13 +355,14 @@ def louvain(
275355 dtype
276356 Data type to use for the adjacency matrix.
277357
358+ use_dask
359+ If `True`, use Dask to create the graph and cluster. This will use all GPUs available. This feature is experimental. For datasets with less than 10 Million cells, it is recommended to use `use_dask=False`.
360+
278361 copy
279362 Whether to copy `adata` or modify it in place.
280363
281364 """
282365 # Adjacency graph
283- from cugraph import louvain as culouvain
284-
285366 dtype = _check_dtype (dtype )
286367
287368 adata = adata .copy () if copy else adata
@@ -295,43 +376,60 @@ def louvain(
295376 restrict_categories = restrict_categories ,
296377 adjacency = adjacency ,
297378 )
379+ # Cluster
380+ if use_dask :
381+ from cugraph .dask import louvain as culouvain
298382
299- g = _create_graph (adjacency , dtype , use_weights = use_weights )
383+ g = _create_graph_dask (adjacency , dtype , use_weights = use_weights )
384+ else :
385+ from cugraph import louvain as culouvain
300386
301- # Cluster
302- louvain_parts , _ = culouvain (
303- g ,
304- resolution = resolution ,
305- max_level = n_iterations ,
306- threshold = threshold ,
307- )
387+ g = _create_graph (adjacency , dtype , use_weights = use_weights )
308388
309- # Format output
310- groups = (
311- louvain_parts .to_pandas ()
312- .sort_values ("vertex" )[["partition" ]]
313- .to_numpy ()
314- .ravel ()
315- )
316- if restrict_to is not None :
317- if key_added == "louvain" :
318- key_added += "_R"
319- groups = rename_groups (
320- adata ,
321- key_added = key_added ,
322- restrict_key = restrict_key ,
323- restrict_categories = restrict_categories ,
324- restrict_indices = restrict_indices ,
325- groups = groups ,
389+ if isinstance (resolution , float | int ):
390+ resolutions = [resolution ]
391+ else :
392+ resolutions = resolution
393+ for resolution in resolutions :
394+ louvain_parts , _ = culouvain (
395+ g ,
396+ resolution = resolution ,
397+ max_level = n_iterations ,
398+ threshold = threshold ,
326399 )
400+ if use_dask :
401+ louvain_parts = louvain_parts .to_backend ("pandas" ).compute ()
402+ else :
403+ louvain_parts = louvain_parts .to_pandas ()
404+
405+ # Format output
406+ groups = louvain_parts .sort_values ("vertex" )[["partition" ]].to_numpy ().ravel ()
407+ key_added_to_use = key_added
408+ if restrict_to is not None :
409+ if key_added == "louvain" :
410+ key_added_to_use += "_R"
411+ groups = rename_groups (
412+ adata ,
413+ key_added = key_added_to_use ,
414+ restrict_key = restrict_key ,
415+ restrict_categories = restrict_categories ,
416+ restrict_indices = restrict_indices ,
417+ groups = groups ,
418+ )
419+ if len (resolutions ) > 1 :
420+ key_added_to_use += f"_{ resolution } "
327421
328- adata .obs [key_added ] = pd .Categorical (
329- values = groups .astype ("U" ),
330- categories = natsorted (map (str , np .unique (groups ))),
331- )
422+ adata .obs [key_added_to_use ] = pd .Categorical (
423+ values = groups .astype ("U" ),
424+ categories = natsorted (map (str , np .unique (groups ))),
425+ )
426+ if use_dask :
427+ import cugraph .dask .comms .comms as Comms
428+
429+ Comms .destroy ()
332430 adata .uns [key_added ] = {}
333431 adata .uns [key_added ]["params" ] = {
334- "resolution" : resolution ,
432+ "resolution" : resolutions ,
335433 "n_iterations" : n_iterations ,
336434 "threshold" : threshold ,
337435 }
0 commit comments