Skip to content

Commit 5fe7b19

Browse files
authored
Merge pull request #41 from amanmajid/master
Return networkx graph
2 parents 665ee43 + 98979d4 commit 5fe7b19

File tree

4 files changed

+86
-0
lines changed

4 files changed

+86
-0
lines changed

dev-requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ black
77
nbstripout
88
pytest
99
pytest-cov
10+
networkx>=1.9

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def readme():
5454
# eg:
5555
# 'rst': ['docutils>=0.11'],
5656
# ':python_version=="2.6"': ['argparse'],
57+
'networkx>=1.9'
5758
},
5859
entry_points={
5960
"console_scripts": [

src/snkit/network.py

+62
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
import numpy as np
1111
import pandas
1212
import shapely.errors
13+
try:
14+
import networkx as nx
15+
USE_NX = True
16+
except ImportError:
17+
USE_NX = False
1318

1419
from geopandas import GeoDataFrame
1520
from shapely.geometry import (
@@ -673,3 +678,60 @@ def set_precision(geom, precision):
673678
np.array(geom_mapping["coordinates"]), precision
674679
)
675680
return shape(geom_mapping)
681+
682+
683+
def to_networkx(network,directed=False,weight_col=None):
684+
"""Return a networkx graph
685+
"""
686+
if not USE_NX:
687+
raise ImportError('No module named networkx')
688+
else:
689+
# init graph
690+
if not directed:
691+
G = nx.Graph()
692+
else:
693+
G = nx.MultiDiGraph()
694+
# get nodes from network data
695+
G.add_nodes_from(network.nodes.id.to_list())
696+
# add nodal positions from geom
697+
network.nodes['pos'] = list(zip(network.nodes.geometry.x, network.nodes.geometry.y))
698+
pos = network.nodes.set_index('id').to_dict()['pos']
699+
for n,p in pos.items():
700+
G.nodes[n]['pos'] = p
701+
# get edges from network data
702+
if weight_col is None:
703+
network.edges['weight'] = network.edges.geometry.length
704+
edges_as_list = list(zip(network.edges.from_id,network.edges.to_id,network.edges.weight))
705+
else:
706+
edges_as_list = list(zip(network.edges.from_id,network.edges.to_id,network.edges[weight_col]))
707+
# add edges to graph
708+
G.add_weighted_edges_from(edges_as_list)
709+
return G
710+
711+
712+
def get_connected_components(network):
713+
"""Get connected components within network and id to each individual graph
714+
"""
715+
if not USE_NX:
716+
raise ImportError('No module named networkx')
717+
else:
718+
G = to_networkx(network)
719+
return sorted(nx.connected_components(G), key = len, reverse=True)
720+
721+
722+
def add_component_ids(network,id_col='component_id'):
723+
"""Add column of component IDs to network data
724+
"""
725+
# get connected components
726+
connected_parts = get_connected_components(network)
727+
# add unique id to each graph
728+
network.edges[id_col] = 0 # init id_col
729+
network.nodes[id_col] = 0 # init id_col
730+
for count, part in enumerate(connected_parts):
731+
# edges
732+
network.edges.loc[ (network.edges.from_id.isin(list(part))) | \
733+
(network.edges.to_id.isin(list(part))), id_col ] = count + 1
734+
# nodes
735+
network.nodes.loc[ (network.nodes.id.isin(list(part))), id_col] = count + 1
736+
# return
737+
return network

tests/test_init.py

+22
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@
1111
from pytest import fixture
1212
from shapely.geometry import Point, LineString, MultiPoint
1313

14+
try:
15+
import networkx as nx
16+
from networkx.utils.misc import graphs_equal
17+
USE_NX = True
18+
except ImportError:
19+
USE_NX = False
20+
1421
import snkit
1522
import snkit.network
1623

@@ -350,3 +357,18 @@ def test_passing_slice():
350357

351358
print(actual)
352359
assert_frame_equal(actual, expected)
360+
361+
362+
def test_to_networkx(connected):
363+
'''test to networkx
364+
'''
365+
connected.nodes['id'] = ['n'+str(i) for i in connected.nodes.index]
366+
connected = snkit.network.add_topology(connected)
367+
G = snkit.network.to_networkx(connected)
368+
369+
G_true = nx.Graph()
370+
G_true.add_node('n0',pos=(0, 0))
371+
G_true.add_node('n1',pos=(0, 2))
372+
G_true.add_edge('n0', 'n1', weight=2)
373+
374+
assert graphs_equal(G, G_true)

0 commit comments

Comments
 (0)