diff --git a/_nx_cugraph/__init__.py b/_nx_cugraph/__init__.py index 3beca176c..15a0cdd4a 100644 --- a/_nx_cugraph/__init__.py +++ b/_nx_cugraph/__init__.py @@ -198,7 +198,7 @@ "edge_betweenness_centrality": "`weight` parameter is not yet supported, and RNG with seed may be different.", "ego_graph": "Weighted ego_graph with negative cycles is not yet supported. `NotImplementedError` will be raised if there are negative `distance` edge weights.", "eigenvector_centrality": "`nstart` parameter is not used, but it is checked for validity.", - "forceatlas2_layout": "`node_mass` parameter is currently ignored. Only `dim=2` is supported.", + "forceatlas2_layout": "Only `dim=2` is supported, and there may be minor numeric differences.", "from_pandas_edgelist": "cudf.DataFrame inputs also supported; value columns with str is unsuppported.", "generic_bfs_edges": "`neighbors` parameter is not yet supported.", "katz_centrality": "`nstart` isn't used (but is checked), and `normalized=False` is not supported.", diff --git a/nx_cugraph/drawing/layout.py b/nx_cugraph/drawing/layout.py index 69d4d19d0..49357deef 100644 --- a/nx_cugraph/drawing/layout.py +++ b/nx_cugraph/drawing/layout.py @@ -38,7 +38,7 @@ @networkx_algorithm( extra_params=_dtype_param, - is_incomplete=True, # dim=2-only; no node_mass + is_incomplete=True, # dim=2-only is_different=True, # node_size handled differently, different RNG and results version_added="25.04", _plc="forceatlas2_layout", @@ -64,7 +64,7 @@ def forceatlas2_layout( # nx_cugraph-only argument dtype=None, ): - """`node_mass` parameter is currently ignored. Only `dim=2` is supported.""" + """Only `dim=2` is supported, and there may be minor numeric differences.""" if len(G) == 0: return {} @@ -138,6 +138,25 @@ def forceatlas2_layout( vertex_radius_vertices = None prevent_overlapping = False + if node_mass is not None: + # Default mass is degree + 1 + vertex_mass_values = G._dict_to_nodearray( + node_mass, default=np.nan, dtype=np.float32 + ) + isnan = cp.isnan(vertex_mass_values) + if isnan.any(): + vertex_mass_values = cp.where( + isnan, (G._degrees_array() + 1).astype(np.float32), vertex_mass_values + ) + vertex_mass_vertices = ( + start_vertices + if start_vertices is not None + else cp.arange(G._N, dtype=index_dtype) + ) + else: + vertex_mass_values = None + vertex_mass_vertices = None + seed = _seed_to_int(seed) vertices, x_axis, y_axis = plc.force_atlas2( @@ -164,6 +183,8 @@ def forceatlas2_layout( gravity=gravity, vertex_mobility_vertices=None, vertex_mobility_values=None, + vertex_mass_vertices=vertex_mass_vertices, + vertex_mass_values=vertex_mass_values, verbose=False, do_expensive_check=False, )