Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return networkx graph #41

Merged
merged 15 commits into from
Mar 23, 2022
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ attrs>=17.4 # to fix pytest compatibility on python 3.6
pylint
pytest>=4.6
pytest-cov
networkx>=1.9
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def readme():
# eg:
# 'rst': ['docutils>=0.11'],
# ':python_version=="2.6"': ['argparse'],
'networkx>=1.9'
},
entry_points={
'console_scripts': [
Expand Down
62 changes: 62 additions & 0 deletions src/snkit/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
import numpy as np
import pandas
import shapely.errors
try:
import networkx as nx
USE_NX = True
except ImportError:
USE_NX = False

from geopandas import GeoDataFrame
from shapely.geometry import Point, MultiPoint, LineString, GeometryCollection, shape, mapping
Expand Down Expand Up @@ -603,3 +608,60 @@ def set_precision(geom, precision):
geom_mapping = mapping(geom)
geom_mapping['coordinates'] = np.round(np.array(geom_mapping['coordinates']), precision)
return shape(geom_mapping)


def to_networkx(network,directed=False,weight_col=None):
"""Return a networkx graph
"""
if not USE_NX:
raise ImportError('No module named networkx')
else:
# init graph
if not directed:
G = nx.Graph()
else:
G = nx.MultiDiGraph()
# get nodes from network data
G.add_nodes_from(network.nodes.id.to_list())
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly to edges, can we add node attributes to the graph here?

# add nodal positions from geom
network.nodes['pos'] = list(zip(network.nodes.geometry.x, network.nodes.geometry.y))
pos = network.nodes.set_index('id').to_dict()['pos']
for n,p in pos.items():
G.nodes[n]['pos'] = p
# get edges from network data
if weight_col is None:
network.edges['weight'] = network.edges.geometry.length
edges_as_list = list(zip(network.edges.from_id,network.edges.to_id,network.edges.weight))
else:
edges_as_list = list(zip(network.edges.from_id,network.edges.to_id,network.edges[weight_col]))
# add edges to graph
G.add_weighted_edges_from(edges_as_list)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we pass in a column to use as weight here?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just thinking about this - it would be generally useful to add all the edge attributes here, including geometry.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed this part to extract nodal positions and edge weights (https://github.com/amanmajid/snkit/blob/e55bc44c4aa848fb9f1724f8db504a3c6a7ad114/src/snkit/network.py#L613-L639). I'm not sure how to go about efficiently adding all edge attributes when defining the graph?

return G


def get_connected_components(network):
"""Get connected components within network and id to each individual graph
"""
if not USE_NX:
raise ImportError('No module named networkx')
else:
G = to_networkx(network)
return sorted(nx.connected_components(G), key = len, reverse=True)


def add_component_ids(network,id_col='component_id'):
"""Add column of component IDs to network data
"""
# get connected components
connected_parts = get_connected_components(network)
# add unique id to each graph
network.edges[id_col] = 0 # init id_col
network.nodes[id_col] = 0 # init id_col
for count, part in enumerate(connected_parts):
# edges
network.edges.loc[ (network.edges.from_id.isin(list(part))) | \
(network.edges.to_id.isin(list(part))), id_col ] = count + 1
# nodes
network.nodes.loc[ (network.nodes.id.isin(list(part))), id_col] = count + 1
# return
return network
22 changes: 22 additions & 0 deletions tests/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
from pytest import fixture
from shapely.geometry import Point, LineString, MultiPoint

try:
import networkx as nx
from networkx.utils.misc import graphs_equal
USE_NX = True
except ImportError:
USE_NX = False

import snkit
import snkit.network

Expand Down Expand Up @@ -346,3 +353,18 @@ def test_passing_slice():

print(actual)
assert_frame_equal(actual, expected)


def test_to_networkx(connected):
'''test to networkx
'''
connected.nodes['id'] = ['n'+str(i) for i in connected.nodes.index]
connected = snkit.network.add_topology(connected)
G = snkit.network.to_networkx(connected)

G_true = nx.Graph()
G_true.add_node('n0',pos=(0, 0))
G_true.add_node('n1',pos=(0, 2))
G_true.add_edge('n0', 'n1', weight=2)

assert graphs_equal(G, G_true)