diff --git a/ci/environment-py310.yml b/ci/environment-py310.yml index 98b767e..28051c3 100644 --- a/ci/environment-py310.yml +++ b/ci/environment-py310.yml @@ -2,25 +2,27 @@ name: test_env channels: - conda-forge dependencies: - - python=3.10 + - python=3.9 - aiohttp - boto3 - exifread - flask - h5netcdf - - intake - netcdf4 - pip - pydap - pytest - rasterio - - s3fs >= 2021.08.0 + - s3fs - scikit-image - - xarray >= 0.17 - rangehttpserver + - xarray - zarr - - moto < 3 + - moto - s3fs - rioxarray - - werkzeug < 2.2.0 + - werkzeug - dask + - numpy <2 + - pip: + - git+https://github.com/intake/intake diff --git a/ci/environment-py311.yml b/ci/environment-py311.yml index 0d6e4d7..28051c3 100644 --- a/ci/environment-py311.yml +++ b/ci/environment-py311.yml @@ -2,25 +2,27 @@ name: test_env channels: - conda-forge dependencies: - - python=3.11 + - python=3.9 - aiohttp - boto3 - exifread - flask - h5netcdf - - intake - netcdf4 - pip - pydap - pytest - rasterio - - s3fs >= 2021.08.0 + - s3fs - scikit-image - - xarray >= 0.17 - rangehttpserver + - xarray - zarr - - moto < 3 + - moto - s3fs - rioxarray - - werkzeug < 2.2.0 + - werkzeug - dask + - numpy <2 + - pip: + - git+https://github.com/intake/intake diff --git a/ci/environment-py312.yml b/ci/environment-py312.yml index 5241b40..28051c3 100644 --- a/ci/environment-py312.yml +++ b/ci/environment-py312.yml @@ -2,25 +2,27 @@ name: test_env channels: - conda-forge dependencies: - - python=3.12 + - python=3.9 - aiohttp - boto3 - exifread - flask - h5netcdf - - intake - netcdf4 - pip - pydap - pytest - rasterio - - s3fs >= 2021.08.0 + - s3fs - scikit-image - rangehttpserver - - xarray >= 0.17 + - xarray - zarr - - moto < 3 + - moto - s3fs - rioxarray - - werkzeug < 2.2.0 + - werkzeug - dask + - numpy <2 + - pip: + - git+https://github.com/intake/intake diff --git a/ci/environment-py39.yml b/ci/environment-py39.yml index 43e7117..28051c3 100644 --- a/ci/environment-py39.yml +++ b/ci/environment-py39.yml @@ -8,20 +8,21 @@ dependencies: - exifread - flask - h5netcdf - - intake - netcdf4 - pip - pydap - pytest - rasterio - - s3fs >= 2021.08.0 + - s3fs - scikit-image - rangehttpserver - - xarray >= 0.17 + - xarray - zarr - - moto < 3 + - moto - s3fs - rioxarray - - werkzeug < 2.2.0 + - werkzeug - dask - + - numpy <2 + - pip: + - git+https://github.com/intake/intake diff --git a/ci/environment-upstream.yml b/ci/environment-upstream.yml index 66c07f1..9275089 100644 --- a/ci/environment-upstream.yml +++ b/ci/environment-upstream.yml @@ -19,13 +19,14 @@ dependencies: - pandas - tornado - zarr - - moto < 3 + - moto - intake - rioxarray - gdal - - werkzeug < 2.2.0 + - werkzeug - rioxarray - dask + - numpy <2 - pip: - git+https://github.com/fsspec/filesystem_spec.git - git+https://github.com/intake/intake.git diff --git a/intake_xarray/__init__.py b/intake_xarray/__init__.py index aa7bc73..9e9eb3d 100644 --- a/intake_xarray/__init__.py +++ b/intake_xarray/__init__.py @@ -2,19 +2,10 @@ __version__ = get_versions()['version'] del get_versions -import intake # Import this first to avoid circular imports during discovery. -from intake.container import register_container +import intake_xarray.base +import intake from .netcdf import NetCDFSource from .opendap import OpenDapSource from .raster import RasterIOSource -from .xzarr import ZarrSource -from .xarray_container import RemoteXarray +#from .xzarr import ZarrSource from .image import ImageSource - - -try: - intake.register_driver('remote-xarray', RemoteXarray) -except ValueError: - pass - -register_container('xarray', RemoteXarray) diff --git a/intake_xarray/base.py b/intake_xarray/base.py index 63cf85a..be68fa4 100644 --- a/intake_xarray/base.py +++ b/intake_xarray/base.py @@ -1,74 +1,19 @@ -from . import __version__ -from intake.source.base import DataSource, Schema +class IntakeXarraySourceAdapter: + container = "xarray" + name = "xarray" + version = "" + def to_dask(self): + return self.reader(chunks={}).read() -class DataSourceMixin(DataSource): - """Common behaviours for plugins in this repo""" - version = __version__ - container = 'xarray' - partition_access = True - - def _get_schema(self): - """Make schema object, which embeds xarray object and some details""" - from .xarray_container import serialize_zarr_ds - - self.urlpath = self._get_cache(self.urlpath)[0] - - if self._ds is None: - self._open_dataset() + def __call__(self, *args, **kwargs): + return self - metadata = { - 'dims': dict(self._ds.dims), - 'data_vars': {k: list(self._ds[k].coords) - for k in self._ds.data_vars.keys()}, - 'coords': tuple(self._ds.coords.keys()), - } - if getattr(self, 'on_server', False): - metadata['internal'] = serialize_zarr_ds(self._ds) - metadata.update(self._ds.attrs) - self._schema = Schema( - datashape=None, - dtype=None, - shape=None, - npartitions=None, - extra_metadata=metadata) - return self._schema + get = __call__ def read(self): - """Return a version of the xarray with all the data in memory""" - self._load_metadata() - return self._ds.load() + return self.reader.read() - def read_chunked(self): - """Return xarray object (which will have chunks)""" - self._load_metadata() - return self._ds - - def read_partition(self, i): - """Fetch one chunk of data at tuple index i - """ - import numpy as np - self._load_metadata() - if not isinstance(i, (tuple, list)): - raise TypeError('For Xarray sources, must specify partition as ' - 'tuple') - if isinstance(i, list): - i = tuple(i) - if hasattr(self._ds, 'variables') or i[0] in self._ds.coords: - arr = self._ds[i[0]].data - i = i[1:] - else: - arr = self._ds.data - if isinstance(arr, np.ndarray): - return arr - # dask array - return arr.blocks[i].compute() - - def to_dask(self): - """Return xarray object where variables are dask arrays""" - return self.read_chunked() + discover = read - def close(self): - """Delete open file from memory""" - self._ds = None - self._schema = None + read_chunked = to_dask diff --git a/intake_xarray/derived.py b/intake_xarray/derived.py deleted file mode 100644 index 392b15f..0000000 --- a/intake_xarray/derived.py +++ /dev/null @@ -1,64 +0,0 @@ -from intake import Schema -from intake.source.derived import GenericTransform - - -class XArrayTransform(GenericTransform): - """Transform where the input and output are both xarray objects. - You must supply ``transform`` and any ``transform_kwargs``. - """ - - input_container = "xarray" - container = "xarray" - optional_params = {} - _ds = None - - def to_dask(self): - if self._ds is None: - self._pick() - self._ds = self._transform( - self._source.to_dask(), **self._params["transform_kwargs"] - ) - return self._ds - - def _get_schema(self): - """load metadata only if needed""" - self.to_dask() - return Schema( - datashape=None, - dtype=None, - shape=None, - npartitions=None, - extra_metadata=self._ds.extra_metadata, - ) - - def read(self): - return self.to_dask().compute() - - -class Sel(XArrayTransform): - """Simple array transform to subsample an xarray object using - the sel method. - Note that you could use XArrayTransform directly, by writing a - function to choose the subsample instead of a method as here. - """ - - input_container = "xarray" - container = "xarray" - required_params = ["indexers"] - - def __init__(self, indexers, **kwargs): - """ - indexers: dict (stord as str) which is passed to xarray.Dataset.sel - """ - # this class wants required "indexers", but XArrayTransform - # uses "transform_kwargs", which we don't need since we use a method for the - # transform - kwargs.update( - transform=self.sel, - indexers=indexers, - transform_kwargs={}, - ) - super().__init__(**kwargs) - - def sel(self, ds): - return ds.sel(eval(self._params["indexers"])) diff --git a/intake_xarray/image.py b/intake_xarray/image.py index 673fb4e..9d8a03e 100644 --- a/intake_xarray/image.py +++ b/intake_xarray/image.py @@ -1,7 +1,9 @@ +import fsspec -from intake.source.base import PatternMixin from intake.source.utils import reverse_formats -from .base import DataSourceMixin, Schema +from intake import readers + +from intake_xarray.base import IntakeXarraySourceAdapter def _coerce_shape(array, shape): @@ -320,17 +322,13 @@ def multireader(files, chunks, concat_dim, exif_tags, **kwargs): ).chunk(chunks=chunks) -class ImageSource(DataSourceMixin, PatternMixin): +class ImageReader(readers.BaseReader): """Open a xarray dataset from image files. This creates an xarray.DataArray or an xarray.Dataset. See http://scikit-image.org/docs/dev/api/skimage.io.html#skimage.io.imread for the file formats supported. - NOTE: Although ``skimage.io.imread`` is used by default, any reader - function which accepts a file object and outputs a numpy array can be - used instead. - Parameters ---------- urlpath : str or iterable, location of data @@ -353,10 +351,6 @@ class ImageSource(DataSourceMixin, PatternMixin): concat_dim : str or iterable Dimension over which to concatenate. If iterable, all fields must be part of the the pattern. - imread : function (optional) - Optionally provide custom imread function. - Function should expect a file object and produce a numpy array. - Defaults to ``skimage.io.imread``. preprocess : function (optional) Optionally provide custom function to preprocess the image. Function should expect a numpy array for a single image and return @@ -375,22 +369,12 @@ class ImageSource(DataSourceMixin, PatternMixin): data in a data variable 'raster'. """ - name = 'xarray_image' + output_instance = "xarray:Dataset" + - def __init__(self, urlpath, chunks=None, concat_dim='concat_dim', - metadata=None, path_as_pattern=True, - storage_options=None, exif_tags=None, **kwargs): - self.path_as_pattern = path_as_pattern - self.urlpath = urlpath - self.chunks = chunks - self.concat_dim = concat_dim - self.storage_options = storage_options or {} - self.exif_tags = exif_tags - self._kwargs = kwargs - self._ds = None - super(ImageSource, self).__init__(metadata=metadata) - - def _open_files(self, files): + def _read(self, urlpath, chunks=None, concat_dim='concat_dim', + metadata=None, path_as_pattern=False, + storage_options=None, exif_tags=None, **kwargs): """ This function is called when the data source refers to more than one file either as a list or a glob. It sets up the @@ -404,93 +388,52 @@ def _open_files(self, files): import pandas as pd from xarray import DataArray + if path_as_pattern: + url = pattern_to_glob(urlpath) + __, _, paths = fsspec.get_fs_token_paths(url, **(data.storage_options or {})) + field_values = reverse_formats(data.url, paths) + else: + url = urlpath + + files = fsspec.open_files(urlpath, **(storage_options or {})) + out = multireader( - files, self.chunks, self.concat_dim, self.exif_tags, **self._kwargs + files, chunks, concat_dim, exif_tags, **kwargs ) - if not self.pattern: + if isinstance(out, DataArray) and len(files) == 1 and isinstance(urlpath, str) and "*" not in urlpath: + out = out[0] + if not path_as_pattern: return out coords = {} filenames = [f.path for f in files] - field_values = reverse_formats(self.pattern, filenames) - if isinstance(self.concat_dim, list): - if not set(field_values.keys()).issuperset(set(self.concat_dim)): + if isinstance(concat_dim, list): + if not set(field_values.keys()).issuperset(set(concat_dim)): raise KeyError('All concat_dims should be in pattern.') index = pd.MultiIndex.from_tuples( - zip(*(field_values[dim] for dim in self.concat_dim)), - names=self.concat_dim) + zip(*(field_values[dim] for dim in concat_dim)), + names=concat_dim) coords = { k: DataArray(v, dims=('dim_0')) - for k, v in field_values.items() if k not in self.concat_dim + for k, v in field_values.items() if k not in concat_dim } out = (out.assign_coords(dim_0=index, **coords) # use the index - .unstack().chunk(self.chunks)) # unstack along new index - return out.transpose(*self.concat_dim, # reorder dims - *filter(lambda x: x not in self.concat_dim, + .unstack().chunk(chunks)) # unstack along new index + return out.transpose(*concat_dim, # reorder dims + *filter(lambda x: x not in concat_dim, out.dims)) else: coords = { - k: DataArray(v, dims=self.concat_dim) + k: DataArray(v, dims=concat_dim) for k, v in field_values.items() } - return out.assign_coords(**coords).chunk(self.chunks).unify_chunks() - - def _open_dataset(self): - """ - Main entry function that finds a set of files and passes them to the - reader. - """ - from fsspec.core import open_files - - files = open_files(self.urlpath, **self.storage_options) - if len(files) == 0: - raise Exception("No files found at {}".format(self.urlpath)) - if len(files) == 1: - self._ds = reader( - files[0], self.chunks, exif_tags=self.exif_tags, **self._kwargs - ) - else: - self._ds = self._open_files(files) + return out.assign_coords(**coords).chunk(chunks).unify_chunks() - def _get_schema(self): - """Make schema object, which embeds xarray object and some details""" - import xarray as xr - import msgpack - from .xarray_container import serialize_zarr_ds - self.urlpath, *_ = self._get_cache(self.urlpath) - - if self._ds is None: - self._open_dataset() +class ImageSource(IntakeXarraySourceAdapter): + name = 'xarray_image' + container = "xarray" - # coerce to dataset for serialization - if isinstance(self._ds, xr.Dataset): - ds2 = self._ds - else: - ds2 = xr.Dataset({'raster': self._ds}) - - metadata = { - 'dims': dict(ds2.dims), - 'data_vars': {k: list(ds2[k].coords) - for k in ds2.data_vars.keys()}, - 'coords': tuple(ds2.coords.keys()), - 'array': 'raster' - } - if getattr(self, 'on_server', False): - metadata['internal'] = serialize_zarr_ds(ds2) - for k, v in ds2.raster.attrs.items(): - try: - # ensure only sending serializable attrs from remote - msgpack.packb(v) - metadata[k] = v - except TypeError: - pass - self._schema = Schema( - datashape=None, - dtype=str(ds2.raster.dtype), - shape=ds2.raster.shape, - npartitions=ds2.raster.data.npartitions, - extra_metadata=metadata) - - return self._schema + def __init__(self, *ar, **kw): + self.reader = ImageReader(*ar, **kw) diff --git a/intake_xarray/netcdf.py b/intake_xarray/netcdf.py index 1b46824..cadd2bf 100644 --- a/intake_xarray/netcdf.py +++ b/intake_xarray/netcdf.py @@ -1,12 +1,10 @@ # -*- coding: utf-8 -*- -import fsspec -from distutils.version import LooseVersion -from intake.source.base import PatternMixin -from intake.source.utils import reverse_format -from .base import DataSourceMixin +from intake import readers +from intake_xarray.base import IntakeXarraySourceAdapter -class NetCDFSource(DataSourceMixin, PatternMixin): + +class NetCDFSource(IntakeXarraySourceAdapter): """Open a xarray file. Parameters @@ -45,55 +43,13 @@ class NetCDFSource(DataSourceMixin, PatternMixin): """ name = 'netcdf' - def __init__(self, urlpath, chunks=None, combine=None, concat_dim=None, + def __init__(self, urlpath, xarray_kwargs=None, metadata=None, path_as_pattern=True, storage_options=None, **kwargs): - self.path_as_pattern = path_as_pattern - self.urlpath = urlpath - self.chunks = chunks - self.concat_dim = concat_dim - self.combine = combine - self.storage_options = storage_options or {} - self.xarray_kwargs = xarray_kwargs or {} - self._ds = None - if isinstance(self.urlpath, list): - self._can_be_local = fsspec.utils.can_be_local(self.urlpath[0]) - else: - self._can_be_local = fsspec.utils.can_be_local(self.urlpath) - super(NetCDFSource, self).__init__(metadata=metadata, **kwargs) - - def _open_dataset(self): - import xarray as xr - url = self.urlpath - - kwargs = self.xarray_kwargs - - if "*" in url or isinstance(url, list): - _open_dataset = xr.open_mfdataset - if self.pattern: - kwargs.update(preprocess=self._add_path_to_ds) - if self.combine is not None: - if 'combine' in kwargs: - raise Exception("Setting 'combine' argument twice in the catalog is invalid") - kwargs.update(combine=self.combine) - if self.concat_dim is not None: - if 'concat_dim' in kwargs: - raise Exception("Setting 'concat_dim' argument twice in the catalog is invalid") - kwargs.update(concat_dim=self.concat_dim) + data = readers.datatypes.NetCDF3(urlpath, storage_options=storage_options, + metadata=metadata) + if path_as_pattern and "{" in urlpath: + reader = readers.XArrayPatternReader(data, **(xarray_kwargs or {}), metadata=metadata, **kwargs) else: - _open_dataset = xr.open_dataset - - if self._can_be_local: - url = fsspec.open_local(self.urlpath, **self.storage_options) - else: - # https://github.com/intake/filesystem_spec/issues/476#issuecomment-732372918 - url = fsspec.open(self.urlpath, **self.storage_options).open() - - self._ds = _open_dataset(url, chunks=self.chunks, **kwargs) - - def _add_path_to_ds(self, ds): - """Adding path info to a coord for a particular file - """ - var = next(var for var in ds) - new_coords = reverse_format(self.pattern, ds[var].encoding['source']) - return ds.assign_coords(**new_coords) + reader = readers.XArrayDatasetReader(data, **(xarray_kwargs or {}), metadata=metadata, **kwargs) + self.reader = reader diff --git a/intake_xarray/opendap.py b/intake_xarray/opendap.py index 7eedaba..6d00a50 100644 --- a/intake_xarray/opendap.py +++ b/intake_xarray/opendap.py @@ -1,21 +1,11 @@ # -*- coding: utf-8 -*- -from .base import DataSourceMixin - import requests import os +from intake import readers +from intake_xarray.base import IntakeXarraySourceAdapter -def _create_generic_http_auth_session(username, password, check_url=None): - if username is None or password is None: - raise Exception("To use HTTP auth with the OPeNDAP driver you " - "need to set the DAP_USER and DAP_PASSWORD " - "environment variables") - session = requests.Session() - session.auth = (username, password) - return session - - -class OpenDapSource(DataSourceMixin): +class OpenDapSource(IntakeXarraySourceAdapter): """Open a OPeNDAP source. Parameters @@ -41,58 +31,9 @@ class OpenDapSource(DataSourceMixin): """ name = 'opendap' - def __init__(self, urlpath, chunks=None, auth=None, engine="pydap", xarray_kwargs=None, metadata=None, + def __init__(self, urlpath, chunks=None, engine="pydap", xarray_kwargs=None, metadata=None, **kwargs): - self.urlpath = urlpath - self.chunks = chunks - self.auth = auth - self.engine = engine - self._kwargs = xarray_kwargs or kwargs - self._ds = None - super(OpenDapSource, self).__init__(metadata=metadata) - - def _get_session(self): - if self.auth is None: - session = None - else: - if self.auth == "esgf": - from pydap.cas.esgf import setup_session - elif self.auth == "urs": - from pydap.cas.urs import setup_session - elif self.auth == "generic_http": - setup_session = _create_generic_http_auth_session - else: - raise ValueError( - "Authentication method should either be None, 'esgf', 'urs' or " - f"'generic_http', got '{self.auth}' instead." - ) - username = os.getenv('DAP_USER', None) - password = os.getenv('DAP_PASSWORD', None) - session = setup_session(username, password, check_url=self.urlpath) - - return session - - def _get_store(self): - import xarray as xr - session = self._get_session() - if self.engine == "netcdf4": - if session: - raise ValueError( - "Opendap session requires 'pydap' engine." - ) - return xr.backends.NetCDF4DataStore.open(self.urlpath) - elif self.engine == "pydap": - return xr.backends.PydapDataStore.open(self.urlpath, session=session) - else: - raise ValueError( - "xarray engine for opendap driver should either be 'netcdf4' or 'pydap'." - ) - - def _open_dataset(self): - import xarray as xr - - if isinstance(self.urlpath, list): - self._ds = xr.open_mfdataset(self.urlpath, chunks=self.chunks, engine=self.engine, **self._kwargs) - else: - store = self._get_store() - self._ds = xr.open_dataset(store, chunks=self.chunks, **self._kwargs) + data = readers.datatypes.OpenDAP(urlpath) + self.reader = readers.XArrayDatasetReader( + data, engine=engine, **(xarray_kwargs or {}), metadata=metadata, **kwargs + ) diff --git a/intake_xarray/raster.py b/intake_xarray/raster.py index 54f48a2..fdc1603 100644 --- a/intake_xarray/raster.py +++ b/intake_xarray/raster.py @@ -1,11 +1,9 @@ -import numpy as np -import fsspec -from intake.source.base import PatternMixin -from intake.source.utils import reverse_formats -from .base import DataSourceMixin, Schema +from intake import readers +from intake_xarray.base import IntakeXarraySourceAdapter -class RasterIOSource(DataSourceMixin, PatternMixin): + +class RasterIOSource(IntakeXarraySourceAdapter): """Open a xarray dataset via RasterIO. This creates an xarray.array, not a dataset (i.e., there is exactly one @@ -38,97 +36,14 @@ class RasterIOSource(DataSourceMixin, PatternMixin): fields. If str, is treated as pattern to match on. Default is True. """ name = 'rasterio' + container = "xarray" - def __init__(self, urlpath, chunks=None, concat_dim='concat_dim', + def __init__(self, urlpath, xarray_kwargs=None, metadata=None, path_as_pattern=True, storage_options=None, **kwargs): - self.path_as_pattern = path_as_pattern - self.urlpath = urlpath - self.chunks = chunks - self.dim = concat_dim - self.storage_options = storage_options or {} - self._kwargs = xarray_kwargs or {} - self._ds = None - if isinstance(self.urlpath, list): - self._can_be_local = fsspec.utils.can_be_local(self.urlpath[0]) - else: - self._can_be_local = fsspec.utils.can_be_local(self.urlpath) - super(RasterIOSource, self).__init__(metadata=metadata) - - def _open_files(self, files): - import xarray as xr - import rioxarray as rio - - das = [rio.open_rasterio(f, chunks=self.chunks, **self._kwargs) - for f in files] - out = xr.concat(das, dim=self.dim) - - coords = {} - if self.pattern: - coords = { - k: xr.concat( - [xr.DataArray( - np.full(das[i].sizes.get(self.dim, 1), v), - dims=self.dim - ) for i, v in enumerate(values)], dim=self.dim) - for k, values in reverse_formats(self.pattern, files).items() - } - - return out.assign_coords(**coords).chunk(self.chunks) - - def _open_dataset(self): - import xarray as xr - import rioxarray as rio - if self._can_be_local: - files = fsspec.open_local(self.urlpath, **self.storage_options) - else: - # pass URLs to delegate remote opening to rasterio library - files = self.urlpath - #files = fsspec.open(self.urlpath, **self.storage_options).open() - if isinstance(files, list): - self._ds = self._open_files(files) + data = readers.datatypes.TIFF(urlpath, storage_options=storage_options) + if path_as_pattern and "{" in urlpath: + reader = readers.XArrayPatternReader(data, **(xarray_kwargs or {}), metadata=metadata, **kwargs) else: - self._ds = rio.open_rasterio(files, chunks=self.chunks, - **self._kwargs) - - def _get_schema(self): - """Make schema object, which embeds xarray object and some details""" - from .xarray_container import serialize_zarr_ds - import msgpack - import xarray as xr - - self.urlpath, *_ = self._get_cache(self.urlpath) - - if self._ds is None: - self._open_dataset() - - ds2 = xr.Dataset({'raster': self._ds}) - metadata = { - 'dims': dict(ds2.sizes), - 'data_vars': {k: list(ds2[k].coords) - for k in ds2.data_vars.keys()}, - 'coords': tuple(ds2.coords.keys()), - 'array': 'raster' - } - if getattr(self, 'on_server', False): - metadata['internal'] = serialize_zarr_ds(ds2) - for k, v in self._ds.attrs.items(): - try: - msgpack.packb(v) - metadata[k] = v - except TypeError: - pass - - if hasattr(self._ds.data, 'npartitions'): - npart = self._ds.data.npartitions - else: - npart = None - - self._schema = Schema( - datashape=None, - dtype=str(self._ds.dtype), - shape=self._ds.shape, - npartitions=npart, - extra_metadata=metadata) - - return self._schema + reader = readers.XArrayDatasetReader(data, **(xarray_kwargs or {}), metadata=metadata, **kwargs) + self.reader = reader diff --git a/intake_xarray/tests/conftest.py b/intake_xarray/tests/conftest.py index 63b59b9..6ddc166 100644 --- a/intake_xarray/tests/conftest.py +++ b/intake_xarray/tests/conftest.py @@ -6,6 +6,7 @@ import tempfile import xarray as xr +import intake_xarray.base from intake_xarray.netcdf import NetCDFSource from intake_xarray.xzarr import ZarrSource diff --git a/intake_xarray/tests/data/catalog.yaml b/intake_xarray/tests/data/catalog.yaml index 66796b6..2e827e8 100644 --- a/intake_xarray/tests/data/catalog.yaml +++ b/intake_xarray/tests/data/catalog.yaml @@ -51,19 +51,11 @@ sources: urlpath: '{{ CATALOG_DIR }}/*.byte.tif' chunks: band: 1 - pattern_tiff_source_concat_on_band: - description: "https://github.com/mapbox/rasterio/blob/master/tests/data/.tif" - driver: rasterio - args: - urlpath: '{{ CATALOG_DIR }}/little_{color}.tif' - chunks: - band: 3 - concat_dim: band pattern_tiff_source_concat_on_new_dim: description: "https://github.com/mapbox/rasterio/blob/master/tests/data/.tif" driver: rasterio args: - urlpath: '{{ CATALOG_DIR }}/little_{color}.tif' + urlpath: '{{ CATALOG_DIR }}little_{color}.tif' chunks: band: 3 concat_dim: new_dim @@ -71,19 +63,10 @@ sources: description: "https://github.com/mapbox/rasterio/blob/master/tests/data/.tif" driver: rasterio args: - urlpath: '{{ CATALOG_DIR }}/little_{band}.tif' + urlpath: '{{ CATALOG_DIR }}little_{band}.tif' chunks: band: 3 concat_dim: band - pattern_tiff_source_path_not_as_pattern: - description: "https://github.com/mapbox/rasterio/blob/master/tests/data/.tif" - driver: rasterio - args: - urlpath: '{{ CATALOG_DIR }}/color_with_special{}.tif' - chunks: - band: 3 - concat_dim: band - path_as_pattern: False pattern_tiff_source_path_pattern_as_str: description: "https://github.com/mapbox/rasterio/blob/master/tests/data/.tif" driver: rasterio @@ -114,12 +97,3 @@ sources: chunks: {} auth: null engine: netcdf4 - xarray_source_sel: - description: select subsample of xarray_source entry - driver: intake_xarray.derived.XArrayTransform - args: - targets: - - xarray_source - transform: "intake_xarray.tests.test_derived._sel" - transform_kwargs: - indexers: "dict([('lat', 20)])" diff --git a/intake_xarray/tests/test_catalog.py b/intake_xarray/tests/test_catalog.py index f2ae4f4..e81e791 100644 --- a/intake_xarray/tests/test_catalog.py +++ b/intake_xarray/tests/test_catalog.py @@ -21,16 +21,6 @@ def test_catalog(catalog1, dataset): assert np.all(ds.rh == dataset.rh) -def test_persist(catalog1): - from intake_xarray import ZarrSource - source = catalog1['blank'] - s2 = source.persist() - assert source.has_been_persisted - assert isinstance(s2, ZarrSource) - assert s2.is_persisted - assert (source.read() == s2.read()).all() - - def test_import_error(mock_import_xarray, catalog1): s = catalog1['xarray_source']() # this is OK with pytest.raises(ImportError): diff --git a/intake_xarray/tests/test_derived.py b/intake_xarray/tests/test_derived.py deleted file mode 100644 index 16208a8..0000000 --- a/intake_xarray/tests/test_derived.py +++ /dev/null @@ -1,23 +0,0 @@ -import os -import pytest - -from intake import open_catalog -from xarray.tests import assert_allclose - - -# Function used in xarray_source_sel entry in catalog.yaml -def _sel(ds, indexers: str): - """indexers: dict (stored as str) which is passed to xarray.Dataset.sel""" - return ds.sel(eval(indexers)) - - -@pytest.fixture -def catalog(): - path = os.path.dirname(__file__) - return open_catalog(os.path.join(path, "data", "catalog.yaml")) - - -def test_catalog(catalog): - expected = catalog["xarray_source"].read().sel(lat=20) - actual = catalog["xarray_source_sel"].read() - assert_allclose(actual, expected) diff --git a/intake_xarray/tests/test_image.py b/intake_xarray/tests/test_image.py index b983eeb..3351db8 100644 --- a/intake_xarray/tests/test_image.py +++ b/intake_xarray/tests/test_image.py @@ -142,10 +142,10 @@ def test_read_image_and_exif(): urlpath = os.path.join(here, 'data', 'images', 'beach57.tif') source = ImageSource(urlpath=urlpath, exif_tags=True) ds = source.read() - assert ds['raster'].shape == (256, 252, 3) + assert ds['raster'].shape == (1, 256, 252, 3) assert ds['raster'].dtype == np.uint8 - assert ds['EXIF Image ImageWidth'].item().values == [252] - assert ds['EXIF Image ImageLength'].item().values == [256] + assert ds['EXIF Image ImageWidth'].values[0].values == [252] + assert ds['EXIF Image ImageLength'].values[0].values == [256] def test_read_image_and_given_exif_tag(): @@ -153,22 +153,13 @@ def test_read_image_and_given_exif_tag(): urlpath = os.path.join(here, 'data', 'images', 'beach57.tif') source = ImageSource(urlpath=urlpath, exif_tags=['Image ImageWidth']) ds = source.read() - assert ds['raster'].shape == (256, 252, 3) + assert ds['raster'].shape == (1, 256, 252, 3) assert ds['raster'].dtype == np.uint8 - assert ds['EXIF Image ImageWidth'].item().values == [252] + assert ds['EXIF Image ImageWidth'].values[0].values == [252] with pytest.raises(KeyError): ds['EXIF Image ImageLength'] -def test_read_images_as_glob_without_coerce_raises_error(): - pytest.importorskip('skimage') - urlpath = os.path.join(here, 'data', 'images', '*') - source = ImageSource(urlpath=urlpath) - with pytest.raises(ValueError, - match='could not broadcast input array'): - source.read() - - def test_read_images_as_glob_with_coerce(): pytest.importorskip('skimage') urlpath = os.path.join(here, 'data', 'images', '*') @@ -184,14 +175,3 @@ def test_read_images_and_exif_as_glob_with_coerce(): ds = source.read() assert ds['raster'].shape == (3, 256, 256, 3) assert ds['EXIF Image ImageWidth'].shape == (3,) - - -def test_read_images_and_persist(): - pytest.importorskip('skimage') - urlpath = os.path.join(here, 'data', 'images', '*') - source = ImageSource(urlpath=urlpath, coerce_shape=(256, 256)) - import tempfile - exported = tempfile.mkdtemp() - source.export(exported) - import xarray as xr - assert xr.open_dataset(exported, engine="zarr") diff --git a/intake_xarray/tests/test_intake_xarray.py b/intake_xarray/tests/test_intake_xarray.py index c360aab..43e84cc 100644 --- a/intake_xarray/tests/test_intake_xarray.py +++ b/intake_xarray/tests/test_intake_xarray.py @@ -5,23 +5,11 @@ import numpy as np import pytest +import xarray as xr import intake -here = os.path.dirname(__file__) - - -@pytest.mark.parametrize('source', ['netcdf', 'zarr']) -def test_discover(source, netcdf_source, zarr_source, dataset): - source = {'netcdf': netcdf_source, 'zarr': zarr_source}[source] - r = source.discover() - - assert r['dtype'] is None - assert r['metadata'] is not None - - assert source.metadata['dims'] == dict(dataset.dims) - assert set(source.metadata['data_vars']) == set(dataset.data_vars.keys()) - assert set(source.metadata['coords']) == set(dataset.coords.keys()) +here = os.path.dirname(__file__).rstrip('/') @pytest.mark.parametrize('source', ['netcdf', 'zarr']) @@ -37,16 +25,6 @@ def test_read(source, netcdf_source, zarr_source, dataset): assert np.all(ds.rh == dataset.rh) -def test_read_partition_netcdf(netcdf_source): - source = netcdf_source - with pytest.raises(TypeError): - source.read_partition(None) - out = source.read_partition(('temp', 0, 0, 0, 0)) - d = source.to_dask()['temp'].data - expected = d[:1, :4, :5, :10].compute() - assert np.all(out == expected) - - def test_read_list_of_netcdf_files_with_combine_nested(): from intake_xarray.netcdf import NetCDFSource source = NetCDFSource([ @@ -85,15 +63,6 @@ def test_read_glob_pattern_of_netcdf_files(): assert (d.num.data == np.array([1, 2])).all() -def test_read_partition_zarr(zarr_source): - source = zarr_source - with pytest.raises(TypeError): - source.read_partition(None) - out = source.read_partition(('temp', 0, 0, 0, 0)) - expected = source.to_dask()['temp'].values - assert np.all(out == expected) - - @pytest.mark.parametrize('source', ['netcdf', 'zarr']) def test_to_dask(source, netcdf_source, zarr_source, dataset): source = {'netcdf': netcdf_source, 'zarr': zarr_source}[source] @@ -121,12 +90,10 @@ def test_rasterio(): pytest.importorskip('rasterio') cat = intake.open_catalog(os.path.join(here, 'data', 'catalog.yaml')) s = cat.tiff_source - info = s.discover() - assert info['shape'] == (3, 718, 791) x = s.to_dask() - assert isinstance(x.data, da.Array) + assert isinstance(x.band_data.data, da.Array) x = s.read() - assert x.data.shape == (3, 718, 791) + assert x.band_data.shape == (3, 718, 791) def test_rasterio_glob(): @@ -134,12 +101,10 @@ def test_rasterio_glob(): pytest.importorskip('rasterio') cat = intake.open_catalog(os.path.join(here, 'data', 'catalog.yaml')) s = cat.tiff_glob_source - info = s.discover() - assert info['shape'] == (1, 3, 718, 791) x = s.to_dask() - assert isinstance(x.data, da.Array) + assert isinstance(x.band_data.data, da.Array) x = s.read() - assert x.data.shape == (1, 3, 718, 791) + assert x.band_data.shape == (3, 718, 791) def test_rasterio_empty_glob(): @@ -147,77 +112,30 @@ def test_rasterio_empty_glob(): cat = intake.open_catalog(os.path.join(here, 'data', 'catalog.yaml')) s = cat.empty_glob with pytest.raises(Exception): - s.discover() - - -def test_rasterio_cached_glob(): - import dask.array as da - pytest.importorskip('rasterio') - cat = intake.open_catalog(os.path.join(here, 'data', 'catalog.yaml')) - s = cat.cached_tiff_glob_source - cache = s.cache[0] - info = s.discover() - assert info['shape'] == (1, 3, 718, 791) - x = s.to_dask() - assert isinstance(x.data, da.Array) - x = s.read() - assert x.data.shape == (1, 3, 718, 791) - cache.clear_all() - - -def test_read_partition_tiff(): - pytest.importorskip('rasterio') - cat = intake.open_catalog(os.path.join(here, 'data', 'catalog.yaml')) - s = cat.tiff_source() - - with pytest.raises(TypeError): - s.read_partition(None) - out = s.read_partition((0, 0, 0)) - d = s.to_dask().data - expected = d[:1].compute() - assert np.all(out == expected) - - -def test_read_pattern_concat_on_existing_dim(): - pytest.importorskip('rasterio') - cat = intake.open_catalog(os.path.join(here, 'data', 'catalog.yaml')) - colors = cat.pattern_tiff_source_concat_on_band() - - da = colors.read() - assert da.shape == (6, 64, 64) - assert len(da.color) == 6 - assert set(da.color.data) == set(['red', 'green']) - - assert (da.band == [1, 2, 3, 1, 2, 3]).all() - assert da[da.color == 'red'].shape == (3, 64, 64) - - rgb = {'red': [204, 17, 17], 'green': [17, 204, 17]} - for color, values in rgb.items(): - for i, v in enumerate(values): - assert (da[da.color == color].sel(band=i+1).values == v).all() + s.read() def test_read_pattern_concat_on_new_dim(): pytest.importorskip('rasterio') cat = intake.open_catalog(os.path.join(here, 'data', 'catalog.yaml')) - colors = cat.pattern_tiff_source_concat_on_new_dim() + colors = cat.pattern_tiff_source_concat_on_new_dim - da = colors.read() + da = colors.read().band_data assert da.shape == (2, 3, 64, 64) - assert len(da.color) == 2 - assert set(da.color.data) == set(['red', 'green']) - assert da[da.color == 'red'].shape == (1, 3, 64, 64) + assert len(da.new_dim) == 2 + assert set(da.new_dim.data) == set(['red', 'green']) + assert da[da.new_dim == 'red'].shape == (1, 3, 64, 64) rgb = {'red': [204, 17, 17], 'green': [17, 204, 17]} for color, values in rgb.items(): for i, v in enumerate(values): - assert (da[da.color == color][0].sel(band=i+1).values == v).all() + assert (da[da.new_dim == color][0].sel(band=i+1).values == v).all() def test_read_pattern_field_as_band(): pytest.importorskip('rasterio') cat = intake.open_catalog(os.path.join(here, 'data', 'catalog.yaml')) - colors = cat.pattern_tiff_source_path_pattern_field_as_band() + colors = cat.pattern_tiff_source_path_pattern_field_as_band da = colors.read() assert len(da.band) == 6 @@ -230,19 +148,10 @@ def test_read_pattern_field_as_band(): assert (da[da.band == color][i].values == v).all() -def test_read_pattern_path_not_as_pattern(): - pytest.importorskip('rasterio') - cat = intake.open_catalog(os.path.join(here, 'data', 'catalog.yaml')) - green = cat.pattern_tiff_source_path_not_as_pattern() - - da = green.read() - assert len(da.band) == 3 - - def test_read_pattern_path_as_pattern_as_str_with_list_of_urlpaths(): pytest.importorskip('rasterio') cat = intake.open_catalog(os.path.join(here, 'data', 'catalog.yaml')) - colors = cat.pattern_tiff_source_path_pattern_as_str() + colors = cat.pattern_tiff_source_path_pattern_as_str da = colors.read() assert da.shape == (2, 3, 64, 64) @@ -307,8 +216,6 @@ def test_read_opendap_no_auth(engine): pytest.importorskip("pydap") cat = intake.open_catalog(os.path.join(here, "data", "catalog.yaml")) source = cat["opendap_source_{}".format(engine)] - info = source.discover() - assert info["metadata"]["dims"] == {"TIME": 12} x = source.read() assert x.TIME.shape == (12,) @@ -342,8 +249,7 @@ def test_read_opendap_mfdataset_with_engine(): with patch('xarray.open_mfdataset') as open_mfdataset_mock: open_mfdataset_mock.return_value = 'dataset' source = OpenDapSource(urlpath=urls, chunks={}, auth=None, engine='fake-engine') - source._open_dataset() - retval = source._ds + retval = source.read() assert open_mfdataset_mock.called_with(urls, chunks={}, engine='fake-engine') assert retval == 'dataset' @@ -359,8 +265,8 @@ def test_read_opendap_with_auth_netcdf4(auth): with patch( f"pydap.cas.{auth}.setup_session", return_value=1 ) as mock_setup_session: - source = OpenDapSource(urlpath=urlpath, chunks={}, auth=auth, engine="netcdf4") - with pytest.raises(ValueError): + source = OpenDapSource(urlpath=urlpath, chunks={}, auth=auth, engine="pydap") + with pytest.raises(Exception): source.discover() @@ -369,7 +275,7 @@ def test_read_opendap_invalid_auth(): from intake_xarray.opendap import OpenDapSource source = OpenDapSource(urlpath="https://test.url", chunks={}, auth="abcd", engine="pydap") - with pytest.raises(ValueError): + with pytest.raises(Exception): source.discover() diff --git a/intake_xarray/tests/test_remote.py b/intake_xarray/tests/test_remote.py deleted file mode 100644 index bf17d69..0000000 --- a/intake_xarray/tests/test_remote.py +++ /dev/null @@ -1,306 +0,0 @@ -# Tests for intake-server, local HTTP file server, local "S3" object server -import aiohttp -import intake -import numpy as np -import os -import pytest -import requests -import subprocess -import time -import xarray as xr -import fsspec -import dask -import numpy -import s3fs - -here = os.path.abspath(os.path.dirname(__file__)) -cat_file = os.path.join(here, 'data', 'catalog.yaml') -DIRECTORY = os.path.join(here, 'data') - - -@pytest.fixture(scope='module') -def data_server(): - ''' Serves test/data folder to http://localhost:8000 ''' - pwd = os.getcwd() - os.chdir(DIRECTORY) - command = ['python', '-m', 'RangeHTTPServer'] - success = False - try: - P = subprocess.Popen(command) - timeout = 10 - while True: - try: - requests.get('http://localhost:8000') - break - except: - time.sleep(0.1) - timeout -= 0.1 - assert timeout > 0 - success = False - yield 'http://localhost:8000' - finally: - os.chdir(pwd) - P.terminate() - out = P.communicate() - if not success: - print(out) - - -def test_http_server_files(data_server): - test_files = ['RGB.byte.tif', 'example_1.nc', 'example_2.nc', 'little_green.tif', 'little_red.tif'] - h = fsspec.filesystem("http") - out = h.glob(data_server + '/*') - assert len(out) > 0 - assert set([data_server+'/'+x for x in test_files]).issubset(set(out)) - -# REMOTE GEOTIFF -def test_http_open_rasterio(data_server): - url = f'{data_server}/RGB.byte.tif' - source = intake.open_rasterio(url) - da = source.to_dask() - assert isinstance(da, xr.core.dataarray.DataArray) - - -def test_http_read_rasterio(data_server): - url = f'{data_server}/RGB.byte.tif' - source = intake.open_rasterio(url) - da = source.read() - # Following line: original file CRS appears to be updated - assert ("+init" in da.attrs.get('crs', "") or "+proj" in da.attrs.get('crs', "") or - "PROJCS" in da.spatial_ref.attrs["crs_wkt"]) - assert da.attrs['AREA_OR_POINT'] == 'Area' - assert da.dtype == np.uint8 - assert da.isel(band=2,x=300,y=500).values == 129 - - -def test_http_open_rasterio_dask(data_server): - url = f'{data_server}/RGB.byte.tif' - source = intake.open_rasterio(url, chunks={}) - da = source.to_dask() - assert isinstance(da, xr.core.dataarray.DataArray) - assert isinstance(da.data, dask.array.core.Array) - - -def test_http_open_rasterio_auth(data_server): - url = f'{data_server}/RGB.byte.tif' - auth = dict(client_kwargs={'auth': aiohttp.BasicAuth('USER', 'PASS')}) - # NOTE: if url startswith 'https' use 'https' instead of 'http' for storage_options - source = intake.open_rasterio(url, - storage_options=dict(http=auth)) - source_auth = source.storage_options['http'].get('client_kwargs').get('auth') - assert isinstance(source_auth, aiohttp.BasicAuth) - - -def test_http_read_rasterio_simplecache(data_server): - url = f'simplecache::{data_server}/RGB.byte.tif' - source = intake.open_rasterio(url, chunks={}) - da = source.to_dask() - assert isinstance(da, xr.core.dataarray.DataArray) - - -def test_http_read_rasterio_pattern(data_server): - url = [data_server+'/'+x for x in ('little_red.tif', 'little_green.tif')] - source = intake.open_rasterio(url, - path_as_pattern='{}/little_{color}.tif', - concat_dim='color', - chunks={}) - da = source.to_dask() - assert isinstance(da, xr.core.dataarray.DataArray) - assert set(da.color.data) == set(['red', 'green']) - assert da.shape == (2, 3, 64, 64) - - -# REMOTE NETCDF / HDF -def test_http_open_netcdf(data_server): - url = f'{data_server}/example_1.nc' - source = intake.open_netcdf(url) - ds = source.to_dask() - assert isinstance(ds, xr.core.dataset.Dataset) - assert isinstance(ds.temp.data, numpy.ndarray) - - -def test_http_read_netcdf(data_server): - url = f'{data_server}/example_1.nc' - source = intake.open_netcdf(url) - ds = source.read() - assert ds['rh'].isel(lat=0,lon=0,time=0).values.dtype == np.float32 - assert ds['rh'].isel(lat=0,lon=0,time=0).values == 0.5 - - -def test_http_read_netcdf_dask(data_server): - url = f'{data_server}/next_example_1.nc' - source = intake.open_netcdf(url, chunks={}, - xarray_kwargs=dict(engine='h5netcdf')) - ds = source.to_dask() - # assert isinstance(ds._file_obj, xr.backends.h5netcdf_.H5NetCDFStore) - assert isinstance(ds, xr.core.dataset.Dataset) - assert isinstance(ds.temp.data, dask.array.core.Array) - - -def test_http_read_netcdf_simplecache(data_server): - url = f'simplecache::{data_server}/example_1.nc' - source = intake.open_netcdf( - url, chunks={}, - xarray_kwargs={"engine": "netcdf4"} - ) - ds = source.to_dask() - assert isinstance(ds, xr.core.dataset.Dataset) - assert isinstance(ds.temp.data, dask.array.core.Array) - - -# S3 -#based on: https://github.com/dask/s3fs/blob/master/s3fs/tests/test_s3fs.py -test_bucket_name = "test" -PORT_S3 = 8001 -endpoint_uri = "http://localhost:%s" % PORT_S3 -test_files = ['RGB.byte.tif', 'example_1.nc'] - -@pytest.fixture() -def s3_base(): - # writable local S3 system - import shlex - import subprocess - - proc = subprocess.Popen(shlex.split("moto_server s3 -p %s" % PORT_S3), - stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL, stdin=subprocess.DEVNULL) - - timeout = 5 - while timeout > 0: - try: - print("polling for moto server") - - r = requests.get(endpoint_uri) - if r.ok: - break - except: - pass - timeout -= 0.1 - time.sleep(0.1) - print("server up") - yield - print("moto done") - proc.terminate() - proc.wait() - - -@pytest.fixture(scope='function') -def aws_credentials(): - """Mocked AWS Credentials for moto.""" - os.environ['AWS_ACCESS_KEY_ID'] = 'testing' - os.environ['AWS_SECRET_ACCESS_KEY'] = 'testing' - os.environ['AWS_SECURITY_TOKEN'] = 'testing' - os.environ['AWS_SESSION_TOKEN'] = 'testing' - - -@pytest.fixture() -def s3(s3_base, aws_credentials): - ''' anonymous access local s3 bucket for testing ''' - from botocore.session import Session - session = Session() - client = session.create_client("s3", endpoint_url=endpoint_uri) - client.create_bucket(Bucket=test_bucket_name, ACL="public-read") - - for file_name in [os.path.join(DIRECTORY,x) for x in test_files]: - with open(file_name, 'rb') as f: - data = f.read() - key = os.path.basename(file_name) - client.put_object(Bucket=test_bucket_name, Key=key, Body=data) - - # Make sure cache not being used - s3fs.S3FileSystem.clear_instance_cache() - s3 = s3fs.S3FileSystem(anon=True, client_kwargs={"endpoint_url": endpoint_uri}) - s3.invalidate_cache() - yield - - -def test_s3_list_files(s3): - s3 = s3fs.S3FileSystem(anon=True, client_kwargs={"endpoint_url": endpoint_uri}) - files = s3.ls(test_bucket_name) - assert len(files) > 0 - assert set([test_bucket_name+'/'+x for x in test_files]).issubset(set(files)) - - -def test_s3_read_rasterio(s3): - # Lots of GDAL Environment variables needed for this to work ! - # https://gdal.org/user/virtual_file_systems.html#vsis3-aws-s3-files - os.environ['AWS_NO_SIGN_REQUEST']='YES' - os.environ['AWS_S3_ENDPOINT'] = endpoint_uri.lstrip('http://') - os.environ['AWS_VIRTUAL_HOSTING']= 'FALSE' - os.environ['AWS_HTTPS']= 'NO' - os.environ['GDAL_DISABLE_READDIR_ON_OPEN']='EMPTY_DIR' - os.environ['CPL_CURL_VERBOSE']='YES' - url = f's3://{test_bucket_name}/RGB.byte.tif' - source = intake.open_rasterio(url) - da = source.read() - # Following line: original file CRS appears to be updated - assert ("+init" in da.attrs.get('crs', "") or "+proj" in da.attrs.get('crs', "") or - "PROJCS" in da.spatial_ref.attrs["crs_wkt"]) - assert da.attrs['AREA_OR_POINT'] == 'Area' - assert da.dtype == np.uint8 - assert da.isel(band=2,x=300,y=500).values == 129 - - -def test_s3_read_netcdf(s3): - url = f's3://{test_bucket_name}/example_1.nc' - s3options = dict(client_kwargs={"endpoint_url": endpoint_uri}) - source = intake.open_netcdf(url, - storage_options=s3options) - ds = source.read() - assert ds['rh'].isel(lat=0,lon=0,time=0).values.dtype == np.float32 - assert ds['rh'].isel(lat=0,lon=0,time=0).values == 0.5 - - -# Remote catalogs with intake-server -@pytest.fixture(scope='module') -def intake_server(): - PORT = 8002 - command = ['intake-server', '-p', str(PORT), cat_file] - try: - P = subprocess.Popen(command) - timeout = 10 - while True: - try: - requests.get('http://localhost:{}'.format(PORT)) - break - except: - time.sleep(0.1) - timeout -= 0.1 - assert timeout > 0 - yield 'intake://localhost:{}'.format(PORT) - finally: - P.terminate() - P.communicate() - - -def test_intake_server_netcdf(intake_server): - cat_local = intake.open_catalog(cat_file) - cat = intake.open_catalog(intake_server) - assert 'xarray_source' in cat - source = cat.xarray_source() - assert isinstance(source._ds, xr.Dataset) - assert source._schema is None - source._get_schema() - assert source._schema is not None - repr(source.to_dask()) - assert (source.to_dask().rh.data.compute() == - cat_local.xarray_source.to_dask().rh.data.compute()).all() - assert (source.read() == - cat_local.xarray_source.read()).all() - - -def test_intake_server_tiff(intake_server): - pytest.importorskip('rasterio') - cat_local = intake.open_catalog(cat_file) - cat = intake.open_catalog(intake_server) - assert 'tiff_source' in cat - source = cat.tiff_source() - assert isinstance(source._ds, xr.Dataset) - assert source._schema is None - source._get_schema() - assert source._schema is not None - repr(source.to_dask()) - remote = source.to_dask().data.compute() - local = cat_local.tiff_source.to_dask().data.compute() - assert (remote == local).all() - assert (source.read() == - cat_local.xarray_source.read()).all() diff --git a/intake_xarray/xarray_container.py b/intake_xarray/xarray_container.py deleted file mode 100644 index 8148973..0000000 --- a/intake_xarray/xarray_container.py +++ /dev/null @@ -1,167 +0,0 @@ -import itertools -import os -import xarray as xr -from dask.delayed import Delayed -from intake.container.base import RemoteSource, get_partition -from intake.source.base import Schema - - -def serialize_zarr_ds(ds): - """Gather group/metadata information from a Zarr into dictionary repr - - A version of the dataset can be recreated, but will not be able to directly - load data without further manipulation. - - Use as follows - >>> out = serialize_zarr(s._ds) - - and reconstitute with - >>> d2 = xr.open_zarr(out, decode_times=False) - - (decode_times is required here because the times will be random binary - data and not be decodable) - - Parameters - ---------- - ds: xarray dataset - - Returns - ------- - dictionary with .z* keys for the various elements of the original dataset. - """ - s = {} - try: - attrs = ds.attrs.copy() - ds.attrs.pop('_ARRAY_DIMENSIONS', None) # zarr implementation detail - ds.to_zarr(store=s, chunk_store={}, compute=False, consolidated=False) - finally: - ds.attrs = attrs - return s - - -class RemoteXarray(RemoteSource): - """ - An xarray data source on the server - """ - name = 'remote-xarray' - container = 'xarray' - - def __init__(self, url, headers, **kwargs): - """ - Initialise local xarray, whose dask arrays contain tasks that pull data - - The matadata contains a key "internal", which is a result of running - ``serialize_zarr_ds`` on the xarray on the server. It is a dict - containing the metadata parts of the original dataset (i.e., the - keys with names like ".z*"). This can be opened by xarray as-is, and - will make a local xarray object. In ``._get_schema()``, the numpy - parts (coordinates) are fetched and the dask-array parts (cariables) - have their dask graphs redefined to tasks that fetch data from the - server. - """ - import xarray as xr - super(RemoteXarray, self).__init__(url, headers, **kwargs) - self._schema = None - self._ds = xr.open_zarr(self.metadata['internal'], - consolidated=False) - - def _get_schema(self): - """Reconstruct xarray arrays - - The schema returned is not very informative as a representation, - this method fetches coordinates data and creates dask arrays. - """ - import dask.array as da - if self._schema is None: - metadata = { - 'dims': dict(self._ds.dims), - 'data_vars': {k: list(self._ds[k].coords) - for k in self._ds.data_vars.keys()}, - 'coords': tuple(self._ds.coords.keys()), - } - if getattr(self, 'on_server', False): - metadata['internal'] = serialize_zarr_ds(self._ds) - metadata.update(self._ds.attrs) - self._schema = Schema( - datashape=None, - dtype=None, - shape=None, - npartitions=None, - extra_metadata=metadata) - # aparently can't replace coords in-place - # we immediately fetch the values of coordinates - # TODO: in the future, these could be functions from the metadata? - self._ds = self._ds.assign_coords(**{c: self._get_partition((c, )) - for c in metadata['coords']}) - for var in list(self._ds.data_vars): - # recreate dask arrays - name = '-'.join(['remote-xarray', var, self._source_id]) - arr = self._ds[var].data - if hasattr(arr, "chunks"): - chunks = arr.chunks - nparts = (range(len(n)) for n in chunks) - else: - nparts = ((1,), ) - if self.metadata.get('array', False): - # original was an array, not dataset - no variable name - extra = () - else: - extra = (var, ) - dask = { - (name, ) + part: (get_partition, self.url, self.headers, - self._source_id, self.container, - extra + part) - - for part in itertools.product(*nparts) - } - self._ds[var].data = da.Array( - dask, - name, - chunks, - dtype=arr.dtype, - shape=arr.shape) - if self.metadata.get('array', False): - self._ds = self._ds[self.metadata.get('array')] - return self._schema - - def _get_partition(self, i): - """ - The partition should look like ("var_name", int, int...), where the - number of ints matches the number of coordinate axes in the named - variable, and is between 0 and the number of chunks in each axis. For - an array, as opposed to a dataset, omit the variable name. - """ - return get_partition(self.url, self.headers, self._source_id, - self.container, i) - - def to_dask(self): - self._get_schema() - return self._ds - - def read_chunked(self): - """The dask repr is the authoritative chunked version""" - self._get_schema() - return self._ds - - def read(self): - self._get_schema() - return self._ds.load() - - def close(self): - self._ds = None - self._schema = None - - @staticmethod - def _persist(source, path, **kwargs): - """Save data to a local zarr - - Uses - http://xarray.pydata.org/en/stable/generated/xarray.Dataset.to_zarr.html - """ - from intake_xarray import ZarrSource - ds = source.to_dask() - if isinstance(ds, xr.DataArray): - ds = ds.to_dataset(name=ds.name if ds.name else "variable") - ds.to_zarr(path, **kwargs) - return ZarrSource(path) - diff --git a/intake_xarray/xzarr.py b/intake_xarray/xzarr.py index e6f9d95..4f001b1 100644 --- a/intake_xarray/xzarr.py +++ b/intake_xarray/xzarr.py @@ -1,7 +1,9 @@ -from .base import DataSourceMixin +from intake import readers +from intake_xarray.base import IntakeXarraySourceAdapter -class ZarrSource(DataSourceMixin): + +class ZarrSource(IntakeXarraySourceAdapter): """Open a xarray dataset. If the path is passed as a list or a string containing "*", then multifile open @@ -24,28 +26,6 @@ class ZarrSource(DataSourceMixin): name = 'zarr' def __init__(self, urlpath, storage_options=None, metadata=None, **kwargs): - super(ZarrSource, self).__init__(metadata=metadata) - self.urlpath = urlpath - self.storage_options = storage_options or {} - self.kwargs = kwargs - self._ds = None - - def _open_dataset(self): - import xarray as xr - kw = self.kwargs.copy() - if "consolidated" not in kw: - kw['consolidated'] = False - if "chunks" not in kw: - kw["chunks"] = {} - kw["engine"] = "zarr" - if self.storage_options and "storage_options" not in kw.get("backend_kwargs", {}): - kw.setdefault("backend_kwargs", {})["storage_options"] = self.storage_options - if isinstance(self.urlpath, list) or "*" in self.urlpath: - self._ds = xr.open_mfdataset(self.urlpath, **kw) - else: - self._ds = xr.open_dataset(self.urlpath, **kw) - - def close(self): - super(ZarrSource, self).close() - self._fs = None - self._mapper = None + data = readers.datatypes.Zarr(urlpath, storage_options=storage_options, + metadata=metadata) + self.reader = readers.XArrayDatasetReader(data, **kwargs) diff --git a/setup.py b/setup.py index 3b8d9cb..94958a6 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ from setuptools import setup, find_packages import versioneer -INSTALL_REQUIRES = ['intake >=0.6.6', 'xarray >=02022', 'zarr', 'dask >=2.2', 'netcdf4', 'fsspec>2022', +INSTALL_REQUIRES = ['intake >=2', 'xarray >=02022', 'zarr', 'dask >=2.2', 'netcdf4', 'fsspec>2022', 'msgpack', 'requests'] setup( @@ -30,7 +30,6 @@ 'opendap = intake_xarray.opendap:OpenDapSource', 'xarray_image = intake_xarray.image:ImageSource', 'rasterio = intake_xarray.raster:RasterIOSource', - 'remote-xarray = intake_xarray.xarray_container:RemoteXarray', ] }, package_data={'': ['*.csv', '*.yml', '*.html']},