-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Return networkx graph #41
Changes from all commits
0099ba7
a5ec6d9
1537469
f156c60
0a90dc7
6c506ee
c783e95
1686d0b
8b07797
68385b5
f71a16e
e3530d2
e55bc44
1bba5c2
98979d4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,3 +10,4 @@ attrs>=17.4 # to fix pytest compatibility on python 3.6 | |
pylint | ||
pytest>=4.6 | ||
pytest-cov | ||
networkx>=1.9 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,11 @@ | |
import numpy as np | ||
import pandas | ||
import shapely.errors | ||
try: | ||
import networkx as nx | ||
USE_NX = True | ||
except ImportError: | ||
USE_NX = False | ||
|
||
from geopandas import GeoDataFrame | ||
from shapely.geometry import Point, MultiPoint, LineString, GeometryCollection, shape, mapping | ||
|
@@ -603,3 +608,60 @@ def set_precision(geom, precision): | |
geom_mapping = mapping(geom) | ||
geom_mapping['coordinates'] = np.round(np.array(geom_mapping['coordinates']), precision) | ||
return shape(geom_mapping) | ||
|
||
|
||
def to_networkx(network,directed=False,weight_col=None): | ||
"""Return a networkx graph | ||
""" | ||
if not USE_NX: | ||
raise ImportError('No module named networkx') | ||
else: | ||
# init graph | ||
if not directed: | ||
G = nx.Graph() | ||
else: | ||
G = nx.MultiDiGraph() | ||
# get nodes from network data | ||
G.add_nodes_from(network.nodes.id.to_list()) | ||
# add nodal positions from geom | ||
network.nodes['pos'] = list(zip(network.nodes.geometry.x, network.nodes.geometry.y)) | ||
pos = network.nodes.set_index('id').to_dict()['pos'] | ||
for n,p in pos.items(): | ||
G.nodes[n]['pos'] = p | ||
# get edges from network data | ||
if weight_col is None: | ||
network.edges['weight'] = network.edges.geometry.length | ||
edges_as_list = list(zip(network.edges.from_id,network.edges.to_id,network.edges.weight)) | ||
else: | ||
edges_as_list = list(zip(network.edges.from_id,network.edges.to_id,network.edges[weight_col])) | ||
# add edges to graph | ||
G.add_weighted_edges_from(edges_as_list) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we pass in a column to use as weight here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just thinking about this - it would be generally useful to add all the edge attributes here, including geometry. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've changed this part to extract nodal positions and edge weights (https://github.com/amanmajid/snkit/blob/e55bc44c4aa848fb9f1724f8db504a3c6a7ad114/src/snkit/network.py#L613-L639). I'm not sure how to go about efficiently adding all edge attributes when defining the graph? |
||
return G | ||
|
||
|
||
def get_connected_components(network): | ||
"""Get connected components within network and id to each individual graph | ||
""" | ||
if not USE_NX: | ||
raise ImportError('No module named networkx') | ||
else: | ||
G = to_networkx(network) | ||
return sorted(nx.connected_components(G), key = len, reverse=True) | ||
|
||
|
||
def add_component_ids(network,id_col='component_id'): | ||
"""Add column of component IDs to network data | ||
""" | ||
# get connected components | ||
connected_parts = get_connected_components(network) | ||
# add unique id to each graph | ||
network.edges[id_col] = 0 # init id_col | ||
network.nodes[id_col] = 0 # init id_col | ||
for count, part in enumerate(connected_parts): | ||
# edges | ||
network.edges.loc[ (network.edges.from_id.isin(list(part))) | \ | ||
(network.edges.to_id.isin(list(part))), id_col ] = count + 1 | ||
# nodes | ||
network.nodes.loc[ (network.nodes.id.isin(list(part))), id_col] = count + 1 | ||
# return | ||
return network |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly to edges, can we add node attributes to the graph here?