diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..039c5fd3 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: patch + changes: + added: + - Output storage for household simulation calculations in simulations table diff --git a/policyengine_api/data/initialise.sql b/policyengine_api/data/initialise.sql index 117045cf..01cd0d7d 100644 --- a/policyengine_api/data/initialise.sql +++ b/policyengine_api/data/initialise.sql @@ -111,7 +111,10 @@ CREATE TABLE IF NOT EXISTS simulations ( -- VARCHAR(255) to accommodate both household IDs and geography codes population_id VARCHAR(255) NOT NULL, population_type VARCHAR(50) NOT NULL, - policy_id INT NOT NULL + policy_id INT NOT NULL, + -- output_json stores calculation results for household simulations only + -- For geography simulations, outputs are stored in report_outputs table + output_json JSON DEFAULT NULL ); CREATE TABLE IF NOT EXISTS report_outputs ( diff --git a/policyengine_api/data/initialise_local.sql b/policyengine_api/data/initialise_local.sql index 13bc68ff..42f78c35 100644 --- a/policyengine_api/data/initialise_local.sql +++ b/policyengine_api/data/initialise_local.sql @@ -120,7 +120,10 @@ CREATE TABLE IF NOT EXISTS simulations ( -- VARCHAR(255) to accommodate both household IDs and geography codes population_id VARCHAR(255) NOT NULL, population_type VARCHAR(50) NOT NULL, - policy_id INT NOT NULL + policy_id INT NOT NULL, + -- output_json stores calculation results for household simulations only + -- For geography simulations, outputs are stored in report_outputs table + output_json JSON DEFAULT NULL ); CREATE TABLE IF NOT EXISTS report_outputs ( diff --git a/policyengine_api/routes/simulation_routes.py b/policyengine_api/routes/simulation_routes.py index 6669ef89..986a97b1 100644 --- a/policyengine_api/routes/simulation_routes.py +++ b/policyengine_api/routes/simulation_routes.py @@ -133,3 +133,78 @@ def get_simulation(country_id: str, simulation_id: int) -> Response: status=200, mimetype="application/json", ) + + +@simulation_bp.route("//simulation", methods=["PATCH"]) +@validate_country +def update_simulation(country_id: str) -> Response: + """ + Update a simulation record with calculation output. + + Args: + country_id (str): The country ID. + + Request body can contain: + - id (int): The simulation ID. + - output_json (str): The calculation output as JSON string (for household simulations) + """ + + payload = request.json + if payload is None: + raise BadRequest("Payload missing from request") + + # Extract fields + simulation_id = payload.get("id") + output_json = payload.get("output_json") + print(f"Updating simulation #{simulation_id} for country {country_id}") + + # Validate that id is provided + if simulation_id is None: + raise BadRequest("id is required") + if not isinstance(simulation_id, int): + raise BadRequest("id must be an integer") + + # Validate that at least one field is being updated + if output_json is None: + raise BadRequest("output_json must be provided for update") + + try: + # First check if the simulation exists + existing_simulation = simulation_service.get_simulation( + country_id, simulation_id + ) + if existing_simulation is None: + raise NotFound(f"Simulation #{simulation_id} not found.") + + # Update the simulation + success = simulation_service.update_simulation_output( + country_id=country_id, + simulation_id=simulation_id, + output_json=output_json, + ) + + if not success: + raise BadRequest("No fields to update") + + # Get the updated record + updated_simulation = simulation_service.get_simulation( + country_id, simulation_id + ) + + response_body = dict( + status="ok", + message="Simulation updated successfully", + result=updated_simulation, + ) + + return Response( + json.dumps(response_body), + status=200, + mimetype="application/json", + ) + + except NotFound: + raise + except Exception as e: + print(f"Error updating simulation: {str(e)}") + raise BadRequest(f"Failed to update simulation: {str(e)}") diff --git a/policyengine_api/services/simulation_service.py b/policyengine_api/services/simulation_service.py index efa4d0b7..39027691 100644 --- a/policyengine_api/services/simulation_service.py +++ b/policyengine_api/services/simulation_service.py @@ -134,3 +134,58 @@ def get_simulation( f"Error fetching simulation #{simulation_id}. Details: {str(e)}" ) raise e + + def update_simulation_output( + self, + country_id: str, + simulation_id: int, + output_json: str | None = None, + ) -> bool: + """ + Update a simulation record with calculation output. + + Args: + country_id (str): The country ID. + simulation_id (int): The simulation ID. + output_json (str | None): The output as JSON string (for household simulations). + + Returns: + bool: True if update was successful. + """ + print(f"Updating simulation {simulation_id} with output") + # Automatically update api_version on every update to latest + api_version: str = COUNTRY_PACKAGE_VERSIONS.get(country_id) + + try: + # Build the update query dynamically based on provided fields + update_fields = [] + update_values = [] + + if output_json is not None: + update_fields.append("output_json = ?") + # Output is already a JSON string from frontend + update_values.append(output_json) + + # Always update API version + update_fields.append("api_version = ?") + update_values.append(api_version) + + if not update_fields: + print("No fields to update") + return False + + # Add simulation_id to the end of values for WHERE clause + update_values.append(simulation_id) + + query = f"UPDATE simulations SET {', '.join(update_fields)} WHERE id = ?" + + database.query(query, tuple(update_values)) + + print(f"Successfully updated simulation #{simulation_id}") + return True + + except Exception as e: + print( + f"Error updating simulation #{simulation_id}. Details: {str(e)}" + ) + raise e