diff --git a/src/open_gira/admin.py b/src/open_gira/admin.py index 5acfcc20..8c815f38 100644 --- a/src/open_gira/admin.py +++ b/src/open_gira/admin.py @@ -5,6 +5,8 @@ import logging import pandas as pd +import geopandas as gpd +import shapely def merge_gadm_admin_levels(preference: pd.DataFrame, alternative: pd.DataFrame) -> pd.DataFrame: @@ -37,4 +39,43 @@ def merge_gadm_admin_levels(preference: pd.DataFrame, alternative: pd.DataFrame) merged = pd.concat([preference, substitute_regions]) - return merged.sort_values("ISO_A3") \ No newline at end of file + return merged.sort_values("ISO_A3") + + +def get_administrative_data(file_path: str, to_epsg: int = None) -> gpd.GeoDataFrame: + """ + Read administrative data (country ISO, country geometry) from disk + + Arguments: + file_path (str): Location of file with country data: + containing an ISO three letter code as 'GID_0' and a geometry as + 'geometry' + to_epsg (int): EPSG code to project data to + + Returns: + gpd.GeoDataFrame: Table of country and geometry data with: + 'iso_a3' and 'geometry' columns + """ + + # read file + gdf = gpd.read_file(file_path) + + # check schema is as expected + expected_columns = {"GID_0", "geometry"} + assert expected_columns.issubset(set(gdf.columns.values)) + + # reproject if desired + if to_epsg is not None: + gdf = gdf.to_crs(epsg=to_epsg) + + # rename these columns first so we don't have to do this twice (to nodes and edges) later + gdf.rename(columns={"GID_0": "iso_a3"}, inplace=True) + + # subset, sort by iso_a3 and return + return gdf[["iso_a3", "geometry"]].sort_values(by=["iso_a3"], ascending=True) + + +def boundary_geom(gdf: gpd.GeoDataFrame, iso_a3: str) -> shapely.Geometry: + """Given administrative data, return the boundary geometry for a given ISO3 country code + """ + return gdf.set_index("iso_a3").loc[iso_a3, "geometry"] diff --git a/workflow/transport/create_rail_network.py b/workflow/transport/create_rail_network.py index 487f93e4..fa230a62 100644 --- a/workflow/transport/create_rail_network.py +++ b/workflow/transport/create_rail_network.py @@ -10,7 +10,8 @@ import geopandas as gpd -from utils import annotate_country, get_administrative_data +from utils import annotate_country +from open_gira.admin import get_administrative_data from open_gira.assets import RailAssets from open_gira.io import write_empty_frames from open_gira.network import create_network diff --git a/workflow/transport/create_road_network.py b/workflow/transport/create_road_network.py index 6ddbe87f..44e07f35 100644 --- a/workflow/transport/create_road_network.py +++ b/workflow/transport/create_road_network.py @@ -13,7 +13,8 @@ import pandas as pd import snkit -from utils import annotate_country, cast, get_administrative_data, strip_suffix +from utils import annotate_country, cast, strip_suffix +from open_gira.admin import get_administrative_data from open_gira.assets import RoadAssets from open_gira.io import write_empty_frames from open_gira.network import create_network diff --git a/workflow/transport/utils.py b/workflow/transport/utils.py index decfc096..10557d9d 100644 --- a/workflow/transport/utils.py +++ b/workflow/transport/utils.py @@ -60,39 +60,6 @@ def cast(x: Any, *, casting_function: Callable, nullable: bool) -> Any: raise ValueError("Couldn't recast to non-nullable value") from casting_error -def get_administrative_data(file_path: str, to_epsg: int = None) -> gpd.GeoDataFrame: - """ - Read administrative data (country ISO, country geometry) from disk - - Arguments: - file_path (str): Location of file with country data: - containing an ISO three letter code as 'GID_0' and a geometry as - 'geometry' - to_epsg (int): EPSG code to project data to - - Returns: - gpd.GeoDataFrame: Table of country and geometry data with: - 'iso_a3' and 'geometry' columns - """ - - # read file - gdf = gpd.read_file(file_path) - - # check schema is as expected - expected_columns = {"GID_0", "geometry"} - assert expected_columns.issubset(set(gdf.columns.values)) - - # reproject if desired - if to_epsg is not None: - gdf = gdf.to_crs(epsg=to_epsg) - - # rename these columns first so we don't have to do this twice (to nodes and edges) later - gdf.rename(columns={"GID_0": "iso_a3"}, inplace=True) - - # subset, sort by iso_a3 and return - return gdf[["iso_a3", "geometry"]].sort_values(by=["iso_a3"], ascending=True) - - def annotate_country(network: snkit.network.Network, countries: gpd.GeoDataFrame) -> snkit.network.Network: """ Label network edges and nodes with their country ISO code