diff --git a/src/snkit/network.py b/src/snkit/network.py index 7154f63..817a0a3 100644 --- a/src/snkit/network.py +++ b/src/snkit/network.py @@ -1040,3 +1040,31 @@ def add_component_ids(network: Network, id_col: str = "component_id") -> Network network.nodes.loc[node_mask, id_col] = count + 1 return network + + +def merge_networks(networks: List[Network]) -> Network: + """Merge multiple networks, identifying duplicate nodes at shared locations""" + + n = len(networks) + if n == 0: + warnings.warn("Merging zero networks to return empty network.") + return Network() + if n == 1: + return networks[0] + + # TODO update components + # - find duplicated nodes by location + # - for each duplicated node, record (network_idx, component_id) + # - set up nx.Graph + # - a vertex for each duplicated node + # - edges connecting duplicate node sets + # - edges connecting (network_idx, component_id) sets + # - find connected components in this graph + # - set up Dict[Tuple[network_idx, component_id], global_component_id] + + # TODO do we default to concat_dedup? or default low-intervention? + # TODO how to handle expectations of unique node or edge ids? (especially if edges have topology) + nodes = concat_dedup([network.nodes for network in networks]) + edges = concat_dedup([network.edges for network in networks]) + + return Network(nodes, edges) diff --git a/tests/test_init.py b/tests/test_init.py index 827bba7..e9a158b 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -838,6 +838,26 @@ def test_add_component_ids(two_components): assert all(labelled.edges.component_id == pd.Series([2, 1, 1])) +def test_merge_zero_networks(): + merged = snkit.network.merge_networks([]) + assert merged.nodes.empty + assert merged.edges.empty + + +def test_merge_one_network(split): + merged = snkit.network.merge_networks([split]) + assert_frame_equal(merged.nodes, split.nodes) + assert_frame_equal(merged.edges, split.edges) + + +def test_merge_networks(split): + split_abc = snkit.Network(split.nodes[:3], split.edges[:2]) + split_cd = snkit.Network(split.nodes[2:], split.edges[2:]) + merged = snkit.network.merge_networks([split_abc, split_cd]) + assert_frame_equal(merged.nodes, split.nodes) + assert_frame_equal(merged.edges, split.edges) + + def test_matching_gdf_from_geoms(edge_only): expected = edge_only.edges.copy() gdf = edge_only.edges.copy()