Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 88 additions & 32 deletions dsa110_continuum/calibration/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,8 @@ def _check_flag_fraction(
) -> float:
"""Check if flag fraction in calibration table exceeds threshold.

This function excludes fully-flagged antennas (dead antennas) from the
calculation, since they don't indicate calibration problems - just
non-working hardware.
Excludes fully-flagged (antenna, receptor) pairs from the calculation,
since those indicate dead hardware rather than calibration problems.

Parameters
----------
Expand All @@ -489,7 +488,7 @@ def _check_flag_fraction(

Returns
-------
Actual flag fraction (0.0 to 1.0) for working antennas only
Actual flag fraction (0.0 to 1.0) for working receptors only

Raises
------
Expand All @@ -505,59 +504,54 @@ def _check_flag_fraction(
return 0.0

flags = tb.getcol("FLAG")
# casacore returns shape (nrow, nchan, npol) for bandpass tables
# where nrow = number of antennas

# Calculate raw flag fraction
total = flags.size
flagged = int(np.sum(flags))
raw_flag_fraction = flagged / total if total > 0 else 0.0

# Calculate flag fraction excluding fully-dead antennas
# Dead antennas have 100% of their solutions flagged
# For bandpass tables: shape is (nant, nchan, npol)
if flags.ndim == 3:
nant, nchan, npol = flags.shape
# Sum over channels and polarizations to get per-antenna flag count
per_ant_flags = np.sum(flags, axis=(1, 2))
max_flags_per_ant = nchan * npol

# Count dead antennas (>=99% flagged)
# Using a threshold handles cases where an antenna is effectively dead
# but has a few unflagged solutions (e.g. due to edge effects or gaps)
dead_ant_mask = per_ant_flags >= 0.99 * max_flags_per_ant
n_dead = int(np.sum(dead_ant_mask))
n_working = nant - n_dead

if n_working > 0:
# Calculate flag fraction for working antennas only
working_flags = np.sum(flags[~dead_ant_mask, :, :])
working_total = n_working * nchan * npol
effective_flag_fraction = working_flags / working_total
else:
effective_flag_fraction = raw_flag_fraction
# Calculate flag fraction excluding fully-dead antenna receptors.
# A CASA calibration table stores one row per antenna/SPW, while the
# FLAG cell stores receptor/polarization and channel axes. Treating the
# first array axis as "antenna" incorrectly counts rows or channels as
# dead antennas after the casa_tables row-axis normalization.
antenna_ids = None
if "ANTENNA1" in tb.colnames():
antenna_ids = tb.getcol("ANTENNA1")

if flags.ndim == 3 and antenna_ids is not None and len(antenna_ids) == flags.shape[0]:
flag_stats = _flag_fraction_excluding_dead_receptors(flags, antenna_ids)
effective_flag_fraction = flag_stats["effective_flag_fraction"]
n_dead = flag_stats["dead_receptor_count"]
Comment on lines +524 to +525
n_dead_antennas = flag_stats["dead_antenna_count"]

logger.info(
f"Flag fraction in {cal_type} table: {raw_flag_fraction * 100:.1f}% raw "
f"({flagged:,}/{total:,} solutions flagged)"
)
if n_dead > 0:
logger.info(
f" Excluding {n_dead} fully-flagged (dead) antennas: "
f" Excluding {n_dead} fully-flagged (dead) antenna receptors "
f"across {n_dead_antennas} antennas: "
f"effective flag fraction = {effective_flag_fraction * 100:.1f}% "
f"({n_working} working antennas)"
f"({flag_stats['working_receptor_count']} working receptors)"
)
else:
# Fallback for other table shapes
effective_flag_fraction = raw_flag_fraction
n_dead = 0
n_dead_antennas = 0
logger.info(
f"Flag fraction in {cal_type} table: {raw_flag_fraction * 100:.1f}% "
f"({flagged:,}/{total:,} solutions flagged)"
)

if effective_flag_fraction > max_flag_fraction:
dead_info = f" (excluding {n_dead} dead antennas)" if n_dead > 0 else ""
dead_info = (
f" (excluding {n_dead} dead antenna receptors across {n_dead_antennas} antennas)"
if n_dead > 0
else ""
)
raise ValueError(
f"{cal_type.upper()} SOLVE FAILED: Excessive flagging detected{dead_info}.\n"
f" Effective flag fraction: {effective_flag_fraction * 100:.1f}% (threshold: {max_flag_fraction * 100:.0f}%)\n"
Expand All @@ -573,6 +567,68 @@ def _check_flag_fraction(
return effective_flag_fraction


def _flag_fraction_excluding_dead_receptors(
flags: Any,
antenna_ids: Any,
*,
dead_threshold: float = 0.99,
) -> dict[str, Any]:
"""Compute caltable flag fraction after removing dead antenna receptors."""
import numpy as np

flags = np.asarray(flags, dtype=bool)
antenna_ids = np.asarray(antenna_ids)
if flags.ndim != 3:
raise ValueError(f"Expected row-major 3D FLAG array, found shape {flags.shape}")
if flags.shape[0] != antenna_ids.shape[0]:
raise ValueError(
f"Expected one ANTENNA1 value per FLAG row, found {antenna_ids.shape[0]} "
f"antenna IDs for {flags.shape[0]} rows"
)

# CASA polarization/receptor axes are always small (≤4); channel axes are
# always >4. Pick the cell axis whose size is ≤4 as the receptor axis.
cell_axes = flags.shape[1:]
receptor_axis_in_cell = min(
range(len(cell_axes)),
key=lambda idx: (cell_axes[idx] > 4, cell_axes[idx]),
)
receptor_axis = receptor_axis_in_cell + 1
receptor_count = flags.shape[receptor_axis]

unique_antennas = sorted(set(antenna_ids.tolist()))
dead_receptors: list[tuple[int, int]] = []
working_flagged = 0
working_total = 0

for antenna_id in unique_antennas:
antenna_mask = antenna_ids == antenna_id
antenna_flags = flags[antenna_mask]
for receptor_idx in range(receptor_count):
receptor_flags = np.take(antenna_flags, receptor_idx, axis=receptor_axis)
receptor_total = int(receptor_flags.size)
receptor_flagged = int(np.sum(receptor_flags))
receptor_fraction = receptor_flagged / receptor_total if receptor_total else 0.0
if receptor_fraction >= dead_threshold:
dead_receptors.append((int(antenna_id), receptor_idx))
else:
working_flagged += receptor_flagged
working_total += receptor_total

effective_flag_fraction = (
working_flagged / working_total if working_total else float(np.mean(flags))
)
dead_antennas = {antenna_id for antenna_id, _ in dead_receptors}
return {
"effective_flag_fraction": float(effective_flag_fraction),
"dead_receptor_count": len(dead_receptors),
"dead_antenna_count": len(dead_antennas),
"working_receptor_count": len(unique_antennas) * receptor_count - len(dead_receptors),
Comment on lines +618 to +626
"working_flagged": int(working_flagged),
"working_total": int(working_total),
}


def _print_bandpass_solution_summary(
caltable_path: str,
ms: str,
Expand Down
97 changes: 97 additions & 0 deletions tests/test_calibration_flag_fraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""Tests for receptor-aware FLAG fraction QA helper."""

from __future__ import annotations

import numpy as np
import pytest

from dsa110_continuum.calibration.calibration import (
_flag_fraction_excluding_dead_receptors,
)
Comment on lines +8 to +10


N_ANTENNAS = 117
N_SPWS = 16
N_ROWS = N_ANTENNAS * N_SPWS # 1872
N_RECEPTORS = 2
N_CHANNELS = 48


def _antenna_ids() -> np.ndarray:
return np.repeat(np.arange(N_ANTENNAS), N_SPWS)


def _empty_flags() -> np.ndarray:
return np.zeros((N_ROWS, N_RECEPTORS, N_CHANNELS), dtype=bool)


class TestFlagFractionExcludingDeadReceptors:
def test_realistic_bandpass_shape_no_dead_receptors(self):
flags = _empty_flags()
antenna_ids = _antenna_ids()

result = _flag_fraction_excluding_dead_receptors(flags, antenna_ids)

assert result["effective_flag_fraction"] == 0.0
assert result["dead_receptor_count"] == 0
assert result["dead_antenna_count"] == 0
assert result["working_receptor_count"] == N_ANTENNAS * N_RECEPTORS

def test_one_dead_receptor_excluded(self):
flags = _empty_flags()
antenna_ids = _antenna_ids()
rows_for_ant_3 = np.where(antenna_ids == 3)[0]
flags[rows_for_ant_3, 0, :] = True

result = _flag_fraction_excluding_dead_receptors(flags, antenna_ids)

assert result["dead_receptor_count"] == 1
assert result["dead_antenna_count"] == 1
assert result["effective_flag_fraction"] == pytest.approx(0.0)
assert result["working_receptor_count"] == N_ANTENNAS * N_RECEPTORS - 1

def test_both_receptors_dead_one_antenna(self):
flags = _empty_flags()
antenna_ids = _antenna_ids()
rows_for_ant_5 = np.where(antenna_ids == 5)[0]
flags[rows_for_ant_5, :, :] = True

result = _flag_fraction_excluding_dead_receptors(flags, antenna_ids)

assert result["dead_receptor_count"] == 2
assert result["dead_antenna_count"] == 1
assert result["working_receptor_count"] == N_ANTENNAS * N_RECEPTORS - 2
assert result["effective_flag_fraction"] == pytest.approx(0.0)

def test_does_not_misread_axis_0_as_antennas(self):
flags = _empty_flags()
antenna_ids = _antenna_ids()
rows_one_per_antenna = np.arange(N_ANTENNAS) * N_SPWS
flagged_rows = rows_one_per_antenna[:116]
flags[flagged_rows, :, :] = True

result = _flag_fraction_excluding_dead_receptors(flags, antenna_ids)

assert result["dead_antenna_count"] == 0
assert result["dead_receptor_count"] == 0

def test_partial_flagging_in_one_receptor(self):
flags = _empty_flags()
antenna_ids = _antenna_ids()
rows_for_ant_7 = np.where(antenna_ids == 7)[0]
flags[rows_for_ant_7, 0, : N_CHANNELS // 2] = True

result = _flag_fraction_excluding_dead_receptors(flags, antenna_ids)

assert result["dead_receptor_count"] == 0
assert result["dead_antenna_count"] == 0
assert result["working_receptor_count"] == N_ANTENNAS * N_RECEPTORS

flagged_per_partial_receptor = N_SPWS * (N_CHANNELS // 2)
total_per_receptor = N_SPWS * N_CHANNELS
expected_fraction = flagged_per_partial_receptor / (
(N_ANTENNAS * N_RECEPTORS) * total_per_receptor
)
assert result["effective_flag_fraction"] == pytest.approx(expected_fraction)
assert result["working_flagged"] == flagged_per_partial_receptor
assert result["working_total"] == N_ANTENNAS * N_RECEPTORS * total_per_receptor
Loading