diff --git a/dsa110_continuum/calibration/calibration.py b/dsa110_continuum/calibration/calibration.py index 648f2ad..7fd436d 100644 --- a/dsa110_continuum/calibration/calibration.py +++ b/dsa110_continuum/calibration/calibration.py @@ -474,9 +474,8 @@ def _check_flag_fraction( ) -> float: """Check if flag fraction in calibration table exceeds threshold. - This function excludes fully-flagged antennas (dead antennas) from the - calculation, since they don't indicate calibration problems - just - non-working hardware. + Excludes fully-flagged (antenna, receptor) pairs from the calculation, + since those indicate dead hardware rather than calibration problems. Parameters ---------- @@ -489,7 +488,7 @@ def _check_flag_fraction( Returns ------- - Actual flag fraction (0.0 to 1.0) for working antennas only + Actual flag fraction (0.0 to 1.0) for working receptors only Raises ------ @@ -505,37 +504,26 @@ def _check_flag_fraction( return 0.0 flags = tb.getcol("FLAG") - # casacore returns shape (nrow, nchan, npol) for bandpass tables - # where nrow = number of antennas # Calculate raw flag fraction total = flags.size flagged = int(np.sum(flags)) raw_flag_fraction = flagged / total if total > 0 else 0.0 - # Calculate flag fraction excluding fully-dead antennas - # Dead antennas have 100% of their solutions flagged - # For bandpass tables: shape is (nant, nchan, npol) - if flags.ndim == 3: - nant, nchan, npol = flags.shape - # Sum over channels and polarizations to get per-antenna flag count - per_ant_flags = np.sum(flags, axis=(1, 2)) - max_flags_per_ant = nchan * npol - - # Count dead antennas (>=99% flagged) - # Using a threshold handles cases where an antenna is effectively dead - # but has a few unflagged solutions (e.g. due to edge effects or gaps) - dead_ant_mask = per_ant_flags >= 0.99 * max_flags_per_ant - n_dead = int(np.sum(dead_ant_mask)) - n_working = nant - n_dead - - if n_working > 0: - # Calculate flag fraction for working antennas only - working_flags = np.sum(flags[~dead_ant_mask, :, :]) - working_total = n_working * nchan * npol - effective_flag_fraction = working_flags / working_total - else: - effective_flag_fraction = raw_flag_fraction + # Calculate flag fraction excluding fully-dead antenna receptors. + # A CASA calibration table stores one row per antenna/SPW, while the + # FLAG cell stores receptor/polarization and channel axes. Treating the + # first array axis as "antenna" incorrectly counts rows or channels as + # dead antennas after the casa_tables row-axis normalization. + antenna_ids = None + if "ANTENNA1" in tb.colnames(): + antenna_ids = tb.getcol("ANTENNA1") + + if flags.ndim == 3 and antenna_ids is not None and len(antenna_ids) == flags.shape[0]: + flag_stats = _flag_fraction_excluding_dead_receptors(flags, antenna_ids) + effective_flag_fraction = flag_stats["effective_flag_fraction"] + n_dead = flag_stats["dead_receptor_count"] + n_dead_antennas = flag_stats["dead_antenna_count"] logger.info( f"Flag fraction in {cal_type} table: {raw_flag_fraction * 100:.1f}% raw " @@ -543,21 +531,27 @@ def _check_flag_fraction( ) if n_dead > 0: logger.info( - f" Excluding {n_dead} fully-flagged (dead) antennas: " + f" Excluding {n_dead} fully-flagged (dead) antenna receptors " + f"across {n_dead_antennas} antennas: " f"effective flag fraction = {effective_flag_fraction * 100:.1f}% " - f"({n_working} working antennas)" + f"({flag_stats['working_receptor_count']} working receptors)" ) else: # Fallback for other table shapes effective_flag_fraction = raw_flag_fraction n_dead = 0 + n_dead_antennas = 0 logger.info( f"Flag fraction in {cal_type} table: {raw_flag_fraction * 100:.1f}% " f"({flagged:,}/{total:,} solutions flagged)" ) if effective_flag_fraction > max_flag_fraction: - dead_info = f" (excluding {n_dead} dead antennas)" if n_dead > 0 else "" + dead_info = ( + f" (excluding {n_dead} dead antenna receptors across {n_dead_antennas} antennas)" + if n_dead > 0 + else "" + ) raise ValueError( f"{cal_type.upper()} SOLVE FAILED: Excessive flagging detected{dead_info}.\n" f" Effective flag fraction: {effective_flag_fraction * 100:.1f}% (threshold: {max_flag_fraction * 100:.0f}%)\n" @@ -573,6 +567,68 @@ def _check_flag_fraction( return effective_flag_fraction +def _flag_fraction_excluding_dead_receptors( + flags: Any, + antenna_ids: Any, + *, + dead_threshold: float = 0.99, +) -> dict[str, Any]: + """Compute caltable flag fraction after removing dead antenna receptors.""" + import numpy as np + + flags = np.asarray(flags, dtype=bool) + antenna_ids = np.asarray(antenna_ids) + if flags.ndim != 3: + raise ValueError(f"Expected row-major 3D FLAG array, found shape {flags.shape}") + if flags.shape[0] != antenna_ids.shape[0]: + raise ValueError( + f"Expected one ANTENNA1 value per FLAG row, found {antenna_ids.shape[0]} " + f"antenna IDs for {flags.shape[0]} rows" + ) + + # CASA polarization/receptor axes are always small (≤4); channel axes are + # always >4. Pick the cell axis whose size is ≤4 as the receptor axis. + cell_axes = flags.shape[1:] + receptor_axis_in_cell = min( + range(len(cell_axes)), + key=lambda idx: (cell_axes[idx] > 4, cell_axes[idx]), + ) + receptor_axis = receptor_axis_in_cell + 1 + receptor_count = flags.shape[receptor_axis] + + unique_antennas = sorted(set(antenna_ids.tolist())) + dead_receptors: list[tuple[int, int]] = [] + working_flagged = 0 + working_total = 0 + + for antenna_id in unique_antennas: + antenna_mask = antenna_ids == antenna_id + antenna_flags = flags[antenna_mask] + for receptor_idx in range(receptor_count): + receptor_flags = np.take(antenna_flags, receptor_idx, axis=receptor_axis) + receptor_total = int(receptor_flags.size) + receptor_flagged = int(np.sum(receptor_flags)) + receptor_fraction = receptor_flagged / receptor_total if receptor_total else 0.0 + if receptor_fraction >= dead_threshold: + dead_receptors.append((int(antenna_id), receptor_idx)) + else: + working_flagged += receptor_flagged + working_total += receptor_total + + effective_flag_fraction = ( + working_flagged / working_total if working_total else float(np.mean(flags)) + ) + dead_antennas = {antenna_id for antenna_id, _ in dead_receptors} + return { + "effective_flag_fraction": float(effective_flag_fraction), + "dead_receptor_count": len(dead_receptors), + "dead_antenna_count": len(dead_antennas), + "working_receptor_count": len(unique_antennas) * receptor_count - len(dead_receptors), + "working_flagged": int(working_flagged), + "working_total": int(working_total), + } + + def _print_bandpass_solution_summary( caltable_path: str, ms: str, diff --git a/tests/test_calibration_flag_fraction.py b/tests/test_calibration_flag_fraction.py new file mode 100644 index 0000000..417876b --- /dev/null +++ b/tests/test_calibration_flag_fraction.py @@ -0,0 +1,97 @@ +"""Tests for receptor-aware FLAG fraction QA helper.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from dsa110_continuum.calibration.calibration import ( + _flag_fraction_excluding_dead_receptors, +) + + +N_ANTENNAS = 117 +N_SPWS = 16 +N_ROWS = N_ANTENNAS * N_SPWS # 1872 +N_RECEPTORS = 2 +N_CHANNELS = 48 + + +def _antenna_ids() -> np.ndarray: + return np.repeat(np.arange(N_ANTENNAS), N_SPWS) + + +def _empty_flags() -> np.ndarray: + return np.zeros((N_ROWS, N_RECEPTORS, N_CHANNELS), dtype=bool) + + +class TestFlagFractionExcludingDeadReceptors: + def test_realistic_bandpass_shape_no_dead_receptors(self): + flags = _empty_flags() + antenna_ids = _antenna_ids() + + result = _flag_fraction_excluding_dead_receptors(flags, antenna_ids) + + assert result["effective_flag_fraction"] == 0.0 + assert result["dead_receptor_count"] == 0 + assert result["dead_antenna_count"] == 0 + assert result["working_receptor_count"] == N_ANTENNAS * N_RECEPTORS + + def test_one_dead_receptor_excluded(self): + flags = _empty_flags() + antenna_ids = _antenna_ids() + rows_for_ant_3 = np.where(antenna_ids == 3)[0] + flags[rows_for_ant_3, 0, :] = True + + result = _flag_fraction_excluding_dead_receptors(flags, antenna_ids) + + assert result["dead_receptor_count"] == 1 + assert result["dead_antenna_count"] == 1 + assert result["effective_flag_fraction"] == pytest.approx(0.0) + assert result["working_receptor_count"] == N_ANTENNAS * N_RECEPTORS - 1 + + def test_both_receptors_dead_one_antenna(self): + flags = _empty_flags() + antenna_ids = _antenna_ids() + rows_for_ant_5 = np.where(antenna_ids == 5)[0] + flags[rows_for_ant_5, :, :] = True + + result = _flag_fraction_excluding_dead_receptors(flags, antenna_ids) + + assert result["dead_receptor_count"] == 2 + assert result["dead_antenna_count"] == 1 + assert result["working_receptor_count"] == N_ANTENNAS * N_RECEPTORS - 2 + assert result["effective_flag_fraction"] == pytest.approx(0.0) + + def test_does_not_misread_axis_0_as_antennas(self): + flags = _empty_flags() + antenna_ids = _antenna_ids() + rows_one_per_antenna = np.arange(N_ANTENNAS) * N_SPWS + flagged_rows = rows_one_per_antenna[:116] + flags[flagged_rows, :, :] = True + + result = _flag_fraction_excluding_dead_receptors(flags, antenna_ids) + + assert result["dead_antenna_count"] == 0 + assert result["dead_receptor_count"] == 0 + + def test_partial_flagging_in_one_receptor(self): + flags = _empty_flags() + antenna_ids = _antenna_ids() + rows_for_ant_7 = np.where(antenna_ids == 7)[0] + flags[rows_for_ant_7, 0, : N_CHANNELS // 2] = True + + result = _flag_fraction_excluding_dead_receptors(flags, antenna_ids) + + assert result["dead_receptor_count"] == 0 + assert result["dead_antenna_count"] == 0 + assert result["working_receptor_count"] == N_ANTENNAS * N_RECEPTORS + + flagged_per_partial_receptor = N_SPWS * (N_CHANNELS // 2) + total_per_receptor = N_SPWS * N_CHANNELS + expected_fraction = flagged_per_partial_receptor / ( + (N_ANTENNAS * N_RECEPTORS) * total_per_receptor + ) + assert result["effective_flag_fraction"] == pytest.approx(expected_fraction) + assert result["working_flagged"] == flagged_per_partial_receptor + assert result["working_total"] == N_ANTENNAS * N_RECEPTORS * total_per_receptor