diff --git a/app/src/app/components/sidebar/charts/HorizontalBarChart.tsx b/app/src/app/components/sidebar/charts/HorizontalBarChart.tsx index 4decb5d44..8fa91d975 100644 --- a/app/src/app/components/sidebar/charts/HorizontalBarChart.tsx +++ b/app/src/app/components/sidebar/charts/HorizontalBarChart.tsx @@ -63,7 +63,6 @@ export const HorizontalBar = () => { useMemo(() => { if (mapMetrics) { - console.log(numDistricts, idealPopulation); const chartObject = calculateChartObject(); setTotalExpectedBars(chartObject); } @@ -88,7 +87,7 @@ export const HorizontalBar = () => { return ( - Population by District + Population by district str: + """ + Download a file from S3 to the local volume path. + + Args: + s3: S3 client + url (ParseResult): URL of the file to download + replace (bool): If True, replace the file if it already exists + + Returns the path to the downloaded file. + """ + if not s3: + raise ValueError("S3 client is not available") + + file_name = url.path.lstrip("/") + logger.info("File name: %s", file_name) + object_information = s3.head_object(Bucket=url.netloc, Key=file_name) + + if object_information["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise ValueError( + f"GeoPackage file {file_name} not found in S3 bucket {url.netloc}" + ) + + logger.info("Downloading GerryDB view. Got response:\n%s", object_information) + + path = os.path.join(settings.VOLUME_PATH, file_name) + + if os.path.exists(path) and not replace: + logger.info("File already exists. Skipping download.") + else: + logger.info("Downloading file...") + s3.download_file(url.netloc, file_name, path) + + return path diff --git a/backend/cli.py b/backend/cli.py index 5aab2a6fa..47f469eb8 100644 --- a/backend/cli.py +++ b/backend/cli.py @@ -2,10 +2,10 @@ import click import logging -from app.main import get_session +from app.core.db import engine from app.core.config import settings import subprocess -from urllib.parse import urlparse, ParseResult +from urllib.parse import urlparse from sqlalchemy import text from app.constants import GERRY_DB_SCHEMA from app.utils import ( @@ -14,51 +14,53 @@ create_parent_child_edges as _create_parent_child_edges, add_extent_to_districtrmap as _add_extent_to_districtrmap, add_available_summary_stats_to_districtrmap as _add_available_summary_stats_to_districtrmap, + update_districtrmap as _update_districtrmap, + download_file_from_s3, ) +from functools import wraps +from contextlib import contextmanager +from sqlmodel import Session +from typing import Callable, TypeVar, Any logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) -@click.group() -def cli(): - pass +T = TypeVar("T") -def download_file_from_s3(s3, url: ParseResult, replace=False) -> str: - """ - Download a file from S3 to the local volume path. +@contextmanager +def session_scope(): + """Provide a transactional scope around a series of operations.""" + session = Session(engine) + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() - Args: - s3: S3 client - url (ParseResult): URL of the file to download - replace (bool): If True, replace the file if it already exists - Returns the path to the downloaded file. +def with_session(f: Callable[..., T]) -> Callable[..., T]: + """ + Decorator that handles database session creation and cleanup. + Compatible with Click commands. """ - if not s3: - raise ValueError("S3 client is not available") - - file_name = url.path.lstrip("/") - logger.info("File name: %s", file_name) - object_information = s3.head_object(Bucket=url.netloc, Key=file_name) - - if object_information["ResponseMetadata"]["HTTPStatusCode"] != 200: - raise ValueError( - f"GeoPackage file {file_name} not found in S3 bucket {url.netloc}" - ) - logger.info("Downloading GerryDB view. Got response:\n%s", object_information) + @wraps(f) + def decorator(*args: Any, **kwargs: Any) -> T: + with session_scope() as session: + kwargs["session"] = session + return f(*args, **kwargs) - path = os.path.join(settings.VOLUME_PATH, file_name) + return decorator - if os.path.exists(path) and not replace: - logger.info("File already exists. Skipping download.") - else: - logger.info("Downloading file...") - s3.download_file(url.netloc, file_name, path) - return path +@click.group() +def cli(): + pass @cli.command("import-gerrydb-view") @@ -66,7 +68,10 @@ def download_file_from_s3(s3, url: ParseResult, replace=False) -> str: @click.option("--gpkg", "-g", help="Path or URL to GeoPackage file", required=True) @click.option("--replace", "-f", help="Replace the file if it exists", is_flag=True) @click.option("--rm", "-r", help="Delete file after loading to postgres", is_flag=True) -def import_gerrydb_view(layer: str, gpkg: str, replace: bool, rm: bool): +@with_session +def import_gerrydb_view( + session: Session, layer: str, gpkg: str, replace: bool, rm: bool +): logger.info("Importing GerryDB view...") url = urlparse(gpkg) @@ -110,9 +115,6 @@ def import_gerrydb_view(layer: str, gpkg: str, replace: bool, rm: bool): logger.info("GerryDB view imported successfully") - _session = get_session() - session = next(_session) - upsert_query = text( """ INSERT INTO gerrydbtable (uuid, name, updated_at) @@ -123,50 +125,39 @@ def import_gerrydb_view(layer: str, gpkg: str, replace: bool, rm: bool): """ ) - try: - session.execute( - upsert_query, - { - "name": layer, - }, - ) - session.commit() - logger.info("GerryDB view upserted successfully.") - except Exception as e: - session.rollback() - logger.error("Failed to upsert GerryDB view. Got %s", e) - raise ValueError(f"Failed to upsert GerryDB view. Got {e}") - - session.close() + session.execute( + upsert_query, + { + "name": layer, + }, + ) + logger.info("GerryDB view upserted successfully.") @cli.command("create-parent-child-edges") @click.option("--districtr-map", "-d", help="Districtr map name", required=True) -def create_parent_child_edges(districtr_map: str): +@with_session +def create_parent_child_edges(session: Session, districtr_map: str): logger.info("Creating parent-child edges...") - session = next(get_session()) stmt = text( "SELECT uuid FROM districtrmap WHERE gerrydb_table_name = :districtrmap_name" ) (districtr_map_uuid,) = session.execute( stmt, params={"districtrmap_name": districtr_map} ).one() - print(f"Found districtmap uuid: {districtr_map_uuid}") + logger.info(f"Found districtmap uuid: {districtr_map_uuid}") + _create_parent_child_edges(session=session, districtr_map_uuid=districtr_map_uuid) - session.commit() logger.info("Parent-child relationship upserted successfully.") - session.close() - @cli.command("delete-parent-child-edges") @click.option("--districtr-map", "-d", help="Districtr map name", required=True) -def delete_parent_child_edges(districtr_map: str): +@with_session +def delete_parent_child_edges(session: Session, districtr_map: str): logger.info("Deleting parent-child edges...") - session = next(get_session()) - delete_query = text( """ DELETE FROM parentchildedges @@ -179,11 +170,8 @@ def delete_parent_child_edges(districtr_map: str): "districtr_map": districtr_map, }, ) - session.commit() logger.info("Parent-child relationship upserted successfully.") - session.close() - @cli.command("create-districtr-map") @click.option("--name", help="Name of the districtr map", required=True) @@ -204,7 +192,9 @@ def delete_parent_child_edges(districtr_map: str): default=None, nargs=4, ) +@with_session def create_districtr_map( + session: Session, name: str, parent_layer_name: str, child_layer_name: str | None, @@ -215,7 +205,6 @@ def create_districtr_map( bounds: list[float] | None = None, ): logger.info("Creating districtr map...") - session = next(get_session()) (districtr_map_uuid,) = _create_districtr_map( session=session, name=name, @@ -238,28 +227,88 @@ def create_districtr_map( session=session, districtr_map_uuid=districtr_map_uuid ) - session.commit() logger.info(f"Districtr map created successfully {districtr_map_uuid}") +@cli.command("update-districtr-map") +@click.option( + "--gerrydb-table-name", + "-n", + help="Name of the GerryDB table", + type=str, + required=True, +) +@click.option("--name", help="Name of the districtr map", type=str, required=False) +@click.option( + "--parent-layer-name", help="Parent gerrydb layer name", type=str, required=False +) +@click.option( + "--child-layer-name", help="Child gerrydb layer name", type=str, required=False +) +@click.option("--num-districts", help="Number of districts", type=str, required=False) +@click.option( + "--tiles-s3-path", help="S3 path to the tileset", type=str, required=False +) +@click.option("--visibility", "-v", help="Visibility", type=bool, required=False) +@click.option( + "--bounds", + "-b", + help="Bounds of the extent as `--bounds x_min y_min x_max y_max`", + required=False, + type=float, + default=None, + nargs=4, +) +@with_session +def update_districtr_map( + session: Session, + gerrydb_table_name: str, + name: str | None, + parent_layer_name: str | None, + child_layer_name: str | None, + num_districts: int | None, + tiles_s3_path: str | None, + visibility: bool = False, + bounds: list[float] | None = None, +): + logger.info("Updating districtr map...") + + _bounds = None + if bounds and len(bounds) == 4: + _bounds = bounds + + result = _update_districtrmap( + session=session, + gerrydb_table_name=gerrydb_table_name, + name=name, + parent_layer=parent_layer_name, + child_layer=child_layer_name, + num_districts=num_districts, + tiles_s3_path=tiles_s3_path, + visible=visibility, + bounds=_bounds, + ) + logger.info(f"Districtr map updated successfully {result}") + + @cli.command("create-shatterable-districtr-view") @click.option("--parent-layer-name", help="Parent gerrydb layer name", required=True) @click.option("--child-layer-name", help="Child gerrydb layer name", required=False) @click.option("--gerrydb-table-name", help="Name of the GerryDB table", required=False) +@with_session def create_shatterable_gerrydb_view( + session: Session, parent_layer_name: str, child_layer_name: str, gerrydb_table_name: str, ): logger.info("Creating materialized shatterable gerrydb view...") - session = next(get_session()) inserted_uuid = _create_shatterable_gerrydb_view( session=session, parent_layer_name=parent_layer_name, child_layer_name=child_layer_name, gerrydb_table_name=gerrydb_table_name, ) - session.commit() logger.info( f"Materialized shatterable gerrydb view created successfully {inserted_uuid}" ) @@ -276,10 +325,12 @@ def create_shatterable_gerrydb_view( default=None, nargs=4, ) -def add_extent_to_districtr_map(districtr_map: str, bounds: list[float] | None = None): +@with_session +def add_extent_to_districtr_map( + session: Session, districtr_map: str, bounds: list[float] | None = None +): logger.info(f"User provided bounds: {bounds}") - session = next(get_session()) stmt = text( "SELECT uuid FROM districtrmap WHERE gerrydb_table_name = :districtrmap_name" ) @@ -291,16 +342,13 @@ def add_extent_to_districtr_map(districtr_map: str, bounds: list[float] | None = _add_extent_to_districtrmap( session=session, districtr_map_uuid=districtr_map_uuid, bounds=bounds ) - session.commit() logger.info("Updated extent successfully.") - session.close() - @cli.command("add-available-summary-stats-to-districtr-map") @click.option("--districtr-map", "-d", help="Districtr map name", required=True) -def add_available_summary_stats_to_districtr_map(districtr_map: str): - session = next(get_session()) +@with_session +def add_available_summary_stats_to_districtr_map(session: Session, districtr_map: str): stmt = text( "SELECT uuid FROM districtrmap WHERE gerrydb_table_name = :districtrmap_name" ) @@ -313,9 +361,7 @@ def add_available_summary_stats_to_districtr_map(districtr_map: str): session=session, districtr_map_uuid=districtr_map_uuid ) - session.commit() logger.info("Updated available summary stats successfully.") - session.close() if __name__ == "__main__": diff --git a/backend/tests/test_utils.py b/backend/tests/test_utils.py index 7e03b154a..e730ff424 100644 --- a/backend/tests/test_utils.py +++ b/backend/tests/test_utils.py @@ -6,10 +6,12 @@ create_parent_child_edges, add_extent_to_districtrmap, get_available_summary_stats, + update_districtrmap, ) from sqlmodel import Session import subprocess from app.constants import GERRY_DB_SCHEMA +from app.models import DistrictrMap from tests.constants import OGR2OGR_PG_CONNECTION_STRING, FIXTURES_PATH from sqlalchemy import text @@ -193,6 +195,36 @@ def test_create_districtr_map_some_nulls(session: Session, simple_parent_geos_ge session.commit() +@pytest.fixture(name="simple_parent_geos_districtrmap") +def simple_parent_geos_districtrmap_fixture( + session: Session, simple_parent_geos_gerrydb, simple_child_geos_gerrydb +): + gerrydb_name = "simple_geos_test" + (inserted_districtr_map,) = create_districtr_map( + session, + name="Simple shatterable layer", + gerrydb_table_name=gerrydb_name, + num_districts=10, + tiles_s3_path="tilesets/simple_shatterable_layer.pmtiles", + parent_layer_name="simple_parent_geos", + child_layer_name="simple_child_geos", + visibility=True, + ) + session.commit() + return gerrydb_name + + +def test_update_districtr_map(session: Session, simple_parent_geos_districtrmap): + result = update_districtrmap( + session=session, + gerrydb_table_name=simple_parent_geos_districtrmap, + visible=False, + ) + session.commit() + districtr_map = DistrictrMap.model_validate(result) + assert not districtr_map.visible + + def test_add_extent_to_districtrmap(session: Session, simple_parent_geos_gerrydb): (inserted_districtr_map,) = create_districtr_map( session, diff --git a/pipelines/simple_elt/main.py b/pipelines/simple_elt/main.py index 2dffe60ca..8327edb28 100644 --- a/pipelines/simple_elt/main.py +++ b/pipelines/simple_elt/main.py @@ -332,20 +332,22 @@ def load_districtr_v1_places(replace: bool = False) -> None: s3_client.upload_file(districtr_places, settings.S3_BUCKET, key) -@cli.command("load-districtr-v1-problems") -@click.option("--replace", is_flag=True, help="Replace existing files", default=False) -def load_districtr_v1_problems(replace: bool = False) -> None: +def upsert_places_and_problems(): """ - load problems definition json file for states from districtr_v1 and store in s3 bucket + Upsert places and problems from districtr_v1. + WIP/not functional port of load_dv1_places_and_problems_problems in #167 """ + + raise NotImplementedError + s3_client = settings.get_s3_client() - # check if the districtr_places object exists in s3; if not, download it using load_districtr_v1_places key = f"{S3_PREFIX}/districtr_places/districtr_v1_places.json" - if not exists_in_s3(s3_client, settings.S3_BUCKET, key): - load_districtr_v1_places() + districtr_places = download_file_from_s3(s3_client, urlparse(key)) - districtr_places = download_file_from_s3(s3_client, urlparse(key), replace) + if not districtr_places: + LOGGER.error("Failed to download districtr_v1_places.json") + return with open(districtr_places, "r") as file: places = json.load(file) @@ -366,24 +368,25 @@ def load_districtr_v1_problems(replace: bool = False) -> None: LOGGER.error(f"Failed to download problems for {url}: {e}") continue - -if __name__ == "__main__": - cli() + # load districtr_v1 places and problems + load_districtr_v1_places() + load_districtr_v1_problems() -def upsert_places_and_problems(): +@cli.command("load-districtr-v1-problems") +@click.option("--replace", is_flag=True, help="Replace existing files", default=False) +def load_districtr_v1_problems(replace: bool = False) -> None: """ - Upsert places and problems from districtr_v1. - WIP/not functional port of load_dv1_places_and_problems_problems in #167 + load problems definition json file for states from districtr_v1 and store in s3 bucket """ s3_client = settings.get_s3_client() + # check if the districtr_places object exists in s3; if not, download it using load_districtr_v1_places key = f"{S3_PREFIX}/districtr_places/districtr_v1_places.json" - districtr_places = download_file_from_s3(s3_client, urlparse(key)) + if not exists_in_s3(s3_client, settings.S3_BUCKET, key): + load_districtr_v1_places() - if not districtr_places: - LOGGER.error("Failed to download districtr_v1_places.json") - return + districtr_places = download_file_from_s3(s3_client, urlparse(key), replace) with open(districtr_places, "r") as file: places = json.load(file) @@ -404,6 +407,6 @@ def upsert_places_and_problems(): LOGGER.error(f"Failed to download problems for {url}: {e}") continue - # load districtr_v1 places and problems - load_districtr_v1_places() - load_districtr_v1_problems() + +if __name__ == "__main__": + cli()