Skip to content

Commit

Permalink
Make QuantityTableCoordinate and TimeTableCoordinate stored tables as…
Browse files Browse the repository at this point in the history
… corners and centers.
  • Loading branch information
DanRyanIrish committed May 30, 2024
1 parent ce4d0b1 commit 96a9331
Showing 1 changed file with 79 additions and 20 deletions.
99 changes: 79 additions & 20 deletions ndcube/extra_coords/table_coord.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,30 +128,27 @@ def _generate_generic_frame(naxes, unit, names=None, physical_types=None):
axes_names=names, name=name, axis_physical_types=physical_types)


def _generate_tabular(lookup_table, interpolation='linear', points_unit=u.pix, **kwargs):
def _generate_tabular(table_points, lookup_table, interpolation='linear', points_unit=u.pix, **kwargs):
"""
Generate a Tabular model class and instance.
"""
if not isinstance(lookup_table, u.Quantity):
raise TypeError("lookup_table must be a Quantity.") # pragma: no cover
if not isinstance(table_points, u.Quantity):
raise TypeError("table_points must be a Quantity.") # pragma: no cover

ndim = lookup_table.ndim
TabularND = tabular_model(ndim, name=f"Tabular{ndim}D")

# The integer location is at the centre of the pixel.
points = [(np.arange(size) - 0) * points_unit for size in lookup_table.shape]
if len(points) == 1:
points = points[0]

kwargs = {'bounds_error': False,
'fill_value': np.nan,
'method': interpolation,
**kwargs}

if len(lookup_table) == 1:
t = Length1Tabular(points, lookup_table, **kwargs)
t = Length1Tabular(table_points, lookup_table, **kwargs)
else:
t = TabularND(points, lookup_table, **kwargs)
t = TabularND(table_points, lookup_table, **kwargs)

# TODO: Remove this when there is a new gWCS release
# Work around https://github.com/spacetelescope/gwcs/pull/331
Expand All @@ -160,13 +157,13 @@ def _generate_tabular(lookup_table, interpolation='linear', points_unit=u.pix, *
return t


def _generate_compound_model(*lookup_tables, mesh=True):
def _generate_compound_model(points, *lookup_tables, mesh=True):
"""
Takes a set of quantities and returns a ND compound model.
"""
model = _generate_tabular(lookup_tables[0])
for lt in lookup_tables[1:]:
model = model & _generate_tabular(lt)
model = _generate_tabular(points[0], lookup_tables[0])
for pts, lt in zip(points[1:], lookup_tables[1:]):
model = model & _generate_tabular(pts, lt)

if mesh:
return model
Expand All @@ -176,11 +173,11 @@ def _generate_compound_model(*lookup_tables, mesh=True):
return models.Mapping(mapping) | model


def _model_from_quantity(lookup_tables, mesh=False):
def _model_from_quantity(points, lookup_tables, mesh=False):
if len(lookup_tables) > 1:
return _generate_compound_model(*lookup_tables, mesh=mesh)
return _generate_compound_model(points, *lookup_tables, mesh=mesh)

return _generate_tabular(lookup_tables[0])
return _generate_tabular(points[0], lookup_tables[0])


class BaseTableCoordinate(abc.ABC):
Expand Down Expand Up @@ -301,7 +298,7 @@ class QuantityTableCoordinate(BaseTableCoordinate):
a physical type must be given for each component.
"""

def __init__(self, *tables, names=None, physical_types=None):
def __init__(self, *tables, names=None, physical_types=None, grid_points="centers"):
if not all([isinstance(t, u.Quantity) for t in tables]):
raise TypeError("All tables must be astropy Quantity objects")
if not all([t.unit.is_equivalent(tables[0].unit) for t in tables]):
Expand All @@ -313,6 +310,17 @@ def __init__(self, *tables, names=None, physical_types=None):
"Currently all tables must be 1-D. If you need >1D support, please "
"raise an issue at https://github.con/sunpy/ndcube/issues")

# lookup table must be stored as values at corners and centers.
# If input tables only represent centers or corners, linearly interpolate to get other values.
# If centers or corners provided, the following assumes the tables are 1-D.
if grid_points == "centers":
tables = _get_grid_from_centers(tables)
elif grid_points == "corners":
tables = _get_grid_from_corners(tables)
elif grid_points != "centers and corners":
raise ValueError(f"Unrecognized value for grid_points: {grid_points}. "
"Must be 'centers', 'corners', or 'centers and corners'.")

if isinstance(names, str):
names = [names]
if names is not None and len(names) != ndim:
Expand Down Expand Up @@ -396,7 +404,11 @@ def model(self):
"""
Generate the Astropy Model for this LookupTable.
"""
return _model_from_quantity(self.table, True)
points_unit = u.pix
points = [(np.arange(-1, table.shape[0] - 1) / 2) * points_unit if len(table.shape) == 1
else [(np.arange(-1, size - 1) / 2) * points_unit for size in table.shape]
for table in self.table]
return _model_from_quantity(points, self.table, True)

@property
def ndim(self):
Expand Down Expand Up @@ -577,6 +589,10 @@ def model(self):
"""
Generate the Astropy Model for this LookupTable.
"""
points_unit = u.pix
points = [np.arange(table.shape[0]) * points_unit if len(table.shape) == 1
else [np.arange(size) for size in table.shape] * points_unit
for table in self._sliced_components]
return _model_from_quantity(self._sliced_components, mesh=self.mesh)

@property
Expand Down Expand Up @@ -701,10 +717,24 @@ class TimeTableCoordinate(BaseTableCoordinate):
Default is first time coordinate in table input.
"""

def __init__(self, *tables, names=None, physical_types=None, reference_time=None):
def __init__(self, *tables, names=None, physical_types=None, reference_time=None, grid_points="centers"):
if not len(tables) == 1 and isinstance(tables[0], Time):
raise ValueError("TimeLookupTable can only be constructed from a single Time object.")

# lookup table must be stored as values at corners and centers.
# If input tables only represent centers or corners, linearly interpolate to get other values.
# If centers or corners provided, the following assumes the tables are 1-D.
table = tables[0]
mjd_table = table.mjd
if grid_points in {"centers", "corners"}:
mjd = _get_grid_from_centers([mjd_table]) if grid_points == "centers" else _get_grid_from_corners([mjd_table])
mjd = mjd[0]
t = Time(mjd, format="mjd", scale=table.scale)
tables = [Time(getattr(t, table.format), scale=table.scale)]
elif grid_points != "centers and corners":
raise ValueError(f"Unrecognized value for grid_points: {grid_points}. "
"Must be 'centers', 'corners', or 'centers and corners'.")

if isinstance(names, str):
names = [names]
if isinstance(physical_types, str):
Expand Down Expand Up @@ -752,8 +782,9 @@ def model(self):
"""
time = self.table
deltas = (time - self.reference_time).to(u.s)

return _model_from_quantity((deltas,), mesh=False)
points_unit = u.pix
points = ((np.arange(-1, self.table.shape[0] - 1) / 2) * points_unit,)
return _model_from_quantity(points, (deltas,), mesh=False)

def interpolate(self, new_array_grids, **kwargs):
"""
Expand Down Expand Up @@ -970,3 +1001,31 @@ def interpolate(self, new_array_grids, **kwargs):
new_obj = type(self)(*new_table_coordinates)
new_obj._dropped_coords = self._dropped_coords
return new_obj


def _get_grid_from_centers(tables):
new_tables = []
for table in tables:
new_table = np.zeros(len(table) * 2 + 1)
tv = table.value if isinstance(table, u.Quantity) else table
new_table[0] = tv[0] - (tv[1] - tv[0]) / 2
new_table[2:-1:2] = tv[:-1] + (tv[1:] - tv[:-1]) / 2
new_table[1::2] = tv
new_table[-1] = tv[-1] + (tv[-1] - tv[-2]) / 2
if isinstance(table, u.Quantity):
new_table *= table.unit
new_tables.append(new_table)
return tuple(new_tables)


def _get_grid_from_corners(tables):
new_tables = []
for table in tables:
new_table = np.zeros(len(table) * 2 - 1)
tv = table.value if isinstance(table, u.Quantity) else table
new_table[::2] = tv
new_table[1::2] = tv[:-1] + (tv[1:] - tv[:-1]) / 2
if isinstance(table, u.Quantity):
new_table *= table.unit
new_tables.append(new_table)
return tuple(new_tables)

0 comments on commit 96a9331

Please sign in to comment.