From 3b3e59c4e2fe83124fd1ba57e036c086065231de Mon Sep 17 00:00:00 2001 From: Romain Hugonnet Date: Tue, 9 Apr 2024 13:49:18 -0800 Subject: [PATCH] Update to `Raster.dtype` refactorisation and `cast_nodata` logic from GeoUtils (#498) --- dev-environment.yml | 4 ++-- environment.yml | 4 ++-- requirements.txt | 4 ++-- tests/test_coreg/test_biascorr.py | 4 ++-- xdem/dem.py | 13 +++++++++++-- 5 files changed, 19 insertions(+), 10 deletions(-) diff --git a/dev-environment.yml b/dev-environment.yml index 759c9c95..c8dddc43 100644 --- a/dev-environment.yml +++ b/dev-environment.yml @@ -9,11 +9,11 @@ dependencies: - matplotlib=3.* - pyproj>=3.4,<4 - rasterio>=1.3,<2 - - scipy=1.* + - scipy>=1.0,<1.13 - tqdm - scikit-image=0.* - scikit-gstat>=1.0 - - geoutils>=0.1.2 + - geoutils>=0.1.4,<0.2 # Development-specific, to mirror manually in setup.cfg [options.extras_require]. - pip diff --git a/environment.yml b/environment.yml index ac0f12f1..29c7780f 100644 --- a/environment.yml +++ b/environment.yml @@ -9,11 +9,11 @@ dependencies: - matplotlib=3.* - pyproj>=3.4,<4 - rasterio>=1.3,<2 - - scipy=1.* + - scipy>=1.0,<1.13 - tqdm - scikit-image=0.* - scikit-gstat>=1.0 - - geoutils>=0.1.2 + - geoutils>=0.1.4,<0.2 - pip # To run CI against latest GeoUtils diff --git a/requirements.txt b/requirements.txt index 28413caa..a9f1cc7e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,9 +7,9 @@ numpy==1.* matplotlib==3.* pyproj>=3.4,<4 rasterio>=1.3,<2 -scipy==1.* +scipy>=1.0,<1.13 tqdm scikit-image==0.* scikit-gstat>=1.0 -geoutils>=0.1.2 +geoutils>=0.1.4,<0.2 pip diff --git a/tests/test_coreg/test_biascorr.py b/tests/test_coreg/test_biascorr.py index 25d6cdbd..5139c329 100644 --- a/tests/test_coreg/test_biascorr.py +++ b/tests/test_coreg/test_biascorr.py @@ -560,10 +560,10 @@ def test_deramp__synthetic(self, fit_args, order: int) -> None: deramp = biascorr.Deramp(poly_order=order) elev_fit_args = fit_args.copy() if isinstance(elev_fit_args["to_be_aligned_elev"], gpd.GeoDataFrame): - bias_elev = bias_dem.to_pointcloud(data_column_name="z", subsample=20000).ds + bias_elev = bias_dem.to_pointcloud(data_column_name="z", subsample=30000).ds else: bias_elev = bias_dem - deramp.fit(elev_fit_args["reference_elev"], to_be_aligned_elev=bias_elev, subsample=10000, random_state=42) + deramp.fit(elev_fit_args["reference_elev"], to_be_aligned_elev=bias_elev, subsample=20000, random_state=42) # Check high-order fit parameters are the same within 10% fit_params = deramp._meta["fit_params"] diff --git a/xdem/dem.py b/xdem/dem.py index cd5e4251..697859a2 100644 --- a/xdem/dem.py +++ b/xdem/dem.py @@ -162,6 +162,7 @@ def from_array( nodata: int | float | None = None, area_or_point: Literal["Area", "Point"] | None = None, tags: dict[str, Any] = None, + cast_nodata: bool = True, vcrs: Literal["Ellipsoid"] | Literal["EGM08"] | Literal["EGM96"] @@ -180,13 +181,21 @@ def from_array( :param nodata: Nodata value. :param area_or_point: Pixel interpretation of the raster, will be stored in AREA_OR_POINT metadata. :param tags: Metadata stored in a dictionary. + :param cast_nodata: Automatically cast nodata value to the default nodata for the new array type if not + compatible. If False, will raise an error when incompatible. :param vcrs: Vertical coordinate reference system. :returns: DEM created from the provided array and georeferencing. """ # We first apply the from_array of the parent class rast = SatelliteImage.from_array( - data=data, transform=transform, crs=crs, nodata=nodata, area_or_point=area_or_point, tags=tags + data=data, + transform=transform, + crs=crs, + nodata=nodata, + area_or_point=area_or_point, + tags=tags, + cast_nodata=cast_nodata, ) # Then add the vcrs to the class call (that builds on top of the parent class) return cls(filename_or_dataset=rast, vcrs=vcrs) @@ -300,7 +309,7 @@ def to_vcrs( zz_trans = _transform_zz(crs_from=src_ccrs, crs_to=dst_ccrs, xx=xx, yy=yy, zz=zz) # Update DEM - self._data = zz_trans.astype(self.dtypes[0]) # type: ignore + self._data = zz_trans.astype(self.dtype) # type: ignore # Update vcrs (which will update ccrs if called) self.set_vcrs(new_vcrs=vcrs)