Skip to content

Commit

Permalink
Merge pull request #8 from hotosm/refactor/db-env-fallback
Browse files Browse the repository at this point in the history
Use httpcore unasync impl, fallback to env vars for NearestCity.connect(), add env var tests
  • Loading branch information
spwoodcock authored Feb 11, 2025
2 parents 2e632e4 + 2e70f1d commit 8a1700c
Show file tree
Hide file tree
Showing 17 changed files with 414 additions and 202 deletions.
12 changes: 1 addition & 11 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
repos:
# Dev: install async for next command
- repo: local
hooks:
- id: dev-deps
name: instal-dev-deps
language: system
entry: uv sync --group dev
always_run: true
pass_filenames: false

# Unasync: Convert async --> sync
- repo: local
hooks:
- id: unasync
name: unasync-all
language: system
entry: uv run python build_sync.py
entry: python unasync.py
always_run: true
pass_filenames: false

Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ from pg_nearest_city import AsyncNearestCity
# Existing code to get db connection, say from API endpoint
db = await get_db_connection()

async with AsyncNearestCity.connect(db) as geocoder:
async with AsyncNearestCity(db) as geocoder:
location = await geocoder.query(40.7128, -74.0060)

print(location.city)
Expand All @@ -127,7 +127,7 @@ from pg_nearest_city import NearestCity
# Existing code to get db connection, say from API endpoint
db = get_db_connection()

with AsyncNearestCity.connect(db) as geocoder:
with NearestCity(db) as geocoder:
location = geocoder.query(40.7128, -74.0060)

print(location.city)
Expand All @@ -154,7 +154,7 @@ db_config = DbConfig(
port="5432",
)

async with AsyncNearestCity.connect(db_config) as geocoder:
async with AsyncNearestCity(db_config) as geocoder:
location = await geocoder.query(40.7128, -74.0060)
```

Expand All @@ -171,9 +171,9 @@ PGNEAREST_DB_PORT=5432
then

```python
from pg_nearest_city import DbConfig, AsyncNearestCity
from pg_nearest_city import AsyncNearestCity

async with AsyncNearestCity.connect() as geocoder:
async with AsyncNearestCity() as geocoder:
location = await geocoder.query(40.7128, -74.0060)
```

Expand Down
5 changes: 5 additions & 0 deletions compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ services:
- ./pg_nearest_city:/opt/python/lib/python3.10/site-packages/pg_nearest_city:ro
# Mount local tests
- ./tests:/data/tests:ro
environment:
- PGNEAREST_DB_HOST=db
- PGNEAREST_DB_USER=cities
- PGNEAREST_DB_PASSWORD=dummycipassword
- PGNEAREST_DB_NAME=cities
depends_on:
db:
condition: service_healthy
Expand Down
4 changes: 2 additions & 2 deletions pg_nearest_city/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""The main pg_nearest_city package."""

from .async_nearest_city import AsyncNearestCity
from ._async.nearest_city import AsyncNearestCity
from ._sync.nearest_city import NearestCity
from .base_nearest_city import DbConfig, Location
from .nearest_city import NearestCity

__all__ = ["NearestCity", "AsyncNearestCity", "DbConfig", "Location"]
1 change: 1 addition & 0 deletions pg_nearest_city/_async/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Async implementation."""
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import gzip
import importlib.resources
import logging
from contextlib import asynccontextmanager
from typing import Optional
from textwrap import dedent, fill

import psycopg
from psycopg import AsyncCursor

from pg_nearest_city import base_nearest_city
from pg_nearest_city.base_nearest_city import (
BaseNearestCity,
DbConfig,
InitializationStatus,
Location,
)
Expand All @@ -22,47 +22,24 @@
class AsyncNearestCity:
"""Reverse geocoding to the nearest city over 1000 population."""

@classmethod
@asynccontextmanager
async def connect(cls, db: psycopg.AsyncConnection | base_nearest_city.DbConfig):
"""Managed NearestCity instance with automatic initialization and cleanup.
Args:
db: Either a DbConfig for a new connection or an existing psycopg Connection
"""
is_external_connection = isinstance(db, psycopg.AsyncConnection)

conn: psycopg.AsyncConnection

if is_external_connection:
conn = db
else:
conn = await psycopg.AsyncConnection.connect(db.get_connection_string())

geocoder = cls(conn)

try:
await geocoder.initialize()
yield geocoder
finally:
if not is_external_connection:
await conn.close()

def __init__(
self,
connection: psycopg.AsyncConnection,
db: psycopg.AsyncConnection | DbConfig | None = None,
logger: Optional[logging.Logger] = None,
):
"""Initialize reverse geocoder with an existing AsyncConnection.
Args:
db: An existing psycopg AsyncConnection
connection: psycopg.AsyncConnection
logger: Optional custom logger. If not provided, uses package logger
logger: Optional custom logger. If not provided, uses package logger.
"""
# Allow users to provide their own logger while having a sensible default
self._logger = logger or logging.getLogger("pg_nearest_city")
self.connection = connection
self._db = db
self.connection: psycopg.AsyncConnection = None
self._is_external_connection = False
self._is_initialized = False

with importlib.resources.path(
"pg_nearest_city.data", "cities_1000_simple.txt.gz"
Expand All @@ -73,8 +50,43 @@ def __init__(
) as voronoi_path:
self.voronoi_file = voronoi_path

async def __aenter__(self):
"""Open the context manager."""
self.connection = await self.get_connection(self._db)
# Create the relevant tables and validate
await self.initialize()
self._is_initialized = True
return self

async def __aexit__(self, exc_type, exc_value, traceback):
"""Close the context manager."""
if self.connection and not self._is_external_connection:
await self.connection.close()
self._initialized = False

async def get_connection(
self,
db: Optional[psycopg.AsyncConnection | DbConfig] = None,
) -> psycopg.AsyncConnection:
"""Determine the database connection to use."""
self._is_external_connection = isinstance(db, psycopg.AsyncConnection)
is_db_config = isinstance(db, DbConfig)

if self._is_external_connection:
return db
elif is_db_config:
return await psycopg.AsyncConnection.connect(db.get_connection_string())
else:
# Fallback to env var extraction, or defaults for testing
return await psycopg.AsyncConnection.connect(
DbConfig().get_connection_string(),
)

async def initialize(self) -> None:
"""Initialize the geocoding database with validation checks."""
if not self.connection:
self._inform_user_if_not_context_manager()

try:
async with self.connection.cursor() as cur:
self._logger.info("Starting database initialization check")
Expand Down Expand Up @@ -126,6 +138,20 @@ async def initialize(self) -> None:
self._logger.error("Database initialization failed: %s", str(e))
raise RuntimeError(f"Database initialization failed: {str(e)}") from e

def _inform_user_if_not_context_manager(self):
"""Raise an error if the context manager was not used."""
if not self._is_initialized:
raise RuntimeError(
fill(
dedent("""
AsyncNearestCity must be used within 'async with' context.\n
For example:\n
async with AsyncNearestCity() as geocoder:\n
details = geocoder.query(lat, lon)
""")
)
)

async def query(self, lat: float, lon: float) -> Optional[Location]:
"""Find the nearest city to the given coordinates using Voronoi regions.
Expand All @@ -140,13 +166,16 @@ async def query(self, lat: float, lon: float) -> Optional[Location]:
ValueError: If coordinates are out of valid ranges
RuntimeError: If database query fails
"""
# Throw an error if not used in 'with' block
self._inform_user_if_not_context_manager()

# Validate coordinate ranges
BaseNearestCity.validate_coordinates(lon, lat)

try:
async with self.connection.cursor() as cur:
await cur.execute(
BaseNearestCity._get_reverse_geocoding_query(lon, lat)
BaseNearestCity._get_reverse_geocoding_query(lon, lat),
)
result = await cur.fetchone()

Expand All @@ -164,7 +193,8 @@ async def query(self, lat: float, lon: float) -> Optional[Location]:
raise RuntimeError(f"Reverse geocoding failed: {str(e)}") from e

async def _check_initialization_status(
self, cur: psycopg.AsyncCursor
self,
cur: psycopg.AsyncCursor,
) -> InitializationStatus:
"""Check the status and integrity of the geocoding database.
Expand Down Expand Up @@ -259,7 +289,7 @@ async def _import_voronoi_polygons(self, cur: AsyncCursor):

# Import the binary WKB data
async with cur.copy(
"COPY voronoi_import (city, country, wkb) FROM STDIN"
"COPY voronoi_import (city, country, wkb) FROM STDIN",
) as copy:
with gzip.open(self.voronoi_file, "rb") as f:
while data := f.read(8192):
Expand Down
1 change: 1 addition & 0 deletions pg_nearest_city/_sync/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Async implementation."""
Loading

0 comments on commit 8a1700c

Please sign in to comment.