diff --git a/nzgmdb/scripts/generate_report.py b/nzgmdb/scripts/generate_report.py index 55809e1..8a38d0c 100644 --- a/nzgmdb/scripts/generate_report.py +++ b/nzgmdb/scripts/generate_report.py @@ -3,11 +3,12 @@ """ import base64 +import textwrap from collections.abc import Sequence from enum import StrEnum from io import BytesIO from pathlib import Path -from typing import Annotated +from typing import Annotated, Optional import matplotlib.pyplot as plt import numpy as np @@ -32,6 +33,13 @@ class TectonicType(StrEnum): UNKNOWN = "Undetermined" +class NZGMDB_Versions(StrEnum): # noqa: N801 + """Enum for NZGMDB versions.""" + + V4p3 = "4p3" + V4p4 = "4p4" + + OQ_INPUT_COLUMNS = [ "vs30", "rrup", @@ -953,6 +961,125 @@ def mag_rrup_scatter( return base64.b64encode(buf.read()).decode("utf-8") +def plot_site_table_image( + full_df: pd.DataFrame, + quality_df: pd.DataFrame, + skipped_df: pd.DataFrame, + wrap_width: int = 20, +): + """ + Create and save a table-like plot showing, per site: + + - number of records in ``full_df`` + - number of records in ``quality_df`` + - count per skipped reason from ``skipped_df`` + + Parameters + ---------- + full_df : pandas.DataFrame + Full dataset containing a ``sta`` column (station identifier). + quality_df : pandas.DataFrame + Quality-filtered dataset containing a ``sta`` column. + skipped_df : pandas.DataFrame + DataFrame of skipped records. Must contain ``sta`` and ``reason`` columns. + wrap_width : int, optional + Maximum characters before wrapping long column labels for display. Default is 20. + + Returns + ------- + img_base64 : str + Base64-encoded PNG image of the rendered table (or a small placeholder image when no rows match). + n_zero_quality_sites : int + Number of sites where the quality count is zero (before applying the ``quality_count < 5`` filter). + total_zero_quality_records : int + Total number of records from ``full_df`` for sites with zero quality records. + + """ + # Counts + full_counts = full_df["sta"].value_counts().rename("full_count") + quality_counts = quality_df["sta"].value_counts().rename("quality_count") + + # Pivot skipped reasons into columns (one column per reason) + if not skipped_df.empty: + skipped_pivot = ( + skipped_df.groupby(["sta", "reason"]).size().unstack(fill_value=0) + ) + else: + skipped_pivot = pd.DataFrame(index=full_counts.index) + + # Combine into a single DataFrame + combined = ( + pd.concat([full_counts, quality_counts, skipped_pivot], axis=1) + .fillna(0) + .astype(int) + ) + combined = combined.reset_index().rename(columns={"index": "site", "sta": "site"}) + + # Ensure columns order: site, full_count, quality_count, then reasons + reason_cols = sorted( + [ + c + for c in combined.columns + if c not in ("site", "full_count", "quality_count") + ] + ) + ordered_cols = ["site", "full_count", "quality_count"] + reason_cols + combined = combined[ordered_cols] + + # Create a new df for the quality counts that are 0 + combined_zero_quality = combined[combined["quality_count"] == 0] + + # Keep only rows where quality_count < 5 + combined = combined[ + (combined["quality_count"] < 5) & (combined["quality_count"] > 0) + ] + + # Remove all columns where the values are zero + combined = combined.loc[:, (combined != 0).any(axis=0)] + + # Wrap long column labels for display + display_cols = [textwrap.fill(c, wrap_width) for c in combined.columns.tolist()] + + # Figure size heuristics + n_rows = max(1, combined.shape[0]) + n_cols = max(1, combined.shape[1]) + # width per column, height per row + width = min(40, max(6, n_cols * 2.0)) + height = min(40, max(2, n_rows * 0.35 + 1.5)) + + fig, ax = plt.subplots(figsize=(width, height)) + ax.axis("off") + + # Build table; convert all values to strings for consistent display + cell_text = combined.astype(str).values.tolist() + table = ax.table( + cellText=cell_text, colLabels=display_cols, cellLoc="center", loc="center" + ) + + # Styling + table.auto_set_font_size(False) + # choose a font size based on space available + font_size = 10 if n_rows < 40 else max(6, int(200 / (n_rows + n_cols))) + table.set_fontsize(font_size) + table.scale(1, 1.2) + + plt.tight_layout() + buf = BytesIO() + fig.savefig(buf, format="png", bbox_inches="tight") + plt.close(fig) + buf.seek(0) + + # Add returns for the total number of records with zero quality counts and the number of sites + n_zero_quality_sites = combined_zero_quality.shape[0] + total_zero_quality_records = combined_zero_quality["full_count"].sum() + + return ( + base64.b64encode(buf.read()).decode("utf-8"), + n_zero_quality_sites, + total_zero_quality_records, + ) + + @cli.from_docstring(app) def generate_report( new_version_directory: Annotated[ @@ -973,6 +1100,22 @@ def generate_report( file_okay=False, ), ] = None, + new_version: Annotated[ + Optional[NZGMDB_Versions], + typer.Option( + None, + help="The version for the new database (choose from the enum).", + case_sensitive=False, + ), + ] = NZGMDB_Versions.V4p3, + old_version: Annotated[ + Optional[NZGMDB_Versions], + typer.Option( + None, + help="The version for the old database (choose from the enum).", + case_sensitive=False, + ), + ] = NZGMDB_Versions.V4p3, ): """ Generate a HTML report comparing the new version of the database to a previous version. @@ -987,6 +1130,10 @@ def generate_report( compare_version_directory : Path | None The Top Level directory containing the previous version of the database to compare against. If None, a summary of the new version will be generated instead and comparison plots will not be generated. + new_version : NZGMDB_Versions | None + The version for the new database (choose from the enum). Default is NZGMDB_Versions.V4p3. + old_version : NZGMDB_Versions | None + The version for the old database (choose from the enum). Default is NZGMDB_Versions.V4p3. """ html_parts = [] # Start of HTML @@ -1364,8 +1511,9 @@ def generate_report( accepted_lengths_new = [ len( pd.read_csv( - new_flatfiles_dir - / file_structure.PreFlatfileNames.STATION_MAGNITUDE_TABLE_GEONET + new_flatfiles_dir / "station_magnitude_table_geonet.csv" + if new_version == NZGMDB_Versions.V4p3 + else file_structure.PreFlatfileNames.STATION_MAGNITUDE_TABLE_EXTRACTION ) ) / 3, @@ -1382,8 +1530,9 @@ def generate_report( [ len( pd.read_csv( - old_flatifles_dir - / file_structure.PreFlatfileNames.STATION_MAGNITUDE_TABLE_GEONET + old_flatifles_dir / "station_magnitude_table_geonet.csv" + if old_version == NZGMDB_Versions.V4p3 + else file_structure.PreFlatfileNames.STATION_MAGNITUDE_TABLE_EXTRACTION ) ) / 3, @@ -1523,6 +1672,54 @@ def generate_report( html_parts.append(f'') html_parts.append("") + # Add New section for low quality site records + html_parts.append("

New NZGMDB - Low Quality Site Records

") + # Add sta col to skipped_reasons by splitting record_id + new_quality_skipped["sta"] = new_quality_skipped["record_id"].str.split( + "_", expand=True + )[1] + img_base64, n_zero_quality_sites, total_zero_quality_records = ( + plot_site_table_image(full_new, quality_new, new_quality_skipped) + ) + html_parts.append("
") + html_parts.append(f'') + html_parts.append("
") + + # Add text for the zero quality sites + html_parts.append("

Zero Quality Sites Summary

") + html_parts.append("") + + if compare_version_directory: + html_parts.append("

Old NZGMDB - Low Quality Site Records

") + # Add sta col to skipped_reasons by splitting record_id + old_quality_skipped["sta"] = old_quality_skipped["record_id"].str.split( + "_", expand=True + )[1] + img_base64_old, n_zero_quality_sites_old, total_zero_quality_records_old = ( + plot_site_table_image(full_old, quality_old, old_quality_skipped) + ) + html_parts.append("
") + html_parts.append(f'') + html_parts.append("
") + + # Add text for the zero quality sites + html_parts.append("

Zero Quality Sites Summary

") + html_parts.append("") + # Add Category Column Comparisons html_parts.append("

Event Column Comparisons

")