Skip to content

Commit

Permalink
Add CLI command and utility function to update a DistrictrMap (#189)
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaellaude authored Nov 19, 2024
1 parent 69fb2d2 commit 8fc81fb
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 78 deletions.
11 changes: 11 additions & 0 deletions backend/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@ class DistrictrMapPublic(BaseModel):
available_summary_stats: list[str] | None = None


class DistrictrMapUpdate(BaseModel):
gerrydb_table_name: str
name: str | None = None
parent_layer: str | None = None
child_layer: str | None = None
tiles_s3_path: str | None = None
num_districts: int | None = None
visible: bool | None = None
available_summary_stats: list[str] | None = None


class GerryDBTable(TimeStampMixin, SQLModel, table=True):
uuid: str = Field(sa_column=Column(UUIDType, unique=True, primary_key=True))
# Must correspond to the layer name in the tileset
Expand Down
78 changes: 76 additions & 2 deletions backend/app/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from sqlalchemy import text
from sqlalchemy import text, update
from sqlalchemy import bindparam, Integer, String, Text
from sqlalchemy.types import UUID
from sqlmodel import Session, Float, Boolean
import logging
from urllib.parse import ParseResult
import os
from app.core.config import settings


from app.models import SummaryStatisticType, UUIDType
from app.models import SummaryStatisticType, UUIDType, DistrictrMap, DistrictrMapUpdate

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -74,6 +77,41 @@ def create_districtr_map(
return inserted_map_uuid # pyright: ignore


def update_districtrmap(
session: Session,
gerrydb_table_name: str,
**kwargs,
):
"""
Update a districtr map.
Args:
session: The database session.
gerrydb_table_name: The name of the gerrydb table.
**kwargs: The fields to update.
Returns:
The updated districtr map.
"""
data = DistrictrMapUpdate(gerrydb_table_name=gerrydb_table_name, **kwargs)
update_districtrmap = data.model_dump(
exclude_unset=True, exclude={"gerrydb_table_name"}, exclude_none=True
)

if not update_districtrmap.keys():
raise KeyError("No fields to update")

stmt = (
update(DistrictrMap)
.where(DistrictrMap.gerrydb_table_name == data.gerrydb_table_name) # pyright: ignore
.values(update_districtrmap)
.returning(DistrictrMap)
)
(updated_districtrmap,) = session.execute(stmt).one()

return updated_districtrmap


def create_shatterable_gerrydb_view(
session: Session,
parent_layer_name: str,
Expand Down Expand Up @@ -273,3 +311,39 @@ def add_available_summary_stats_to_districtrmap(
f"Updated available summary stats for districtr map {districtr_map_uuid} to {available_summary_stats}"
)
return available_summary_stats


def download_file_from_s3(s3, url: ParseResult, replace=False) -> str:
"""
Download a file from S3 to the local volume path.
Args:
s3: S3 client
url (ParseResult): URL of the file to download
replace (bool): If True, replace the file if it already exists
Returns the path to the downloaded file.
"""
if not s3:
raise ValueError("S3 client is not available")

file_name = url.path.lstrip("/")
logger.info("File name: %s", file_name)
object_information = s3.head_object(Bucket=url.netloc, Key=file_name)

if object_information["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise ValueError(
f"GeoPackage file {file_name} not found in S3 bucket {url.netloc}"
)

logger.info("Downloading GerryDB view. Got response:\n%s", object_information)

path = os.path.join(settings.VOLUME_PATH, file_name)

if os.path.exists(path) and not replace:
logger.info("File already exists. Skipping download.")
else:
logger.info("Downloading file...")
s3.download_file(url.netloc, file_name, path)

return path
Loading

0 comments on commit 8fc81fb

Please sign in to comment.