Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 46 additions & 10 deletions dsa110_continuum/calibration/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,10 @@
import time
from pathlib import Path


import astropy.units as u # noqa: E402
from dsa110_continuum.adapters import casa_tables as tb # noqa: E402
import numpy as np # noqa: E402
from astropy.coordinates import SkyCoord # noqa: E402
from dsa110_continuum.adapters import casa_tables as tb # noqa: E402

# Set up logger
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -444,20 +443,53 @@ def _calculate_manual_model_data(
print(f" Processing {nselected:,} selected rows...")
sys.stdout.flush()

# Read DATA shape to create MODEL_DATA with matching shape
# Read DATA shape to create MODEL_DATA with matching shape.
# CASA tables can store DATA as either (nchan, npol) or (npol, nchan)
# depending on writer convention; identify the channel axis by matching
# against the SPW channel count rather than assuming an order.
print(" Reading DATA shape...")
sys.stdout.flush()
data_sample = main_tb.getcell("DATA", 0)
data_shape = data_sample.shape # In CASA: (nchan, npol)
nchan, npol = data_shape[0], data_shape[1]
logger.debug(f"Data shape: {nchan} channels, {npol} polarizations")
print(f" DATA shape: {nchan} channels × {npol} pols")
data_shape = data_sample.shape
if len(data_shape) != 2:
raise ValueError(
f"Unsupported DATA cell shape for MODEL_DATA calculation: {data_shape}"
)

expected_nchan = int(chan_freq.shape[1])
if data_shape[0] == expected_nchan:
chan_axis = 0
corr_axis = 1
elif data_shape[1] == expected_nchan:
corr_axis = 0
chan_axis = 1
else:
raise ValueError(
"Cannot identify channel axis for MODEL_DATA calculation: "
f"DATA cell shape={data_shape}, expected channel count={expected_nchan}"
)

nchan = data_shape[chan_axis]
npol = data_shape[corr_axis]
logger.debug(
"Data shape: %s, channel axis=%d (%d channels), correlation axis=%d (%d correlations)",
data_shape,
chan_axis,
nchan,
corr_axis,
npol,
)
print(
f" DATA shape: {data_shape} (channel axis={chan_axis}, correlation axis={corr_axis})"
)
sys.stdout.flush()

# Initialize MODEL_DATA array with correct shape (nrows, nchan, npol)
print(f" Allocating MODEL_DATA array ({nrows * nchan * npol * 8 / 1e9:.2f} GB)...")
sys.stdout.flush()
model_data = np.zeros((nrows, nchan, npol), dtype=np.complex64)
# Allocate MODEL_DATA with the per-row shape that matches DATA exactly,
# so axis order matches whatever the writer used.
model_data = np.zeros((nrows, *data_shape), dtype=np.complex64)
logger.debug(f"Allocated MODEL_DATA array: {model_data.nbytes / 1e9:.2f} GB")
print(" MODEL_DATA allocated")
sys.stdout.flush()
Expand Down Expand Up @@ -565,8 +597,12 @@ def _calculate_manual_model_data(
# Create complex model: amplitude * exp(i*phase)
model_complex = amplitude * (np.cos(phase) + 1j * np.sin(phase))

# Broadcast to all polarizations and store in output array
model_data[chunk_indices, :, :] = model_complex[:, :, np.newaxis]
# Broadcast the channel-dependent point-source model to every
# correlation while preserving the MS DATA column's axis order.
if chan_axis == 0:
model_data[chunk_indices, :, :] = model_complex[:, :, np.newaxis]
else:
model_data[chunk_indices, :, :] = model_complex[:, np.newaxis, :]

# Progress output
if (chunk_idx + 1) % 3 == 0 or chunk_idx == n_chunks - 1:
Expand Down
2 changes: 2 additions & 0 deletions dsa110_continuum/evidence/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"""Evidence workflows for pipeline validation artifacts."""

Loading
Loading