Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
rvankoert committed Jun 26, 2024
2 parents 33ab059 + 05192de commit 7cc5931
Show file tree
Hide file tree
Showing 9 changed files with 367 additions and 275 deletions.
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
flask==3.0.2
gunicorn==22.0.0
numpy==1.26.4
editdistance==0.8.1
tensorflow==2.14.1
Expand All @@ -15,3 +13,6 @@ xlsxwriter==3.2.0
six
Pillow==10.3.0
h5py==3.10.0
fastapi==0.111.0
uvicorn==0.30.1
typing-extensions==4.12.2
181 changes: 122 additions & 59 deletions src/api/app.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,142 @@
# Imports

# > Standard library
import asyncio
from contextlib import asynccontextmanager
import socket
import multiprocessing as mp

# > Third-party dependencies
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from uvicorn.config import Config
from uvicorn.server import Server

# > Local dependencies
import errors
from routes import main
from app_utils import setup_logging, get_env_variable, start_workers
from simple_security import SimpleSecurity
from app_utils import (setup_logging, get_env_variable,
start_workers, stop_workers)
from routes import create_router

# > Third-party dependencies
from flask import Flask
# Set up logging
logging_level = get_env_variable("LOGGING_LEVEL", "INFO")
logger = setup_logging(logging_level)

# Get Loghi-HTR options from environment variables
logger.info("Getting Loghi-HTR options from environment variables")
batch_size = int(get_env_variable("LOGHI_BATCH_SIZE", "256"))
model_path = get_env_variable("LOGHI_MODEL_PATH")
output_path = get_env_variable("LOGHI_OUTPUT_PATH")
max_queue_size = int(get_env_variable("LOGHI_MAX_QUEUE_SIZE", "10000"))
patience = float(get_env_variable("LOGHI_PATIENCE", "0.5"))

def create_app() -> Flask:
"""
Create and configure a Flask app for image prediction.
# Get GPU options from environment variables
logger.info("Getting GPU options from environment variables")
gpus = get_env_variable("LOGHI_GPUS", "0")

This function initializes a Flask app, sets up necessary configurations,
starts image preparation and batch prediction processes, and returns the
configured app instance.

Returns
-------
Flask
Configured Flask app instance ready for serving.
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Manage the lifespan of the FastAPI application.
Parameters
----------
app : FastAPI
The FastAPI application instance.
Side Effects
------------
- Initializes and starts preparation, prediction, and decoding processes.
- Logs various messages regarding the app and process initialization.
Yields
------
None
"""
# Create a stop event
stop_event = mp.Event()

# Set up logging
logging_level = get_env_variable("LOGGING_LEVEL", "INFO")
logger = setup_logging(logging_level)

# Get Loghi-HTR options from environment variables
logger.info("Getting Loghi-HTR options from environment variables")
batch_size = int(get_env_variable("LOGHI_BATCH_SIZE", "256"))
model_path = get_env_variable("LOGHI_MODEL_PATH")
output_path = get_env_variable("LOGHI_OUTPUT_PATH")
max_queue_size = int(get_env_variable("LOGHI_MAX_QUEUE_SIZE", "10000"))
patience = float(get_env_variable("LOGHI_PATIENCE", "0.5"))

# Get GPU options from environment variables
logger.info("Getting GPU options from environment variables")
gpus = get_env_variable("LOGHI_GPUS", "0")

# Create Flask app
logger.info("Creating Flask app")
app = Flask(__name__)

# Register error handler
app.register_error_handler(ValueError, errors.handle_invalid_usage)
app.register_error_handler(405, errors.method_not_allowed)

# Add security to app
security_config = \
{"enabled": get_env_variable("SECURITY_ENABLED", "False"),
"key_user_json": get_env_variable("API_KEY_USER_JSON_STRING", "{}")}
security = SimpleSecurity(app, security_config)
logger.info(f"Security enabled: {security.enabled}")

# Start the worker processes
# Startup: Start the worker processes
logger.info("Starting worker processes")
workers, queues = start_workers(batch_size, max_queue_size, output_path,
gpus, model_path, patience)
gpus, model_path, patience, stop_event)
# Add request queue and stop event to the app
app.state.request_queue = queues["Request"]
app.state.stop_event = stop_event
app.state.workers = workers

yield

# Add request queue to the app
app.request_queue = queues["Request"]
# Shutdown: Stop all workers and join them
logger.info("Shutting down worker processes")
stop_workers(app.state.workers, app.state.stop_event)
logger.info("All workers have been stopped and joined")

# Add the workers to the app
app.workers = workers

# Register blueprints
app.register_blueprint(main)
def create_app() -> FastAPI:
"""
Create and configure the FastAPI application.
Returns
-------
FastAPI
The configured FastAPI application instance.
"""
app = FastAPI(
title="Loghi-HTR API",
description="API for Loghi-HTR",
lifespan=lifespan
)

# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)

# Include the router
router = create_router(app)
app.include_router(router)

return app


app = create_app()


async def run_server():
"""
Run the FastAPI server.
Returns
-------
None
"""
host = get_env_variable("UVICORN_HOST", "127.0.0.1")
port = int(get_env_variable("UVICORN_PORT", "5000"))

# Attempt to resolve the hostname
try:
socket.gethostbyname(host)
except socket.gaierror:
logger.error(
f"Unable to resolve hostname: {host}. Falling back to localhost.")
host = "127.0.0.1"

config = Config("app:app", host=host, port=port, workers=1)
server = Server(config=config)

try:
await server.serve()
except OSError as e:
logger.error(f"Error starting server: {e}")
if e.errno == 98: # Address already in use
logger.error(
f"Port {port} is already in use. Try a different port.")
elif e.errno == 13: # Permission denied
logger.error(
f"Permission denied when trying to bind to port {port}. Try a "
"port number > 1024 or run with sudo.")
except Exception as e:
logger.error(f"Unexpected error occurred: {e}")

if __name__ == "__main__":
asyncio.run(run_server())
Loading

0 comments on commit 7cc5931

Please sign in to comment.