Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 202 additions & 5 deletions nzgmdb/scripts/generate_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
"""

import base64
import textwrap
from collections.abc import Sequence
from enum import StrEnum
from io import BytesIO
from pathlib import Path
from typing import Annotated
from typing import Annotated, Optional

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -32,6 +33,13 @@ class TectonicType(StrEnum):
UNKNOWN = "Undetermined"


class NZGMDB_Versions(StrEnum): # noqa: N801
"""Enum for NZGMDB versions."""

V4p3 = "4p3"
V4p4 = "4p4"


OQ_INPUT_COLUMNS = [
"vs30",
"rrup",
Expand Down Expand Up @@ -953,6 +961,125 @@ def mag_rrup_scatter(
return base64.b64encode(buf.read()).decode("utf-8")


def plot_site_table_image(
full_df: pd.DataFrame,
quality_df: pd.DataFrame,
skipped_df: pd.DataFrame,
wrap_width: int = 20,
):
"""
Create and save a table-like plot showing, per site:

- number of records in ``full_df``
- number of records in ``quality_df``
- count per skipped reason from ``skipped_df``

Parameters
----------
full_df : pandas.DataFrame
Full dataset containing a ``sta`` column (station identifier).
quality_df : pandas.DataFrame
Quality-filtered dataset containing a ``sta`` column.
skipped_df : pandas.DataFrame
DataFrame of skipped records. Must contain ``sta`` and ``reason`` columns.
wrap_width : int, optional
Maximum characters before wrapping long column labels for display. Default is 20.

Returns
-------
img_base64 : str
Base64-encoded PNG image of the rendered table (or a small placeholder image when no rows match).
n_zero_quality_sites : int
Number of sites where the quality count is zero (before applying the ``quality_count < 5`` filter).
total_zero_quality_records : int
Total number of records from ``full_df`` for sites with zero quality records.

"""
# Counts
full_counts = full_df["sta"].value_counts().rename("full_count")
quality_counts = quality_df["sta"].value_counts().rename("quality_count")

# Pivot skipped reasons into columns (one column per reason)
if not skipped_df.empty:
skipped_pivot = (
skipped_df.groupby(["sta", "reason"]).size().unstack(fill_value=0)
)
else:
skipped_pivot = pd.DataFrame(index=full_counts.index)

# Combine into a single DataFrame
combined = (
pd.concat([full_counts, quality_counts, skipped_pivot], axis=1)
.fillna(0)
.astype(int)
)
combined = combined.reset_index().rename(columns={"index": "site", "sta": "site"})

# Ensure columns order: site, full_count, quality_count, then reasons
reason_cols = sorted(
[
c
for c in combined.columns
if c not in ("site", "full_count", "quality_count")
]
)
ordered_cols = ["site", "full_count", "quality_count"] + reason_cols
combined = combined[ordered_cols]

# Create a new df for the quality counts that are 0
combined_zero_quality = combined[combined["quality_count"] == 0]

# Keep only rows where quality_count < 5
combined = combined[
(combined["quality_count"] < 5) & (combined["quality_count"] > 0)
]

# Remove all columns where the values are zero
combined = combined.loc[:, (combined != 0).any(axis=0)]

# Wrap long column labels for display
display_cols = [textwrap.fill(c, wrap_width) for c in combined.columns.tolist()]

# Figure size heuristics
n_rows = max(1, combined.shape[0])
n_cols = max(1, combined.shape[1])
# width per column, height per row
width = min(40, max(6, n_cols * 2.0))
height = min(40, max(2, n_rows * 0.35 + 1.5))

fig, ax = plt.subplots(figsize=(width, height))
ax.axis("off")

# Build table; convert all values to strings for consistent display
cell_text = combined.astype(str).values.tolist()
table = ax.table(
cellText=cell_text, colLabels=display_cols, cellLoc="center", loc="center"
)

# Styling
table.auto_set_font_size(False)
# choose a font size based on space available
font_size = 10 if n_rows < 40 else max(6, int(200 / (n_rows + n_cols)))
table.set_fontsize(font_size)
table.scale(1, 1.2)

plt.tight_layout()
buf = BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight")
plt.close(fig)
buf.seek(0)

# Add returns for the total number of records with zero quality counts and the number of sites
n_zero_quality_sites = combined_zero_quality.shape[0]
total_zero_quality_records = combined_zero_quality["full_count"].sum()

return (
base64.b64encode(buf.read()).decode("utf-8"),
n_zero_quality_sites,
total_zero_quality_records,
)


@cli.from_docstring(app)
def generate_report(
new_version_directory: Annotated[
Expand All @@ -973,6 +1100,22 @@ def generate_report(
file_okay=False,
),
] = None,
new_version: Annotated[
Optional[NZGMDB_Versions],
typer.Option(
None,
help="The version for the new database (choose from the enum).",
case_sensitive=False,
),
] = NZGMDB_Versions.V4p3,
old_version: Annotated[
Optional[NZGMDB_Versions],
typer.Option(
None,
help="The version for the old database (choose from the enum).",
case_sensitive=False,
),
] = NZGMDB_Versions.V4p3,
):
"""
Generate a HTML report comparing the new version of the database to a previous version.
Expand All @@ -987,6 +1130,10 @@ def generate_report(
compare_version_directory : Path | None
The Top Level directory containing the previous version of the database to compare against.
If None, a summary of the new version will be generated instead and comparison plots will not be generated.
new_version : NZGMDB_Versions | None
The version for the new database (choose from the enum). Default is NZGMDB_Versions.V4p3.
old_version : NZGMDB_Versions | None
The version for the old database (choose from the enum). Default is NZGMDB_Versions.V4p3.
"""
html_parts = []
# Start of HTML
Expand Down Expand Up @@ -1364,8 +1511,9 @@ def generate_report(
accepted_lengths_new = [
len(
pd.read_csv(
new_flatfiles_dir
/ file_structure.PreFlatfileNames.STATION_MAGNITUDE_TABLE_GEONET
new_flatfiles_dir / "station_magnitude_table_geonet.csv"
if new_version == NZGMDB_Versions.V4p3
else file_structure.PreFlatfileNames.STATION_MAGNITUDE_TABLE_EXTRACTION
)
)
/ 3,
Expand All @@ -1382,8 +1530,9 @@ def generate_report(
[
len(
pd.read_csv(
old_flatifles_dir
/ file_structure.PreFlatfileNames.STATION_MAGNITUDE_TABLE_GEONET
old_flatifles_dir / "station_magnitude_table_geonet.csv"
if old_version == NZGMDB_Versions.V4p3
else file_structure.PreFlatfileNames.STATION_MAGNITUDE_TABLE_EXTRACTION
)
)
/ 3,
Expand Down Expand Up @@ -1523,6 +1672,54 @@ def generate_report(
html_parts.append(f'<img src="data:image/png;base64,{img}">')
html_parts.append("</div>")

# Add New section for low quality site records
html_parts.append("<h2>New NZGMDB - Low Quality Site Records</h2>")
# Add sta col to skipped_reasons by splitting record_id
new_quality_skipped["sta"] = new_quality_skipped["record_id"].str.split(
"_", expand=True
)[1]
img_base64, n_zero_quality_sites, total_zero_quality_records = (
plot_site_table_image(full_new, quality_new, new_quality_skipped)
)
html_parts.append("<div class='fig-single'>")
html_parts.append(f'<img src="data:image/png;base64,{img_base64}">')
html_parts.append("</div>")

# Add text for the zero quality sites
html_parts.append("<h3>Zero Quality Sites Summary</h3>")
html_parts.append("<ul>")
html_parts.append(
f"<li>Number of sites with zero quality records: {n_zero_quality_sites}</li>"
)
html_parts.append(
f"<li>Total number of records for these sites in the full database: {total_zero_quality_records}</li>"
)
html_parts.append("</ul>")

if compare_version_directory:
html_parts.append("<h2>Old NZGMDB - Low Quality Site Records</h2>")
# Add sta col to skipped_reasons by splitting record_id
old_quality_skipped["sta"] = old_quality_skipped["record_id"].str.split(
"_", expand=True
)[1]
img_base64_old, n_zero_quality_sites_old, total_zero_quality_records_old = (
plot_site_table_image(full_old, quality_old, old_quality_skipped)
)
html_parts.append("<div class='fig-single'>")
html_parts.append(f'<img src="data:image/png;base64,{img_base64_old}">')
html_parts.append("</div>")

# Add text for the zero quality sites
html_parts.append("<h3>Zero Quality Sites Summary</h3>")
html_parts.append("<ul>")
html_parts.append(
f"<li>Number of sites with zero quality records: {n_zero_quality_sites_old}</li>"
)
html_parts.append(
f"<li>Total number of records for these sites in the full database: {total_zero_quality_records_old}</li>"
)
html_parts.append("</ul>")

# Add Category Column Comparisons
html_parts.append("<h2>Event Column Comparisons</h2>")

Expand Down
Loading