Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions src/rasterix/lib.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Shared library utilities for rasterix."""

import logging
from typing import NotRequired, TypedDict

from affine import Affine

Expand Down Expand Up @@ -104,3 +105,70 @@ def affine_from_stac_proj_metadata(metadata: dict) -> Affine | None:

a, b, c, d, e, f = transform[:6]
return Affine(a, b, c, d, e, f)


_ZarrConventionRegistration = TypedDict("_ZarrConventionRegistration", {"spatial:": str})

_ZarrSpatialMetadata = TypedDict(
"_ZarrSpatialMetadata",
{
"zarr_conventions": NotRequired[list[_ZarrConventionRegistration | dict]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some fancy types there!

"spatial:transform": NotRequired[list[float]],
"spatial:transform_type": NotRequired[str],
"spatial:registration": NotRequired[str],
},
)


def _has_spatial_zarr_convention(metadata: _ZarrSpatialMetadata) -> bool:
zarr_conventions = metadata.get("zarr_conventions")
if not zarr_conventions:
return False
for entry in zarr_conventions:
if isinstance(entry, dict) and entry.get("name") == "spatial:":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(entry, dict) and entry.get("name") == "spatial:":
if isinstance(entry, dict) and entry.get("uuid") == "689b58e2-cf7b-45e0-9fff-9cfc0883d6b4":

If a convention specifies a uuid, it is meant to be the "primary" identifier. I'll check at the GeoZarr meeting tomorrow if implemented have been following this in practice

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

o lordy. that's unreadable!

return True
return False


def affine_from_spatial_zarr_convention(metadata: dict) -> Affine | None:
"""Extract Affine transform from Zarr spatial convention metadata.

See https://github.com/zarr-conventions/spatial for the full specification.

Parameters
----------
metadata : dict
Dictionary containing Zarr spatial convention metadata.

Returns
-------
Affine or None
Affine transformation matrix if minimal Zarr spatial metadata is found, None otherwise.

Examples
--------
>>> ds: xr.Dataset = ...
>>> affine = affine_from_spatial_zarr_convention(ds.attrs)
"""
possibly_spatial_metadata: _ZarrSpatialMetadata = metadata # type: ignore[assignment]

if _has_spatial_zarr_convention(possibly_spatial_metadata):
if transform := possibly_spatial_metadata.get("spatial:transform"):
if len(transform) < 6:
raise ValueError(f"spatial:transform must have at least 6 elements, got {len(transform)}")

transform_type = possibly_spatial_metadata.get("spatial:transform_type", "affine")
if transform_type != "affine":
raise NotImplementedError(
f"Unsupported spatial:transform_type {transform_type!r}; only 'affine' is supported."
)

registration = possibly_spatial_metadata.get("spatial:registration", "pixel")
if registration != "pixel":
raise NotImplementedError(
f"Unsupported spatial:registration {registration!r}; only 'pixel' is supported."
)

return Affine(*map(float, transform[:6]))

return None
8 changes: 6 additions & 2 deletions src/rasterix/raster_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from rasterix.odc_compat import BoundingBox, bbox_intersection, bbox_union, maybe_int, snap_grid
from rasterix.rioxarray_compat import guess_dims
from rasterix.utils import get_affine
from rasterix.utils import get_affine, get_crs_from_proj_zarr_convention

T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset")

Expand Down Expand Up @@ -87,13 +87,17 @@ def assign_index(

affine = get_affine(obj, x_dim=x_dim, y_dim=y_dim, clear_transform=True)

detected_crs = obj.proj.crs if crs else None
if detected_crs is None:
detected_crs = get_crs_from_proj_zarr_convention(obj)

index = RasterIndex.from_transform(
affine,
width=obj.sizes[x_dim],
height=obj.sizes[y_dim],
x_dim=x_dim,
y_dim=y_dim,
crs=obj.proj.crs if crs else None,
crs=detected_crs,
)
coords = Coordinates.from_xindex(index)
return obj.assign_coords(coords)
Expand Down
71 changes: 69 additions & 2 deletions src/rasterix/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from typing import NotRequired, TypedDict

import xarray as xr
from affine import Affine
from pyproj import CRS

from rasterix.lib import affine_from_stac_proj_metadata, affine_from_tiepoint_and_scale, logger
from rasterix.lib import (
affine_from_spatial_zarr_convention,
affine_from_stac_proj_metadata,
affine_from_tiepoint_and_scale,
logger,
)


def get_grid_mapping_var(obj: xr.Dataset | xr.DataArray) -> xr.DataArray | None:
Expand Down Expand Up @@ -58,7 +66,7 @@ def get_affine(
del grid_mapping_var.attrs["GeoTransform"]
return Affine.from_gdal(*map(float, transform.split(" ")))

# Check for STAC and GeoTIFF metadata in DataArray attrs
# Check for STAC, GeoTIFF, or spatial zarr convention metadata in DataArray attrs
attrs = obj.attrs if isinstance(obj, xr.DataArray) else {}

# Try to extract affine from STAC proj:transform
Expand All @@ -80,6 +88,13 @@ def get_affine(

return affine

# Try to extract from spatial zarr convention attributes
if affine := affine_from_spatial_zarr_convention(attrs):
logger.trace("Creating affine from spatial zarr convention attributes")
if clear_transform:
del attrs["spatial:transform"]
return affine

# Fall back to computing from coordinate arrays
logger.trace(f"Creating affine from coordinate arrays {x_dim=!r} and {y_dim=!r}")
if x_dim not in obj.coords or y_dim not in obj.coords:
Expand All @@ -106,3 +121,55 @@ def get_affine(
return Affine.translation(
x[0].item() - dx / 2, (y[0] if dy < 0 else y[-1]).item() - dy / 2
) * Affine.scale(dx, dy)


_ZarrConventionRegistration = TypedDict("_ZarrConventionRegistration", {"proj:": str})

_ZarrProjMetadata = TypedDict(
"_ZarrProjMetadata",
{
"zarr_conventions": NotRequired[list[_ZarrConventionRegistration | dict]],
"proj:code": NotRequired[str],
"proj:wkt2": NotRequired[str],
"proj:projjson": NotRequired[object],
},
)


def _has_proj_zarr_convention(metadata: _ZarrProjMetadata) -> bool:
zarr_conventions = metadata.get("zarr_conventions")
if not zarr_conventions:
return False
for entry in zarr_conventions:
if isinstance(entry, dict) and entry.get("name") == "proj:":
return True
return False


def get_crs_from_proj_zarr_convention(obj: xr.Dataset | xr.DataArray) -> CRS | None:
"""Extract CRS from Zarr proj: convention metadata if present.

See https://github.com/zarr-conventions/geo-proj for more details.

Parameters
----------
obj: xr.Dataset or xr.DataArray
The Xarray object to extract CRS from.

Returns
-------
CRS or None
The extracted CRS object, or None if not found.
"""
metadata: _ZarrProjMetadata = obj.attrs # type: ignore[assignment]

if not _has_proj_zarr_convention(metadata):
return None

if code := metadata.get("proj:code"):
return CRS.from_string(code)
if wkt2 := metadata.get("proj:wkt2"):
return CRS.from_wkt(wkt2)
if projjson := metadata.get("proj:projjson"):
return CRS.from_user_input(projjson)
return None
116 changes: 116 additions & 0 deletions tests/test_raster_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,75 @@ def test_assign_index_with_stac_proj_transform_9_elements():
assert actual_affine == expected_affine


def test_assign_index_with_spatial_zarr_convention():
da = xr.DataArray(
np.ones((100, 100)),
dims=("y", "x"),
attrs={
"zarr_conventions": [{"name": "spatial:"}],
"spatial:transform": [30.0, 0.0, 323400.0, 0.0, 30.0, 4268400.0],
},
)

result = assign_index(da)

# Check that the index was created
assert isinstance(result.xindexes["x"], RasterIndex)
assert isinstance(result.xindexes["y"], RasterIndex)

# Verify the affine transform
expected_affine = Affine(30.0, 0.0, 323400.0, 0.0, 30.0, 4268400.0)
actual_affine = result.xindexes["x"].transform()
assert actual_affine == expected_affine

# Verify spatial:transform attribute is removed
assert "spatial:transform" not in result.attrs


def test_assign_index_with_spatual_zarr_convention_too_few_raises():
da = xr.DataArray(
np.ones((100, 100)),
dims=("y", "x"),
attrs={
"zarr_conventions": [{"name": "spatial:"}],
"spatial:transform": [30.0, 0.0, 323400.0, 0.0, 30.0],
},
)

with pytest.raises(ValueError, match="spatial:transform must have at least 6 elements"):
assign_index(da)


def test_assign_index_with_spatual_zarr_convention_transform_type_not_implemented():
da = xr.DataArray(
np.ones((100, 100)),
dims=("y", "x"),
attrs={
"zarr_conventions": [{"name": "spatial:"}],
"spatial:transform_type": "not_affine",
"spatial:transform": [30.0, 0.0, 323400.0, 0.0, 30.0, 4268400.0],
},
)

with pytest.raises(NotImplementedError, match="Unsupported spatial:transform_type"):
assign_index(da)


def test_assign_index_with_spatual_zarr_convention_registration_not_implemented():
da = xr.DataArray(
np.ones((100, 100)),
dims=("y", "x"),
attrs={
"zarr_conventions": [{"name": "spatial:"}],
"spatial:registration": "not_pixel",
"spatial:transform": [30.0, 0.0, 323400.0, 0.0, 30.0, 4268400.0],
},
)

with pytest.raises(NotImplementedError, match="Unsupported spatial:registration"):
assign_index(da)


def test_assign_index_no_coords_no_metadata():
"""Test that assign_index raises error when coords are missing and no transform metadata."""
da = xr.DataArray(np.ones((10, 10)), dims=("y", "x"))
Expand Down Expand Up @@ -706,3 +775,50 @@ def test_raster_index_from_stac_proj_metadata_with_crs():
# Verify CRS was set
assert index.crs is not None
assert index.crs.to_epsg() == 32610


def test_assign_index_proj_zarr_convention_code():
ds = xr.DataArray(
np.ones((3, 4)),
dims=("y", "x"),
attrs={
"zarr_conventions": [{"name": "proj:"}, {"name": "spatial:"}],
"proj:code": "EPSG:4326",
"spatial:transform": [1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
},
)
indexed = assign_index(ds)
assert indexed.xindexes["x"].crs is not None
assert indexed.xindexes["x"].crs.to_epsg() == 4326


def test_assign_index_proj_zarr_convention_wkt2():
crs = pyproj.CRS.from_epsg(3857)
ds = xr.DataArray(
np.ones((3, 4)),
dims=("y", "x"),
attrs={
"zarr_conventions": [{"name": "proj:"}, {"name": "spatial:"}],
"proj:wkt2": crs.to_wkt(),
"spatial:transform": [1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
},
)
indexed = assign_index(ds)
assert indexed.xindexes["x"].crs is not None
assert indexed.xindexes["x"].crs.to_epsg() == 3857


def test_assign_index_proj_zarr_convention_projjson():
crs = pyproj.CRS.from_epsg(32610)
ds = xr.DataArray(
np.ones((3, 4)),
dims=("y", "x"),
attrs={
"zarr_conventions": [{"name": "proj:"}, {"name": "spatial:"}],
"proj:projjson": crs.to_json_dict(),
"spatial:transform": [1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
},
)
indexed = assign_index(ds)
assert indexed.xindexes["x"].crs is not None
assert indexed.xindexes["x"].crs.to_epsg() == 32610
Loading