diff --git a/src/rastr/raster.py b/src/rastr/raster.py index 653f04c..dad9b65 100644 --- a/src/rastr/raster.py +++ b/src/rastr/raster.py @@ -46,7 +46,7 @@ from folium import Map from matplotlib.axes import Axes from matplotlib.image import AxesImage - from numpy.typing import ArrayLike, NDArray + from numpy.typing import ArrayLike, DTypeLike, NDArray from rasterio.io import BufferedDatasetWriter, DatasetReader, DatasetWriter from shapely.geometry.base import BaseGeometry from typing_extensions import Self @@ -341,6 +341,21 @@ def clamp( raster_meta=self.raster_meta, ) + def astype(self, dtype: DTypeLike) -> Self: + """Cast the raster array to a specified dtype. + + Returns a new raster with the array cast to the given dtype. The original + raster is not modified. + + Args: + dtype: Target data type (e.g. ``"float32"``, ``np.int16``). + + Returns: + A new Raster instance with the array cast to the specified dtype. + """ + cls = self.__class__ + return cls(arr=self.arr.astype(dtype), raster_meta=self.raster_meta) + def set_crs(self, crs: CRS | str, *, allow_override: bool = False) -> Self: """Set the CRS of the raster without reprojecting. diff --git a/tests/rastr/test_raster.py b/tests/rastr/test_raster.py index b5dd54e..17848b9 100644 --- a/tests/rastr/test_raster.py +++ b/tests/rastr/test_raster.py @@ -1183,6 +1183,82 @@ def test_preserves_dtype(self): assert result.arr.dtype == np.int32 assert result.raster_meta == raster_meta + class TestAstype: + def test_converts_to_float32(self): + # Arrange + raster_meta = RasterMeta( + crs=CRS.from_epsg(2193), + transform=Affine(1.0, 0.0, 0.0, 0.0, 1.0, 0.0), + ) + raster = Raster( + arr=np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64), + raster_meta=raster_meta, + ) + + # Act + result = raster.astype("float32") + + # Assert + assert result.arr.dtype == np.float32 + assert result.raster_meta == raster_meta + + def test_converts_to_int16(self): + # Arrange + raster_meta = RasterMeta( + crs=CRS.from_epsg(2193), + transform=Affine(1.0, 0.0, 0.0, 0.0, 1.0, 0.0), + ) + raster = Raster( + arr=np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64), + raster_meta=raster_meta, + ) + + # Act + result = raster.astype(np.int16) + + # Assert + assert result.arr.dtype == np.int16 + np.testing.assert_array_equal(result.arr, np.array([[1, 2], [3, 4]])) + + def test_preserves_values(self): + # Arrange + raster_meta = RasterMeta( + crs=CRS.from_epsg(2193), + transform=Affine(1.0, 0.0, 0.0, 0.0, 1.0, 0.0), + ) + raster = Raster( + arr=np.array([[1, 2], [3, 4]], dtype=np.int32), + raster_meta=raster_meta, + ) + + # Act + result = raster.astype(np.float64) + + # Assert + np.testing.assert_array_equal( + result.arr, np.array([[1.0, 2.0], [3.0, 4.0]]) + ) + + def test_subclass_return_type(self): + # Arrange + class MyRaster(Raster): + pass + + raster_meta = RasterMeta( + crs=CRS.from_epsg(2193), + transform=Affine(1.0, 0.0, 0.0, 0.0, 1.0, 0.0), + ) + raster = MyRaster( + arr=np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64), + raster_meta=raster_meta, + ) + + # Act + result = raster.astype("float32") + + # Assert + assert isinstance(result, MyRaster) + class TestSetCRS: def test_crs_object(self, example_raster: Raster) -> None: # Arrange