diff --git a/dev-requirements.txt b/dev-requirements.txt index 489346c..30f2a5b 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -10,3 +10,4 @@ attrs>=17.4 # to fix pytest compatibility on python 3.6 pylint pytest>=4.6 pytest-cov +networkx>=1.9 diff --git a/setup.py b/setup.py index 9314602..f0246f3 100644 --- a/setup.py +++ b/setup.py @@ -57,6 +57,7 @@ def readme(): # eg: # 'rst': ['docutils>=0.11'], # ':python_version=="2.6"': ['argparse'], + 'networkx>=1.9' }, entry_points={ 'console_scripts': [ diff --git a/src/snkit/network.py b/src/snkit/network.py index 2672f76..6c66e77 100644 --- a/src/snkit/network.py +++ b/src/snkit/network.py @@ -5,6 +5,11 @@ import numpy as np import pandas import shapely.errors +try: + import networkx as nx + USE_NX = True +except ImportError: + USE_NX = False from geopandas import GeoDataFrame from shapely.geometry import Point, MultiPoint, LineString, GeometryCollection, shape, mapping @@ -603,3 +608,60 @@ def set_precision(geom, precision): geom_mapping = mapping(geom) geom_mapping['coordinates'] = np.round(np.array(geom_mapping['coordinates']), precision) return shape(geom_mapping) + + +def to_networkx(network,directed=False,weight_col=None): + """Return a networkx graph + """ + if not USE_NX: + raise ImportError('No module named networkx') + else: + # init graph + if not directed: + G = nx.Graph() + else: + G = nx.MultiDiGraph() + # get nodes from network data + G.add_nodes_from(network.nodes.id.to_list()) + # add nodal positions from geom + network.nodes['pos'] = list(zip(network.nodes.geometry.x, network.nodes.geometry.y)) + pos = network.nodes.set_index('id').to_dict()['pos'] + for n,p in pos.items(): + G.nodes[n]['pos'] = p + # get edges from network data + if weight_col is None: + network.edges['weight'] = network.edges.geometry.length + edges_as_list = list(zip(network.edges.from_id,network.edges.to_id,network.edges.weight)) + else: + edges_as_list = list(zip(network.edges.from_id,network.edges.to_id,network.edges[weight_col])) + # add edges to graph + G.add_weighted_edges_from(edges_as_list) + return G + + +def get_connected_components(network): + """Get connected components within network and id to each individual graph + """ + if not USE_NX: + raise ImportError('No module named networkx') + else: + G = to_networkx(network) + return sorted(nx.connected_components(G), key = len, reverse=True) + + +def add_component_ids(network,id_col='component_id'): + """Add column of component IDs to network data + """ + # get connected components + connected_parts = get_connected_components(network) + # add unique id to each graph + network.edges[id_col] = 0 # init id_col + network.nodes[id_col] = 0 # init id_col + for count, part in enumerate(connected_parts): + # edges + network.edges.loc[ (network.edges.from_id.isin(list(part))) | \ + (network.edges.to_id.isin(list(part))), id_col ] = count + 1 + # nodes + network.nodes.loc[ (network.nodes.id.isin(list(part))), id_col] = count + 1 + # return + return network diff --git a/tests/test_init.py b/tests/test_init.py index c6f594a..a0045a5 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -6,6 +6,13 @@ from pytest import fixture from shapely.geometry import Point, LineString, MultiPoint +try: + import networkx as nx + from networkx.utils.misc import graphs_equal + USE_NX = True +except ImportError: + USE_NX = False + import snkit import snkit.network @@ -346,3 +353,18 @@ def test_passing_slice(): print(actual) assert_frame_equal(actual, expected) + + +def test_to_networkx(connected): + '''test to networkx + ''' + connected.nodes['id'] = ['n'+str(i) for i in connected.nodes.index] + connected = snkit.network.add_topology(connected) + G = snkit.network.to_networkx(connected) + + G_true = nx.Graph() + G_true.add_node('n0',pos=(0, 0)) + G_true.add_node('n1',pos=(0, 2)) + G_true.add_edge('n0', 'n1', weight=2) + + assert graphs_equal(G, G_true) \ No newline at end of file