diff --git a/src/snkit/network.py b/src/snkit/network.py index cdc6588..b9cbd5d 100644 --- a/src/snkit/network.py +++ b/src/snkit/network.py @@ -1,6 +1,8 @@ """Network representation and utilities """ +import logging import os +from typing import Optional import warnings import geopandas @@ -38,11 +40,20 @@ from snkit.utils import tqdm_standin as tqdm # optional parallel processing -if "SNKIT_PARALLEL" in os.environ and os.environ["SNKIT_PARALLEL"] in ("1", "TRUE"): - PARALLEL = True +if "SNKIT_PROCESSES" in os.environ: + processes_env_var = os.environ["SNKIT_PROCESSES"] + try: + requested_processes = int(processes_env_var) + except TypeError: + raise RuntimeError( + "SNKIT_PROCESSES env var must be a non-negative integer. " + "Use 0 or unset for serial operation." + ) + PARALLEL_PROCESS_COUNT = min([os.cpu_count(), requested_processes]) import multiprocessing + logging.info(f"SNKIT_PROCESSES={processes_env_var}, using {PARALLEL_PROCESS_COUNT} processes") else: - PARALLEL = False + PARALLEL_PROCESS_COUNT = 0 class Network: @@ -316,18 +327,33 @@ def _split_edges_at_nodes( return split_edges -def split_edges_at_nodes(network, tolerance=1e-9): - """Split network edges where they intersect node geometries""" +def split_edges_at_nodes(network: Network, tolerance: float = 1e-9, chunk_size: Optional[int] = None): + """ + Split network edges where they intersect node geometries. + + N.B. Can operate in parallel if SNKIT_PROCESSES is in the environment and a + positive integer. + + Args: + network: Network object to split edges for. + tolerance: Proximity within which nodes are said to intersect an edge. + chunk_size: When splitting in parallel, set the number of edges per + unit of work. + + Returns: + Network with edges split at nodes (within proximity tolerance). + """ split_edges = [] n = len(network.edges) - if PARALLEL and (n > 10_000): - chunk_size = int(n / os.cpu_count()) + if PARALLEL_PROCESS_COUNT > 1: + if chunk_size is None: + chunk_size = max([1, int(n / PARALLEL_PROCESS_COUNT)]) args = [ - (network.edges.iloc[i : i + chunk_size, :], network.nodes, tolerance) + (network.edges.iloc[i: i + chunk_size, :], network.nodes, tolerance) for i in range(0, n, chunk_size) ] - with multiprocessing.Pool() as pool: + with multiprocessing.Pool(PARALLEL_PROCESS_COUNT) as pool: results = pool.starmap(_split_edges_at_nodes, args) # flatten return list