diff --git a/src/snkit/network.py b/src/snkit/network.py index 949ed2a..3cc5e8b 100644 --- a/src/snkit/network.py +++ b/src/snkit/network.py @@ -694,18 +694,16 @@ def to_networkx(network, directed=False, weight_col=None): G = nx.MultiDiGraph() # get nodes from network data G.add_nodes_from(network.nodes.id.to_list()) + # add nodal positions from geom - network.nodes["pos"] = list( - zip(network.nodes.geometry.x, network.nodes.geometry.y) - ) - pos = network.nodes.set_index("id").to_dict()["pos"] - for n, p in pos.items(): - G.nodes[n]["pos"] = p + for node_id, x, y in zip(network.nodes.id, network.nodes.geometry.x, network.nodes.geometry.y): + G.nodes[node_id]["pos"] = (x, y) + # get edges from network data if weight_col is None: - network.edges["weight"] = network.edges.geometry.length + # default to geometry length edges_as_list = list( - zip(network.edges.from_id, network.edges.to_id, network.edges.weight) + zip(network.edges.from_id, network.edges.to_id, network.edges.geometry.length) ) else: edges_as_list = list(