From fdbfc84d8ab466049092fdfca164b694cb8f29ac Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Thu, 2 Dec 2021 11:48:31 +0100 Subject: [PATCH 001/107] Implement base SQLAlchemy raster driver --- terracotta/drivers/common.py | 344 +++++++++++++++++++++++++++++++++++ 1 file changed, 344 insertions(+) create mode 100644 terracotta/drivers/common.py diff --git a/terracotta/drivers/common.py b/terracotta/drivers/common.py new file mode 100644 index 00000000..c5750706 --- /dev/null +++ b/terracotta/drivers/common.py @@ -0,0 +1,344 @@ +from abc import ABC, abstractmethod +from collections import OrderedDict +import contextlib +import functools +import json +import re +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union +import urllib.parse as urlparse +import numpy as np + +import sqlalchemy as sqla +from sqlalchemy.engine.base import Connection +from terracotta import exceptions +import terracotta +from terracotta.drivers.base import requires_connection +from terracotta.drivers.raster_base import RasterDriver + + +class RelationalDriver(RasterDriver, ABC): + + SQL_DRIVER_TYPE: str # The actual DB driver, eg pymysql, psycopg2, etc + + SQLA_REAL = functools.partial(sqla.types.Float, precision=8) + SQLA_TEXT = sqla.types.Text + SQLA_BLOB = sqla.types.LargeBinary + _METADATA_COLUMNS: Tuple[Tuple[str, sqla.types.TypeEngine], ...] = ( + ('bounds_north', SQLA_REAL()), + ('bounds_east', SQLA_REAL()), + ('bounds_south', SQLA_REAL()), + ('bounds_west', SQLA_REAL()), + ('convex_hull', SQLA_TEXT()), + ('valid_percentage', SQLA_REAL()), + ('min', SQLA_REAL()), + ('max', SQLA_REAL()), + ('mean', SQLA_REAL()), + ('stdev', SQLA_REAL()), + ('percentiles', SQLA_BLOB()), + ('metadata', SQLA_TEXT()) + ) + + def __init__(self, path: str) -> None: + settings = terracotta.get_settings() + db_connection_timeout: int = settings.DB_CONNECTION_TIMEOUT + + assert self.SQL_DRIVER_TYPE is not None + cp = urlparse.urlparse(path) + connection_string = f'{cp.scheme}+{self.SQL_DRIVER_TYPE}://{cp.netloc}{cp.path}' + + self.sqla_engine = sqla.create_engine( + connection_string, + echo=True, + future=True, + connect_args={'timeout': db_connection_timeout} + ) + self.sqla_metadata = sqla.MetaData() + + self._db_keys: Optional[OrderedDict] = None + + self.connection: Connection + self.connected: bool = False + self.db_version_verified: bool = False + + # use normalized path to make sure username and password don't leak into __repr__ + qualified_path = self._normalize_path(path) + super().__init__(qualified_path) + + @contextlib.contextmanager + def connect(self) -> contextlib.AbstractContextManager: + if not self.connected: + with self.sqla_engine.connect() as connection: + self.connection = connection + self.connected = True + self._verify_db_version() + yield + self.connected = False + self.connection = None + else: + yield + + def _verify_db_version(self) -> None: + if not self.db_version_verified: + # check for version compatibility + def version_tuple(version_string: str) -> Sequence[str]: + return version_string.split('.') + + db_version = self.db_version + current_version = terracotta.__version__ + + if version_tuple(db_version)[:2] != version_tuple(current_version)[:2]: + raise exceptions.InvalidDatabaseError( + f'Version conflict: database was created in v{db_version}, ' + f'but this is v{current_version}' + ) + self.db_version_verified = True + + @property + @requires_connection + def db_version(self) -> str: + """Terracotta version used to create the database""" + terracotta_table = sqla.Table('terracotta', self.sqla_metadata, autoload_with=self.sqla_engine) + stmt = sqla.select(terracotta_table.c.version) + version = self.connection.execute(stmt).scalar() + return version + + def create(self, keys: Sequence[str], key_descriptions: Mapping[str, str] = None) -> None: + """Create and initialize database with empty tables. + + This must be called before opening the first connection. The MySQL database must not + exist already. + + Arguments: + + keys: Key names to use throughout the Terracotta database. + key_descriptions: Optional (but recommended) full-text description for some keys, + in the form of ``{key_name: description}``. + + """ + self._create_database() + self._initialize_database() + + @abstractmethod + def _create_database(self, database_name: str) -> None: + # This might be made abstract, for each subclass to implement specifically + # Note that some subclasses may not actually create any database here, as + # it may already exist for some vendors + pass + + @requires_connection + def _initialize_database(self, keys: Sequence[str], key_descriptions: Mapping[str, str] = None) -> None: + if key_descriptions is None: + key_descriptions = {} + else: + key_descriptions = dict(key_descriptions) + + if not all(k in keys for k in key_descriptions.keys()): + raise exceptions.InvalidKeyError('key description dict contains unknown keys') + + if not all(re.match(r'^\w+$', key) for key in keys): + raise exceptions.InvalidKeyError('key names must be alphanumeric') + + if any(key in self._RESERVED_KEYS for key in keys): + raise exceptions.InvalidKeyError(f'key names cannot be one of {self._RESERVED_KEYS!s}') + + terracotta_table = sqla.Table( + 'terracotta', self.sqla_metadata, + sqla.Column('version', sqla.types.String(255), primary_key=True) + ) + key_names_table = sqla.Table( + 'key_names', self.sqla_metadata, + sqla.Column('key_name', sqla.types.String, primary_key=True), + sqla.Column('description', sqla.types.String(8000)) + ) + datasets_table = sqla.Table( + 'datasets', self.sqla_metadata, + *[sqla.Column(key, sqla.types.String, primary_key=True) for key in keys], + sqla.Column('filepath', sqla.types.String(8000)) + ) + metadata_table = sqla.Table( + 'metadata', self.sqla_metadata, + *[sqla.Column(key, sqla.types.String, primary_key=True) for key in keys], + *self._METADATA_COLUMNS + ) + self.sqla_metadata.create_all(self.sqla_engine) + + self.connection.execute( + terracotta_table.insert().values(version=terracotta.__version__) + ) + self.connection.execute( + key_names_table.insert(), + [dict(key_name=key, description=key_descriptions.get(key, '')) for key in keys] + ) + self.connection.commit() + + # invalidate key cache # TODO: Is that actually necessary? + self._db_keys = None + + @requires_connection + def get_keys(self) -> OrderedDict: + keys_table = sqla.Table('key_names', self.sqla_metadata, autoload_with=self.sqla_engine) + result = self.connection.execute(keys_table.select()) + return OrderedDict(result.all()) + + @property + def key_names(self) -> Tuple[str]: + """Names of all keys defined by the database""" + if self._db_keys is None: + self._db_keys = self.get_keys() + return tuple(self._db_keys.keys()) + + @requires_connection + def get_datasets(self, where: Mapping[str, Union[str, List[str]]] = None, page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], str]: + # Ensure standardized structure of where items + if where is None: + where = {} + for key, value in where.items(): + if not isinstance(value, list): + where[key] = [value] + + datasets_table = sqla.Table('datasets', self.sqla_metadata, autoload_with=self.sqla_engine) + stmt = datasets_table \ + .select() \ + .where(* + [ + sqla.or_(*[datasets_table.c.get(column) == value for value in values]) + for column, values in where.items() + ] + ) \ + .order_by(*datasets_table.c.values()) \ + .limit(limit) \ + .offset(page * limit if limit is not None else None) + + result = self.connection.execute(stmt) + + def keytuple(row: Dict[str, Any]) -> Tuple[str, ...]: + return tuple(row[key] for key in self.key_names) + + datasets = {keytuple(row): row['filepath'] for row in result} + return datasets + + @requires_connection + def get_metadata(self, keys: Union[Sequence[str], Mapping[str, str]]) -> Dict[str, Any]: + keys = tuple(self._key_dict_to_sequence(keys)) + if len(keys) != len(self.key_names): + raise exceptions.InvalidKeyError( + f'Got wrong number of keys (available keys: {self.key_names})' + ) + + metadata_table = sqla.Table('metadata', self.sqla_metadata, autoload_with=self.sqla_engine) + stmt = metadata_table \ + .select() \ + .where(* + [ + metadata_table.c.get(key) == value + for key, value in zip(self.key_names, keys) + ] + ) + + row = self.connection.execute(stmt).first() + + if not row: # support lazy loading + filepath = self.get_datasets(dict(zip(self.key_names, keys))) + if not filepath: + raise exceptions.DatasetNotFoundError(f'No dataset found for given keys {keys}') + assert len(filepath) == 1 + + # compute metadata and try again + self.insert(keys, filepath[keys], skip_metadata=False) + row = self.connection.execute(stmt).first() + + assert row + + data_columns, _ = zip(*self._METADATA_COLUMNS) + encoded_data = {col: row[col] for col in self.key_names + data_columns} + return self._decode_data(encoded_data) + + @requires_connection + def insert( + self, + keys: Union[Sequence[str], Mapping[str, str]], + filepath: str, *, + metadata: Mapping[str, Any] = None, + skip_metadata: bool = False, + override_path: str = None + ) -> None: + if len(keys) != len(self.key_names): + raise exceptions.InvalidKeyError( + f'Got wrong number of keys (available keys: {self.key_names})' + ) + + if override_path is None: + override_path = filepath + + keys = self._key_dict_to_sequence(keys) + key_dict = dict(zip(self.key_names, keys)) + + datasets_table = sqla.Table('datasets', self.sqla_metadata, autoload_with=self.sqla_engine) + metadata_table = sqla.Table('metadata', self.sqla_metadata, autoload_with=self.sqla_engine) + + self.connection.execute(datasets_table.delete().where(*[datasets_table.c.get(column) == value for column, value in key_dict.items()])) + self.connection.execute(datasets_table.insert().values(**key_dict, filepath=override_path)) + + if metadata is None and not skip_metadata: + metadata = self.compute_metadata(filepath) + + if metadata is not None: + encoded_data = self._encode_data(metadata) + self.connection.execute(metadata_table.delete().where(*[metadata_table.c.get(column) == value for column, value in key_dict.items()])) + self.connection.execute(metadata_table.insert().values(**key_dict, **encoded_data)) + + self.connection.commit() + + @requires_connection + def delete(self, keys: Union[Sequence[str], Mapping[str, str]], silent=False) -> None: + if len(keys) != len(self.key_names): + raise exceptions.InvalidKeyError( + f'Got wrong number of keys (available keys: {self.key_names})' + ) + + keys = self._key_dict_to_sequence(keys) + key_dict = dict(zip(self.key_names, keys)) + + if not self.get_datasets(key_dict): + raise exceptions.DatasetNotFoundError(f'No dataset found with keys {keys}') + + datasets_table = sqla.Table('datasets', self.sqla_metadata, autoload_with=self.sqla_engine) + metadata_table = sqla.Table('metadata', self.sqla_metadata, autoload_with=self.sqla_engine) + + self.connection.execute(datasets_table.delete().where(*[datasets_table.c.get(column) == value for column, value in key_dict.items()])) + self.connection.execute(metadata_table.delete().where(*[metadata_table.c.get(column) == value for column, value in key_dict.items()])) + self.connection.commit() + + @staticmethod + def _encode_data(decoded: Mapping[str, Any]) -> Dict[str, Any]: + """Transform from internal format to database representation""" + encoded = { + 'bounds_north': decoded['bounds'][0], + 'bounds_east': decoded['bounds'][1], + 'bounds_south': decoded['bounds'][2], + 'bounds_west': decoded['bounds'][3], + 'convex_hull': json.dumps(decoded['convex_hull']), + 'valid_percentage': decoded['valid_percentage'], + 'min': decoded['range'][0], + 'max': decoded['range'][1], + 'mean': decoded['mean'], + 'stdev': decoded['stdev'], + 'percentiles': np.array(decoded['percentiles'], dtype='float32').tobytes(), + 'metadata': json.dumps(decoded['metadata']) + } + return encoded + + @staticmethod + def _decode_data(encoded: Mapping[str, Any]) -> Dict[str, Any]: + """Transform from database format to internal representation""" + decoded = { + 'bounds': tuple([encoded[f'bounds_{d}'] for d in ('north', 'east', 'south', 'west')]), + 'convex_hull': json.loads(encoded['convex_hull']), + 'valid_percentage': encoded['valid_percentage'], + 'range': (encoded['min'], encoded['max']), + 'mean': encoded['mean'], + 'stdev': encoded['stdev'], + 'percentiles': np.frombuffer(encoded['percentiles'], dtype='float32').tolist(), + 'metadata': json.loads(encoded['metadata']) + } + return decoded From b6602fa637b8e2009a32f05419ae525410a7a79f Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Thu, 2 Dec 2021 11:53:31 +0100 Subject: [PATCH 002/107] Add String size on table creation --- terracotta/drivers/common.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/terracotta/drivers/common.py b/terracotta/drivers/common.py index c5750706..45307a3a 100644 --- a/terracotta/drivers/common.py +++ b/terracotta/drivers/common.py @@ -19,6 +19,7 @@ class RelationalDriver(RasterDriver, ABC): SQL_DRIVER_TYPE: str # The actual DB driver, eg pymysql, psycopg2, etc + SQL_KEY_SIZE: int SQLA_REAL = functools.partial(sqla.types.Float, precision=8) SQLA_TEXT = sqla.types.Text @@ -147,17 +148,17 @@ def _initialize_database(self, keys: Sequence[str], key_descriptions: Mapping[st ) key_names_table = sqla.Table( 'key_names', self.sqla_metadata, - sqla.Column('key_name', sqla.types.String, primary_key=True), + sqla.Column('key_name', sqla.types.String(self.SQL_KEY_SIZE), primary_key=True), sqla.Column('description', sqla.types.String(8000)) ) datasets_table = sqla.Table( 'datasets', self.sqla_metadata, - *[sqla.Column(key, sqla.types.String, primary_key=True) for key in keys], + *[sqla.Column(key, sqla.types.String(self.SQL_KEY_SIZE), primary_key=True) for key in keys], sqla.Column('filepath', sqla.types.String(8000)) ) metadata_table = sqla.Table( 'metadata', self.sqla_metadata, - *[sqla.Column(key, sqla.types.String, primary_key=True) for key in keys], + *[sqla.Column(key, sqla.types.String(self.SQL_KEY_SIZE), primary_key=True) for key in keys], *self._METADATA_COLUMNS ) self.sqla_metadata.create_all(self.sqla_engine) From 4465d0ee5d108441328d729817278e610413f6a5 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Thu, 2 Dec 2021 14:27:39 +0100 Subject: [PATCH 003/107] Use readonly properties for db_version and key_names --- terracotta/drivers/base.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/terracotta/drivers/base.py b/terracotta/drivers/base.py index 37cd738b..58b67084 100644 --- a/terracotta/drivers/base.py +++ b/terracotta/drivers/base.py @@ -28,8 +28,15 @@ class Driver(ABC): """ _RESERVED_KEYS = ('limit', 'page') - db_version: str #: Terracotta version used to create the database - key_names: Tuple[str] #: Names of all keys defined by the database + @property + @abstractmethod + def db_version(self) -> str: + ... # Terracotta version used to create the database + + @property + @abstractmethod + def key_names(self) -> Tuple[str, ...]: + ... # Names of all keys defined by the database @abstractmethod def __init__(self, url_or_path: str) -> None: From abf5e85134c51e387d38b2b009587e6d304df01c Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Thu, 2 Dec 2021 14:28:32 +0100 Subject: [PATCH 004/107] Adhere flake8 and mypy, and introduce DB_SCHEME --- terracotta/drivers/common.py | 187 +++++++++++++++++++++++------------ 1 file changed, 126 insertions(+), 61 deletions(-) diff --git a/terracotta/drivers/common.py b/terracotta/drivers/common.py index 45307a3a..58abff41 100644 --- a/terracotta/drivers/common.py +++ b/terracotta/drivers/common.py @@ -1,24 +1,28 @@ -from abc import ABC, abstractmethod -from collections import OrderedDict import contextlib import functools import json import re -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union import urllib.parse as urlparse -import numpy as np +from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import (Any, Dict, Iterator, List, Mapping, Optional, Sequence, + Tuple, Union) +import numpy as np import sqlalchemy as sqla +import terracotta from sqlalchemy.engine.base import Connection from terracotta import exceptions -import terracotta from terracotta.drivers.base import requires_connection from terracotta.drivers.raster_base import RasterDriver +from terracotta.profile import trace class RelationalDriver(RasterDriver, ABC): + # NOTE: `convert_exceptions` decorators are NOT added yet - SQL_DRIVER_TYPE: str # The actual DB driver, eg pymysql, psycopg2, etc + SQL_DATABASE_SCHEME: str # The database flavour, eg mysql, sqlite, etc + SQL_DRIVER_TYPE: str # The actual database driver, eg pymysql, sqlite3, etc SQL_KEY_SIZE: int SQLA_REAL = functools.partial(sqla.types.Float, precision=8) @@ -44,7 +48,8 @@ def __init__(self, path: str) -> None: db_connection_timeout: int = settings.DB_CONNECTION_TIMEOUT assert self.SQL_DRIVER_TYPE is not None - cp = urlparse.urlparse(path) + self._CONNECTION_PARAMETERS = self._parse_connection_string(path) + cp = self._CONNECTION_PARAMETERS connection_string = f'{cp.scheme}+{self.SQL_DRIVER_TYPE}://{cp.netloc}{cp.path}' self.sqla_engine = sqla.create_engine( @@ -64,9 +69,23 @@ def __init__(self, path: str) -> None: # use normalized path to make sure username and password don't leak into __repr__ qualified_path = self._normalize_path(path) super().__init__(qualified_path) - + + @classmethod + def _parse_connection_string(cls, connection_string: str) -> urlparse.ParseResult: + con_params = urlparse.urlparse(connection_string) + + if not con_params.hostname: + con_params = urlparse.urlparse(f'{cls.SQL_DATABASE_SCHEME}://{connection_string}') + + assert con_params.hostname is not None + + if con_params.scheme != cls.SQL_DATABASE_SCHEME: + raise ValueError(f'unsupported URL scheme "{con_params.scheme}"') + + return con_params + @contextlib.contextmanager - def connect(self) -> contextlib.AbstractContextManager: + def connect(self) -> Iterator: if not self.connected: with self.sqla_engine.connect() as connection: self.connection = connection @@ -83,7 +102,7 @@ def _verify_db_version(self) -> None: # check for version compatibility def version_tuple(version_string: str) -> Sequence[str]: return version_string.split('.') - + db_version = self.db_version current_version = terracotta.__version__ @@ -94,11 +113,15 @@ def version_tuple(version_string: str) -> Sequence[str]: ) self.db_version_verified = True - @property + @property # type: ignore @requires_connection def db_version(self) -> str: """Terracotta version used to create the database""" - terracotta_table = sqla.Table('terracotta', self.sqla_metadata, autoload_with=self.sqla_engine) + terracotta_table = sqla.Table( + 'terracotta', + self.sqla_metadata, + autoload_with=self.sqla_engine + ) stmt = sqla.select(terracotta_table.c.version) version = self.connection.execute(stmt).scalar() return version @@ -117,17 +140,21 @@ def create(self, keys: Sequence[str], key_descriptions: Mapping[str, str] = None """ self._create_database() - self._initialize_database() + self._initialize_database(keys, key_descriptions) @abstractmethod - def _create_database(self, database_name: str) -> None: + def _create_database(self) -> None: # This might be made abstract, for each subclass to implement specifically # Note that some subclasses may not actually create any database here, as # it may already exist for some vendors pass @requires_connection - def _initialize_database(self, keys: Sequence[str], key_descriptions: Mapping[str, str] = None) -> None: + def _initialize_database( + self, + keys: Sequence[str], + key_descriptions: Mapping[str, str] = None + ) -> None: if key_descriptions is None: key_descriptions = {} else: @@ -141,7 +168,7 @@ def _initialize_database(self, keys: Sequence[str], key_descriptions: Mapping[st if any(key in self._RESERVED_KEYS for key in keys): raise exceptions.InvalidKeyError(f'key names cannot be one of {self._RESERVED_KEYS!s}') - + terracotta_table = sqla.Table( 'terracotta', self.sqla_metadata, sqla.Column('version', sqla.types.String(255), primary_key=True) @@ -151,14 +178,15 @@ def _initialize_database(self, keys: Sequence[str], key_descriptions: Mapping[st sqla.Column('key_name', sqla.types.String(self.SQL_KEY_SIZE), primary_key=True), sqla.Column('description', sqla.types.String(8000)) ) - datasets_table = sqla.Table( + _ = sqla.Table( 'datasets', self.sqla_metadata, - *[sqla.Column(key, sqla.types.String(self.SQL_KEY_SIZE), primary_key=True) for key in keys], + *[ + sqla.Column(key, sqla.types.String(self.SQL_KEY_SIZE), primary_key=True) for key in keys], # noqa: E501 sqla.Column('filepath', sqla.types.String(8000)) ) - metadata_table = sqla.Table( + _ = sqla.Table( 'metadata', self.sqla_metadata, - *[sqla.Column(key, sqla.types.String(self.SQL_KEY_SIZE), primary_key=True) for key in keys], + *[sqla.Column(key, sqla.types.String(self.SQL_KEY_SIZE), primary_key=True) for key in keys], # noqa: E501 *self._METADATA_COLUMNS ) self.sqla_metadata.create_all(self.sqla_engine) @@ -171,7 +199,7 @@ def _initialize_database(self, keys: Sequence[str], key_descriptions: Mapping[st [dict(key_name=key, description=key_descriptions.get(key, '')) for key in keys] ) self.connection.commit() - + # invalidate key cache # TODO: Is that actually necessary? self._db_keys = None @@ -182,42 +210,53 @@ def get_keys(self) -> OrderedDict: return OrderedDict(result.all()) @property - def key_names(self) -> Tuple[str]: + def key_names(self) -> Tuple[str, ...]: """Names of all keys defined by the database""" if self._db_keys is None: self._db_keys = self.get_keys() return tuple(self._db_keys.keys()) + @trace('get_datasets') @requires_connection - def get_datasets(self, where: Mapping[str, Union[str, List[str]]] = None, page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], str]: + def get_datasets( + self, + where: Mapping[str, Union[str, List[str]]] = None, + page: int = 0, + limit: int = None + ) -> Dict[Tuple[str, ...], str]: # Ensure standardized structure of where items if where is None: where = {} + else: + where = dict(where) for key, value in where.items(): if not isinstance(value, list): where[key] = [value] datasets_table = sqla.Table('datasets', self.sqla_metadata, autoload_with=self.sqla_engine) - stmt = datasets_table \ - .select() \ - .where(* - [ - sqla.or_(*[datasets_table.c.get(column) == value for value in values]) - for column, values in where.items() - ] - ) \ - .order_by(*datasets_table.c.values()) \ - .limit(limit) \ - .offset(page * limit if limit is not None else None) + stmt = ( + datasets_table + .select() + .where( + *[ + sqla.or_(*[datasets_table.c.get(column) == value for value in values]) + for column, values in where.items() + ] + ) + .order_by(*datasets_table.c.values()) + .limit(limit) + .offset(page * limit if limit is not None else None) + ) result = self.connection.execute(stmt) - + def keytuple(row: Dict[str, Any]) -> Tuple[str, ...]: return tuple(row[key] for key in self.key_names) - + datasets = {keytuple(row): row['filepath'] for row in result} return datasets - + + @trace('get_metadata') @requires_connection def get_metadata(self, keys: Union[Sequence[str], Mapping[str, str]]) -> Dict[str, Any]: keys = tuple(self._key_dict_to_sequence(keys)) @@ -225,19 +264,21 @@ def get_metadata(self, keys: Union[Sequence[str], Mapping[str, str]]) -> Dict[st raise exceptions.InvalidKeyError( f'Got wrong number of keys (available keys: {self.key_names})' ) - + metadata_table = sqla.Table('metadata', self.sqla_metadata, autoload_with=self.sqla_engine) - stmt = metadata_table \ - .select() \ - .where(* - [ - metadata_table.c.get(key) == value - for key, value in zip(self.key_names, keys) - ] - ) - + stmt = ( + metadata_table + .select() + .where( + *[ + metadata_table.c.get(key) == value + for key, value in zip(self.key_names, keys) + ] + ) + ) + row = self.connection.execute(stmt).first() - + if not row: # support lazy loading filepath = self.get_datasets(dict(zip(self.key_names, keys))) if not filepath: @@ -247,13 +288,14 @@ def get_metadata(self, keys: Union[Sequence[str], Mapping[str, str]]) -> Dict[st # compute metadata and try again self.insert(keys, filepath[keys], skip_metadata=False) row = self.connection.execute(stmt).first() - + assert row data_columns, _ = zip(*self._METADATA_COLUMNS) encoded_data = {col: row[col] for col in self.key_names + data_columns} return self._decode_data(encoded_data) + @trace('insert') @requires_connection def insert( self, @@ -267,47 +309,70 @@ def insert( raise exceptions.InvalidKeyError( f'Got wrong number of keys (available keys: {self.key_names})' ) - + if override_path is None: override_path = filepath - + keys = self._key_dict_to_sequence(keys) key_dict = dict(zip(self.key_names, keys)) datasets_table = sqla.Table('datasets', self.sqla_metadata, autoload_with=self.sqla_engine) metadata_table = sqla.Table('metadata', self.sqla_metadata, autoload_with=self.sqla_engine) - self.connection.execute(datasets_table.delete().where(*[datasets_table.c.get(column) == value for column, value in key_dict.items()])) - self.connection.execute(datasets_table.insert().values(**key_dict, filepath=override_path)) + self.connection.execute( + datasets_table + .delete() + .where(*[datasets_table.c.get(column) == value for column, value in key_dict.items()]) + ) + self.connection.execute( + datasets_table.insert().values(**key_dict, filepath=override_path) + ) if metadata is None and not skip_metadata: metadata = self.compute_metadata(filepath) if metadata is not None: encoded_data = self._encode_data(metadata) - self.connection.execute(metadata_table.delete().where(*[metadata_table.c.get(column) == value for column, value in key_dict.items()])) - self.connection.execute(metadata_table.insert().values(**key_dict, **encoded_data)) - + self.connection.execute( + metadata_table + .delete() + .where( + *[metadata_table.c.get(column) == value for column, value in key_dict.items()] + ) + ) + self.connection.execute( + metadata_table.insert().values(**key_dict, **encoded_data) + ) + self.connection.commit() + @trace('delete') @requires_connection - def delete(self, keys: Union[Sequence[str], Mapping[str, str]], silent=False) -> None: + def delete(self, keys: Union[Sequence[str], Mapping[str, str]]) -> None: if len(keys) != len(self.key_names): raise exceptions.InvalidKeyError( f'Got wrong number of keys (available keys: {self.key_names})' ) - + keys = self._key_dict_to_sequence(keys) key_dict = dict(zip(self.key_names, keys)) if not self.get_datasets(key_dict): raise exceptions.DatasetNotFoundError(f'No dataset found with keys {keys}') - + datasets_table = sqla.Table('datasets', self.sqla_metadata, autoload_with=self.sqla_engine) metadata_table = sqla.Table('metadata', self.sqla_metadata, autoload_with=self.sqla_engine) - self.connection.execute(datasets_table.delete().where(*[datasets_table.c.get(column) == value for column, value in key_dict.items()])) - self.connection.execute(metadata_table.delete().where(*[metadata_table.c.get(column) == value for column, value in key_dict.items()])) + self.connection.execute( + datasets_table + .delete() + .where(*[datasets_table.c.get(column) == value for column, value in key_dict.items()]) + ) + self.connection.execute( + metadata_table + .delete() + .where(*[metadata_table.c.get(column) == value for column, value in key_dict.items()]) + ) self.connection.commit() @staticmethod From 50ed8e416440da718501741aba4533422e89fa74 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Thu, 2 Dec 2021 14:28:59 +0100 Subject: [PATCH 005/107] Use RelationalDriver --- terracotta/drivers/mysql.py | 440 ++---------------------------------- 1 file changed, 20 insertions(+), 420 deletions(-) diff --git a/terracotta/drivers/mysql.py b/terracotta/drivers/mysql.py index 18ff9ecd..f0c09a38 100644 --- a/terracotta/drivers/mysql.py +++ b/terracotta/drivers/mysql.py @@ -4,78 +4,31 @@ to be present on disk. """ -from typing import (List, Tuple, Dict, Iterator, Sequence, Union, - Mapping, Any, Optional, cast, TypeVar) -from collections import OrderedDict import contextlib -from contextlib import AbstractContextManager -import re -import json import urllib.parse as urlparse +from typing import Iterator, TypeVar from urllib.parse import ParseResult -import numpy as np -import pymysql -from pymysql.connections import Connection -from pymysql.cursors import DictCursor - -from terracotta import get_settings, __version__ -from terracotta.drivers.raster_base import RasterDriver -from terracotta.drivers.base import requires_connection +import sqlalchemy as sqla from terracotta import exceptions -from terracotta.profile import trace - +from terracotta.drivers.common import RelationalDriver T = TypeVar('T') -_ERROR_ON_CONNECT = ( - 'Could not connect to database. Make sure that the given path points ' - 'to a valid Terracotta database, and that you ran driver.create().' -) - DEFAULT_PORT = 3306 @contextlib.contextmanager def convert_exceptions(msg: str) -> Iterator: """Convert internal mysql exceptions to our InvalidDatabaseError""" - from pymysql import OperationalError, InternalError, ProgrammingError + from pymysql import InternalError, OperationalError, ProgrammingError try: yield except (OperationalError, InternalError, ProgrammingError) as exc: raise exceptions.InvalidDatabaseError(msg) from exc -class MySQLCredentials: - __slots__ = ('host', 'port', 'db', '_user', '_password') - - def __init__(self, - host: str, - port: int, - db: str, - user: Optional[str] = None, - password: Optional[str] = None): - self.host = host - self.port = port - self.db = db - self._user = user - self._password = password - - @property - def user(self) -> Optional[str]: - return self._user or get_settings().MYSQL_USER - - @property - def password(self) -> str: - pw = self._password or get_settings().MYSQL_PASSWORD - - if pw is None: - pw = '' - - return pw - - -class MySQLDriver(RasterDriver): +class MySQLDriver(RelationalDriver): """A MySQL-backed raster driver. Assumes raster data to be present in separate GDAL-readable files on disk or remotely. @@ -92,21 +45,9 @@ class MySQLDriver(RasterDriver): This driver caches raster data and key names, but not metadata. """ - _MAX_PRIMARY_KEY_LENGTH = 767 // 4 # Max key length for MySQL is at least 767B - _METADATA_COLUMNS: Tuple[Tuple[str, ...], ...] = ( - ('bounds_north', 'REAL'), - ('bounds_east', 'REAL'), - ('bounds_south', 'REAL'), - ('bounds_west', 'REAL'), - ('convex_hull', 'LONGTEXT'), - ('valid_percentage', 'REAL'), - ('min', 'REAL'), - ('max', 'REAL'), - ('mean', 'REAL'), - ('stdev', 'REAL'), - ('percentiles', 'BLOB'), - ('metadata', 'LONGTEXT') - ) + SQL_DATABASE_SCHEME = 'mysql' + SQL_DRIVER_TYPE = 'pymysql' + SQL_KEY_SIZE = 50 _CHARSET: str = 'utf8mb4' def __init__(self, mysql_path: str) -> None: @@ -120,38 +61,8 @@ def __init__(self, mysql_path: str) -> None: ``mysql://username:password@hostname/database`` """ - settings = get_settings() - - self.DB_CONNECTION_TIMEOUT: int = settings.DB_CONNECTION_TIMEOUT - - con_params = urlparse.urlparse(mysql_path) - - if not con_params.hostname: - con_params = urlparse.urlparse(f'mysql://{mysql_path}') - - assert con_params.hostname is not None - - if con_params.scheme != 'mysql': - raise ValueError(f'unsupported URL scheme "{con_params.scheme}"') - - self._db_args = MySQLCredentials( - host=con_params.hostname, - user=con_params.username, - password=con_params.password, - port=con_params.port or DEFAULT_PORT, - db=self._parse_db_name(con_params) - ) - - self._connection: Connection - self._cursor: DictCursor - self._connected = False - - self._version_checked: bool = False - self._db_keys: Optional[OrderedDict] = None - - # use normalized path to make sure username and password don't leak into __repr__ - qualified_path = self._normalize_path(mysql_path) - super().__init__(qualified_path) + super().__init__(mysql_path) + self._parse_db_name(self._CONNECTION_PARAMETERS) # To enforce path is parsable @classmethod def _normalize_path(cls, path: str) -> str: @@ -175,325 +86,14 @@ def _parse_db_name(con_params: ParseResult) -> str: return path - @requires_connection - @convert_exceptions(_ERROR_ON_CONNECT) - def _get_db_version(self) -> str: - """Terracotta version used to create the database""" - cursor = self._cursor - cursor.execute('SELECT version from terracotta') - db_row = cast(Dict[str, str], cursor.fetchone()) - return db_row['version'] - - db_version = cast(str, property(_get_db_version)) - - def _connection_callback(self) -> None: - if not self._version_checked: - # check for version compatibility - def versiontuple(version_string: str) -> Sequence[str]: - return version_string.split('.') - - db_version = self.db_version - current_version = __version__ - - if versiontuple(db_version)[:2] != versiontuple(current_version)[:2]: - raise exceptions.InvalidDatabaseError( - f'Version conflict: database was created in v{db_version}, ' - f'but this is v{current_version}' - ) - self._version_checked = True - - def _get_key_names(self) -> Tuple[str, ...]: - """Names of all keys defined by the database""" - return tuple(self.get_keys().keys()) - - key_names = cast(Tuple[str], property(_get_key_names)) - - def connect(self) -> AbstractContextManager: - return self._connect(check=True) - - @contextlib.contextmanager - def _connect(self, check: bool = True) -> Iterator: - close = False - try: - if not self._connected: - with convert_exceptions(_ERROR_ON_CONNECT): - self._connection = pymysql.connect( - host=self._db_args.host, user=self._db_args.user, db=self._db_args.db, - password=self._db_args.password, port=self._db_args.port, - read_timeout=self.DB_CONNECTION_TIMEOUT, - write_timeout=self.DB_CONNECTION_TIMEOUT, - binary_prefix=True, charset='utf8mb4' - ) - self._cursor = self._connection.cursor(DictCursor) - self._connected = close = True - - if check: - self._connection_callback() - - try: - yield - except Exception: - self._connection.rollback() - raise - - finally: - if close: - self._connected = False - self._cursor.close() - self._connection.commit() - self._connection.close() - - @convert_exceptions('Could not create database') - def create(self, keys: Sequence[str], key_descriptions: Mapping[str, str] = None) -> None: - """Create and initialize database with empty tables. - - This must be called before opening the first connection. The MySQL database must not - exist already. - - Arguments: - - keys: Key names to use throughout the Terracotta database. - key_descriptions: Optional (but recommended) full-text description for some keys, - in the form of ``{key_name: description}``. - - """ - if key_descriptions is None: - key_descriptions = {} - else: - key_descriptions = dict(key_descriptions) - - if not all(k in keys for k in key_descriptions.keys()): - raise exceptions.InvalidKeyError('key description dict contains unknown keys') - - if not all(re.match(r'^\w+$', key) for key in keys): - raise exceptions.InvalidKeyError('key names must be alphanumeric') - - if any(key in self._RESERVED_KEYS for key in keys): - raise exceptions.InvalidKeyError(f'key names cannot be one of {self._RESERVED_KEYS!s}') - - for key in keys: - if key not in key_descriptions: - key_descriptions[key] = '' - - # total primary key length has an upper limit in MySQL - key_size = self._MAX_PRIMARY_KEY_LENGTH // len(keys) - key_type = f'VARCHAR({key_size})' - - connection = pymysql.connect( - host=self._db_args.host, user=self._db_args.user, - password=self._db_args.password, port=self._db_args.port, - read_timeout=self.DB_CONNECTION_TIMEOUT, - write_timeout=self.DB_CONNECTION_TIMEOUT, - binary_prefix=True, charset='utf8mb4' + def _create_database(self) -> None: + engine = sqla.create_engine( + f'{self._CONNECTION_PARAMETERS.scheme}+{self.SQL_DRIVER_TYPE}://' + f'{self._CONNECTION_PARAMETERS.netloc}', + echo=True, + future=True ) - - with connection, connection.cursor() as cursor: # type: ignore - cursor.execute(f'CREATE DATABASE {self._db_args.db}') - - with self._connect(check=False): - cursor = self._cursor - cursor.execute(f'CREATE TABLE terracotta (version VARCHAR(255)) ' - f'CHARACTER SET {self._CHARSET}') - cursor.execute('INSERT INTO terracotta VALUES (%s)', [str(__version__)]) - - cursor.execute(f'CREATE TABLE key_names (key_name {key_type}, ' - f'description VARCHAR(8000)) CHARACTER SET {self._CHARSET}') - key_rows = [(key, key_descriptions[key]) for key in keys] - cursor.executemany('INSERT INTO key_names VALUES (%s, %s)', key_rows) - - key_string = ', '.join([f'{key} {key_type}' for key in keys]) - cursor.execute(f'CREATE TABLE datasets ({key_string}, filepath VARCHAR(8000), ' - f'PRIMARY KEY({", ".join(keys)})) CHARACTER SET {self._CHARSET}') - - column_string = ', '.join(f'{col} {col_type}' for col, col_type - in self._METADATA_COLUMNS) - cursor.execute(f'CREATE TABLE metadata ({key_string}, {column_string}, ' - f'PRIMARY KEY ({", ".join(keys)})) CHARACTER SET {self._CHARSET}') - - # invalidate key cache - self._db_keys = None - - def get_keys(self) -> OrderedDict: - if self._db_keys is None: - self._db_keys = self._get_keys() - return self._db_keys - - @requires_connection - @convert_exceptions('Could not retrieve keys from database') - def _get_keys(self) -> OrderedDict: - out: OrderedDict = OrderedDict() - - cursor = self._cursor - cursor.execute('SELECT * FROM key_names') - key_rows = cursor.fetchall() or () - - for row in key_rows: - out[row['key_name']] = row['description'] - - return out - - @trace('get_datasets') - @requires_connection - @convert_exceptions('Could not retrieve datasets') - def get_datasets(self, where: Mapping[str, Union[str, List[str]]] = None, - page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], str]: - cursor = self._cursor - - if limit is not None: - # explicitly cast to int to prevent SQL injection - page_fragment = f'LIMIT {int(limit)} OFFSET {int(page) * int(limit)}' - else: - page_fragment = '' - - # sort by keys to ensure deterministic results - order_fragment = f'ORDER BY {", ".join(self.key_names)}' - - if where is None: - cursor.execute(f'SELECT * FROM datasets {order_fragment} {page_fragment}') - else: - if not all(key in self.key_names for key in where.keys()): - raise exceptions.InvalidKeyError('Encountered unrecognized keys in ' - 'where clause') - conditions = [] - values = [] - for key, value in where.items(): - if isinstance(value, str): - value = [value] - values.extend(value) - conditions.append(' OR '.join([f'{key}=%s'] * len(value))) - where_fragment = ' AND '.join([f'({condition})' for condition in conditions]) - cursor.execute( - f'SELECT * FROM datasets WHERE {where_fragment} {order_fragment} {page_fragment}', - values - ) - - def keytuple(row: Dict[str, Any]) -> Tuple[str, ...]: - return tuple(row[key] for key in self.key_names) - - datasets = {} - for row in cursor: - datasets[keytuple(row)] = row['filepath'] - - return datasets - - @staticmethod - def _encode_data(decoded: Mapping[str, Any]) -> Dict[str, Any]: - """Transform from internal format to database representation""" - encoded = { - 'bounds_north': decoded['bounds'][0], - 'bounds_east': decoded['bounds'][1], - 'bounds_south': decoded['bounds'][2], - 'bounds_west': decoded['bounds'][3], - 'convex_hull': json.dumps(decoded['convex_hull']), - 'valid_percentage': decoded['valid_percentage'], - 'min': decoded['range'][0], - 'max': decoded['range'][1], - 'mean': decoded['mean'], - 'stdev': decoded['stdev'], - 'percentiles': np.array(decoded['percentiles'], dtype='float32').tobytes(), - 'metadata': json.dumps(decoded['metadata']) - } - return encoded - - @staticmethod - def _decode_data(encoded: Mapping[str, Any]) -> Dict[str, Any]: - """Transform from database format to internal representation""" - decoded = { - 'bounds': tuple([encoded[f'bounds_{d}'] for d in ('north', 'east', 'south', 'west')]), - 'convex_hull': json.loads(encoded['convex_hull']), - 'valid_percentage': encoded['valid_percentage'], - 'range': (encoded['min'], encoded['max']), - 'mean': encoded['mean'], - 'stdev': encoded['stdev'], - 'percentiles': np.frombuffer(encoded['percentiles'], dtype='float32').tolist(), - 'metadata': json.loads(encoded['metadata']) - } - return decoded - - @trace('get_metadata') - @requires_connection - @convert_exceptions('Could not retrieve metadata') - def get_metadata(self, keys: Union[Sequence[str], Mapping[str, str]]) -> Dict[str, Any]: - keys = tuple(self._key_dict_to_sequence(keys)) - - if len(keys) != len(self.key_names): - raise exceptions.InvalidKeyError('Got wrong number of keys') - - cursor = self._cursor - - where_string = ' AND '.join([f'{key}=%s' for key in self.key_names]) - cursor.execute(f'SELECT * FROM metadata WHERE {where_string}', keys) - row = cursor.fetchone() - - if not row: # support lazy loading - filepath = self.get_datasets(dict(zip(self.key_names, keys))) - if not filepath: - raise exceptions.DatasetNotFoundError(f'No dataset found for given keys {keys}') - assert len(filepath) == 1 - - # compute metadata and try again - self.insert(keys, filepath[keys], skip_metadata=False) - cursor.execute(f'SELECT * FROM metadata WHERE {where_string}', keys) - row = cursor.fetchone() - - assert row - - data_columns, _ = zip(*self._METADATA_COLUMNS) - encoded_data = {col: row[col] for col in self.key_names + data_columns} - return self._decode_data(encoded_data) - - @trace('insert') - @requires_connection - @convert_exceptions('Could not write to database') - def insert(self, - keys: Union[Sequence[str], Mapping[str, str]], - filepath: str, *, - metadata: Mapping[str, Any] = None, - skip_metadata: bool = False, - override_path: str = None) -> None: - cursor = self._cursor - - if len(keys) != len(self.key_names): - raise exceptions.InvalidKeyError( - f'Got wrong number of keys (available keys: {self.key_names})' - ) - - if override_path is None: - override_path = filepath - - keys = self._key_dict_to_sequence(keys) - template_string = ', '.join(['%s'] * (len(keys) + 1)) - cursor.execute(f'REPLACE INTO datasets VALUES ({template_string})', - [*keys, override_path]) - - if metadata is None and not skip_metadata: - metadata = self.compute_metadata(filepath) - - if metadata is not None: - encoded_data = self._encode_data(metadata) - row_keys, row_values = zip(*encoded_data.items()) - template_string = ', '.join(['%s'] * (len(keys) + len(row_values))) - cursor.execute(f'REPLACE INTO metadata ({", ".join(self.key_names)}, ' - f'{", ".join(row_keys)}) VALUES ({template_string})', - [*keys, *row_values]) - - @trace('delete') - @requires_connection - @convert_exceptions('Could not write to database') - def delete(self, keys: Union[Sequence[str], Mapping[str, str]]) -> None: - cursor = self._cursor - - if len(keys) != len(self.key_names): - raise exceptions.InvalidKeyError( - f'Got wrong number of keys (available keys: {self.key_names})' - ) - - keys = self._key_dict_to_sequence(keys) - key_dict = dict(zip(self.key_names, keys)) - - if not self.get_datasets(key_dict): - raise exceptions.DatasetNotFoundError(f'No dataset found with keys {keys}') - - where_string = ' AND '.join([f'{key}=%s' for key in self.key_names]) - cursor.execute(f'DELETE FROM datasets WHERE {where_string}', keys) - cursor.execute(f'DELETE FROM metadata WHERE {where_string}', keys) + with engine.connect() as connection: + db_name = self._parse_db_name(self._CONNECTION_PARAMETERS) + connection.execute(sqla.text(f'CREATE DATABASE {db_name}')) + connection.commit() From 70377935417117090c75ed79bb1b716a1989a4a1 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Thu, 2 Dec 2021 14:29:49 +0100 Subject: [PATCH 006/107] Use default ParseResult structure instead of specific MySQLCredentials object --- tests/drivers/test_mysql.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/drivers/test_mysql.py b/tests/drivers/test_mysql.py index 8ddc546f..a8a6fe38 100644 --- a/tests/drivers/test_mysql.py +++ b/tests/drivers/test_mysql.py @@ -2,19 +2,19 @@ TEST_CASES = { 'mysql://root@localhost:5000/test': dict( - user='root', password='', host='localhost', port=5000, db='test' + username='root', password=None, hostname='localhost', port=5000, path='/test' ), 'root@localhost:5000/test': dict( - user='root', password='', host='localhost', port=5000, db='test' + username='root', password=None, hostname='localhost', port=5000, path='/test' ), 'mysql://root:foo@localhost/test': dict( - user='root', password='foo', host='localhost', port=3306, db='test' + username='root', password='foo', hostname='localhost', port=None, path='/test' ), 'mysql://localhost/test': dict( - password='', host='localhost', port=3306, db='test' + password=None, hostname='localhost', port=None, path='/test' ), 'localhost/test': dict( - password='', host='localhost', port=3306, db='test' + password=None, hostname='localhost', port=None, path='/test' ) } @@ -32,8 +32,10 @@ def test_path_parsing(case): drivers._DRIVER_CACHE = {} db = drivers.get_driver(case, provider='mysql') - db_args = db._db_args - for attr in ('user', 'password', 'host', 'port', 'db'): + db_args = db._CONNECTION_PARAMETERS + print(db_args) + for attr in ('username', 'password', 'hostname', 'port', 'path'): + print(attr) assert getattr(db_args, attr) == TEST_CASES[case].get(attr, None) From b657bd3fba3103d80605c036f84a0ee12c7e234b Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Thu, 2 Dec 2021 14:31:15 +0100 Subject: [PATCH 007/107] Remove leftover print statements --- tests/drivers/test_mysql.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/drivers/test_mysql.py b/tests/drivers/test_mysql.py index a8a6fe38..38bd0480 100644 --- a/tests/drivers/test_mysql.py +++ b/tests/drivers/test_mysql.py @@ -33,9 +33,7 @@ def test_path_parsing(case): db = drivers.get_driver(case, provider='mysql') db_args = db._CONNECTION_PARAMETERS - print(db_args) for attr in ('username', 'password', 'hostname', 'port', 'path'): - print(attr) assert getattr(db_args, attr) == TEST_CASES[case].get(attr, None) From 35a5ce87f56974388dc00f118dd30d14b4927d10 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Thu, 2 Dec 2021 16:39:00 +0100 Subject: [PATCH 008/107] Let functions optionally use unverified databases when connecting --- terracotta/drivers/base.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/terracotta/drivers/base.py b/terracotta/drivers/base.py index 58b67084..6e5afd19 100644 --- a/terracotta/drivers/base.py +++ b/terracotta/drivers/base.py @@ -13,10 +13,13 @@ T = TypeVar('T') -def requires_connection(fun: Callable[..., T]) -> Callable[..., T]: +def requires_connection(fun: Callable[..., T] = None, *, verify: bool = True) -> Callable[..., T]: + if fun is None: + return functools.partial(requires_connection, verify=verify) + @functools.wraps(fun) def inner(self: Driver, *args: Any, **kwargs: Any) -> T: - with self.connect(): + with self.connect(verify=verify): return fun(self, *args, **kwargs) return inner From fc5504c7a82dfec23a5c5ebe535c098f64624e23 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Thu, 2 Dec 2021 16:40:13 +0100 Subject: [PATCH 009/107] Fix column creation error and create tables on unverified database connection --- terracotta/drivers/common.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/terracotta/drivers/common.py b/terracotta/drivers/common.py index 58abff41..da80cba9 100644 --- a/terracotta/drivers/common.py +++ b/terracotta/drivers/common.py @@ -29,18 +29,18 @@ class RelationalDriver(RasterDriver, ABC): SQLA_TEXT = sqla.types.Text SQLA_BLOB = sqla.types.LargeBinary _METADATA_COLUMNS: Tuple[Tuple[str, sqla.types.TypeEngine], ...] = ( - ('bounds_north', SQLA_REAL()), - ('bounds_east', SQLA_REAL()), - ('bounds_south', SQLA_REAL()), - ('bounds_west', SQLA_REAL()), - ('convex_hull', SQLA_TEXT()), - ('valid_percentage', SQLA_REAL()), - ('min', SQLA_REAL()), - ('max', SQLA_REAL()), - ('mean', SQLA_REAL()), - ('stdev', SQLA_REAL()), - ('percentiles', SQLA_BLOB()), - ('metadata', SQLA_TEXT()) + ('bounds_north', SQLA_REAL), + ('bounds_east', SQLA_REAL), + ('bounds_south', SQLA_REAL), + ('bounds_west', SQLA_REAL), + ('convex_hull', SQLA_TEXT), + ('valid_percentage', SQLA_REAL), + ('min', SQLA_REAL), + ('max', SQLA_REAL), + ('mean', SQLA_REAL), + ('stdev', SQLA_REAL), + ('percentiles', SQLA_BLOB), + ('metadata', SQLA_TEXT) ) def __init__(self, path: str) -> None: @@ -56,7 +56,7 @@ def __init__(self, path: str) -> None: connection_string, echo=True, future=True, - connect_args={'timeout': db_connection_timeout} + #connect_args={'timeout': db_connection_timeout} ) self.sqla_metadata = sqla.MetaData() @@ -85,12 +85,13 @@ def _parse_connection_string(cls, connection_string: str) -> urlparse.ParseResul return con_params @contextlib.contextmanager - def connect(self) -> Iterator: + def connect(self, verify: bool = True) -> Iterator: if not self.connected: with self.sqla_engine.connect() as connection: self.connection = connection self.connected = True - self._verify_db_version() + if verify: + self._verify_db_version() yield self.connected = False self.connection = None @@ -149,7 +150,7 @@ def _create_database(self) -> None: # it may already exist for some vendors pass - @requires_connection + @requires_connection(verify=False) def _initialize_database( self, keys: Sequence[str], @@ -187,9 +188,10 @@ def _initialize_database( _ = sqla.Table( 'metadata', self.sqla_metadata, *[sqla.Column(key, sqla.types.String(self.SQL_KEY_SIZE), primary_key=True) for key in keys], # noqa: E501 - *self._METADATA_COLUMNS + *[sqla.Column(name, column_type()) for name, column_type in self._METADATA_COLUMNS] ) self.sqla_metadata.create_all(self.sqla_engine) + self.connection.commit() self.connection.execute( terracotta_table.insert().values(version=terracotta.__version__) From cba5f5bfc94963189e3d5284b2be47eaeaad0a9a Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Fri, 3 Dec 2021 08:24:36 +0100 Subject: [PATCH 010/107] Use indices on the keys in the DB to recreate correct order --- terracotta/drivers/common.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/terracotta/drivers/common.py b/terracotta/drivers/common.py index da80cba9..d7ec079c 100644 --- a/terracotta/drivers/common.py +++ b/terracotta/drivers/common.py @@ -177,7 +177,8 @@ def _initialize_database( key_names_table = sqla.Table( 'key_names', self.sqla_metadata, sqla.Column('key_name', sqla.types.String(self.SQL_KEY_SIZE), primary_key=True), - sqla.Column('description', sqla.types.String(8000)) + sqla.Column('description', sqla.types.String(8000)), + sqla.Column('index', sqla.types.Integer, unique=True) ) _ = sqla.Table( 'datasets', self.sqla_metadata, @@ -198,7 +199,7 @@ def _initialize_database( ) self.connection.execute( key_names_table.insert(), - [dict(key_name=key, description=key_descriptions.get(key, '')) for key in keys] + [dict(key_name=key, description=key_descriptions.get(key, ''), index=i) for i, key in enumerate(keys)] ) self.connection.commit() @@ -208,7 +209,12 @@ def _initialize_database( @requires_connection def get_keys(self) -> OrderedDict: keys_table = sqla.Table('key_names', self.sqla_metadata, autoload_with=self.sqla_engine) - result = self.connection.execute(keys_table.select()) + result = self.connection.execute( + sqla.select( + keys_table.c.get('key_name'), + keys_table.c.get('description') + ) + .order_by(keys_table.c.get('index'))) return OrderedDict(result.all()) @property From 50608a569452db77fdd0cd4cc670b945a0ad3328 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 7 Dec 2021 14:33:57 +0100 Subject: [PATCH 011/107] index on sqlalchemy: cba5f5b Use indices on the keys in the DB to recreate correct order From b0116cba6f70191cbc880477ff14332d05a1d394 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 7 Dec 2021 15:29:04 +0100 Subject: [PATCH 012/107] Align handling of errors with previous version --- terracotta/drivers/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/terracotta/drivers/common.py b/terracotta/drivers/common.py index 06df577d..bebae5d9 100644 --- a/terracotta/drivers/common.py +++ b/terracotta/drivers/common.py @@ -403,7 +403,7 @@ def insert( metadata_table.insert().values(**key_dict, **encoded_data) ) - self.connection.commit() + # self.connection.commit() @trace('delete') @requires_connection From 627bc306331f8bac90f701f86df02532477b3143 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 14 Dec 2021 16:07:47 +0100 Subject: [PATCH 013/107] Check for missing scheme instead of missing hostname --- terracotta/drivers/common.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/terracotta/drivers/common.py b/terracotta/drivers/common.py index bebae5d9..ee91180c 100644 --- a/terracotta/drivers/common.py +++ b/terracotta/drivers/common.py @@ -104,11 +104,9 @@ def __init__(self, path: str) -> None: def _parse_connection_string(cls, connection_string: str) -> urlparse.ParseResult: con_params = urlparse.urlparse(connection_string) - if not con_params.hostname: + if not con_params.scheme: con_params = urlparse.urlparse(f'{cls.SQL_DATABASE_SCHEME}://{connection_string}') - assert con_params.hostname is not None - if con_params.scheme != cls.SQL_DATABASE_SCHEME: raise ValueError(f'unsupported URL scheme "{con_params.scheme}"') From ac1cbbda5b897bb61b92ab7fa4a825060b9cd41b Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 14 Dec 2021 16:48:24 +0100 Subject: [PATCH 014/107] Tidy mysql driver code up a bit --- terracotta/drivers/mysql.py | 31 +++++-------------------------- 1 file changed, 5 insertions(+), 26 deletions(-) diff --git a/terracotta/drivers/mysql.py b/terracotta/drivers/mysql.py index 3238f654..aac63998 100644 --- a/terracotta/drivers/mysql.py +++ b/terracotta/drivers/mysql.py @@ -1,33 +1,14 @@ -"""drivers/sqlite.py +"""drivers/mysql.py MySQL-backed raster driver. Metadata is stored in a MySQL database, raster data is assumed to be present on disk. """ -import contextlib -import urllib.parse as urlparse -from typing import Iterator, TypeVar from urllib.parse import ParseResult -import pymysql import sqlalchemy as sqla -from terracotta import exceptions from terracotta.drivers.common import RelationalDriver -T = TypeVar('T') - -DEFAULT_PORT = 3306 - - -@contextlib.contextmanager -def convert_exceptions(msg: str) -> Iterator: - """Convert internal mysql exceptions to our InvalidDatabaseError""" - from pymysql import InternalError, OperationalError, ProgrammingError - try: - yield - except (OperationalError, InternalError, ProgrammingError) as exc: - raise exceptions.InvalidDatabaseError(msg) from exc - class MySQLDriver(RelationalDriver): """A MySQL-backed raster driver. @@ -49,7 +30,8 @@ class MySQLDriver(RelationalDriver): SQL_DATABASE_SCHEME = 'mysql' SQL_DRIVER_TYPE = 'pymysql' SQL_KEY_SIZE = 50 - _CHARSET: str = 'utf8mb4' + + DEFAULT_PORT = 3306 def __init__(self, mysql_path: str) -> None: """Initialize the MySQLDriver. @@ -67,12 +49,9 @@ def __init__(self, mysql_path: str) -> None: @classmethod def _normalize_path(cls, path: str) -> str: - parts = urlparse.urlparse(path) - - if not parts.hostname: - parts = urlparse.urlparse(f'mysql://{path}') + parts = cls._parse_connection_string(path) - path = f'{parts.scheme}://{parts.hostname}:{parts.port or DEFAULT_PORT}{parts.path}' + path = f'{parts.scheme}://{parts.hostname}:{parts.port or cls.DEFAULT_PORT}{parts.path}' path = path.rstrip('/') return path From 658d7cfb1b3baea7e1f7294650005c0f644c4f1f Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 14 Dec 2021 16:56:01 +0100 Subject: [PATCH 015/107] Remove leftover debugging stuff --- tests/drivers/test_drivers.py | 3 +-- tests/drivers/test_raster_drivers.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/drivers/test_drivers.py b/tests/drivers/test_drivers.py index ce64deb3..29918702 100644 --- a/tests/drivers/test_drivers.py +++ b/tests/drivers/test_drivers.py @@ -140,7 +140,6 @@ def test_connect_before_create(driver_path, provider): with pytest.raises(exceptions.InvalidDatabaseError) as exc: with db.connect(): - import ipdb; ipdb.set_trace() pass assert 'ran driver.create()' in str(exc.value) @@ -208,7 +207,7 @@ def test_version_conflict(driver_path, provider, raster_file, monkeypatch): with monkeypatch.context() as m: fake_version = '0.0.0' - m.setattr(f'terracotta.__version__', fake_version) + m.setattr('terracotta.__version__', fake_version) db.db_version_verified = False with pytest.raises(exceptions.InvalidDatabaseError) as exc: diff --git a/tests/drivers/test_raster_drivers.py b/tests/drivers/test_raster_drivers.py index 61475980..c3abaac2 100644 --- a/tests/drivers/test_raster_drivers.py +++ b/tests/drivers/test_raster_drivers.py @@ -275,7 +275,6 @@ def test_multiprocess_insertion(driver_path, provider, raster_file): pass datasets = db.get_datasets() - print(sorted((k[0] for k in datasets.keys()))) assert all((key,) in datasets for key in key_vals) data1 = db.get_metadata(['77']) From 0d712263bdab4241e8c8389046d03d184bc832d5 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 14 Dec 2021 16:58:46 +0100 Subject: [PATCH 016/107] Add sqla InvalidRequestError to list of exceptions to convert (and tidy up a bit) --- terracotta/drivers/common.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/terracotta/drivers/common.py b/terracotta/drivers/common.py index ee91180c..ae3876ed 100644 --- a/terracotta/drivers/common.py +++ b/terracotta/drivers/common.py @@ -5,7 +5,7 @@ import urllib.parse as urlparse from abc import ABC, abstractmethod from collections import OrderedDict -from typing import (Any, Dict, Iterable, Iterator, List, Mapping, Optional, Sequence, +from typing import (Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union) import numpy as np @@ -17,7 +17,6 @@ from terracotta.drivers.raster_base import RasterDriver from terracotta.profile import trace - _ERROR_ON_CONNECT = ( 'Could not connect to database. Make sure that the given path points ' 'to a valid Terracotta database, and that you ran driver.create().' @@ -43,8 +42,6 @@ def convert_exceptions_context(error_message, exceptions_to_convert): class RelationalDriver(RasterDriver, ABC): - # NOTE: `convert_exceptions` decorators are NOT added yet - SQL_DATABASE_SCHEME: str # The database flavour, eg mysql, sqlite, etc SQL_DRIVER_TYPE: str # The actual database driver, eg pymysql, sqlite3, etc SQL_KEY_SIZE: int @@ -52,7 +49,8 @@ class RelationalDriver(RasterDriver, ABC): DATABASE_DRIVER_EXCEPTIONS_TO_CONVERT: Tuple[Type[Exception]] = ( sqla.exc.OperationalError, sqla.exc.InternalError, - sqla.exc.ProgrammingError + sqla.exc.ProgrammingError, + sqla.exc.InvalidRequestError, ) SQLA_REAL = functools.partial(sqla.types.Float, precision=8) From 63b2c70c38f91e5ef2c544532255eea10a270107 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 14 Dec 2021 17:04:16 +0100 Subject: [PATCH 017/107] Refactor sqlite driver to use the common RelationalDriver --- terracotta/drivers/sqlite.py | 349 ++--------------------------------- 1 file changed, 14 insertions(+), 335 deletions(-) diff --git a/terracotta/drivers/sqlite.py b/terracotta/drivers/sqlite.py index f144dc7b..39f1387c 100644 --- a/terracotta/drivers/sqlite.py +++ b/terracotta/drivers/sqlite.py @@ -4,40 +4,14 @@ to be present on disk. """ -from typing import Any, List, Sequence, Mapping, Tuple, Union, Iterator, Dict, cast import os -import contextlib -from contextlib import AbstractContextManager -import json -import re -import sqlite3 -from sqlite3 import Connection from pathlib import Path -from collections import OrderedDict +from typing import Union -import numpy as np +from terracotta.drivers.common import RelationalDriver -from terracotta import get_settings, exceptions, __version__ -from terracotta.profile import trace -from terracotta.drivers.base import requires_connection -from terracotta.drivers.raster_base import RasterDriver -_ERROR_ON_CONNECT = ( - 'Could not connect to database. Make sure that the given path points ' - 'to a valid Terracotta database, and that you ran driver.create().' -) - - -@contextlib.contextmanager -def convert_exceptions(msg: str) -> Iterator: - """Convert internal sqlite exceptions to our InvalidDatabaseError""" - try: - yield - except sqlite3.OperationalError as exc: - raise exceptions.InvalidDatabaseError(msg) from exc - - -class SQLiteDriver(RasterDriver): +class SQLiteDriver(RelationalDriver): """An SQLite-backed raster driver. Assumes raster data to be present in separate GDAL-readable files on disk or remotely. @@ -60,7 +34,7 @@ class SQLiteDriver(RasterDriver): - ``datasets``: Maps key values to physical raster path. - ``metadata``: Contains actual metadata as separate columns. Indexed via key values. - This driver caches raster data, but not metadata. + This driver caches raster data and key names, but not metadata. Warning: @@ -68,21 +42,9 @@ class SQLiteDriver(RasterDriver): outside the main thread. """ - _KEY_TYPE: str = 'VARCHAR[256]' - _METADATA_COLUMNS: Tuple[Tuple[str, ...], ...] = ( - ('bounds_north', 'REAL'), - ('bounds_east', 'REAL'), - ('bounds_south', 'REAL'), - ('bounds_west', 'REAL'), - ('convex_hull', 'VARCHAR[max]'), - ('valid_percentage', 'REAL'), - ('min', 'REAL'), - ('max', 'REAL'), - ('mean', 'REAL'), - ('stdev', 'REAL'), - ('percentiles', 'BLOB'), - ('metadata', 'VARCHAR[max]') - ) + SQL_DATABASE_SCHEME = 'sqlite' + SQL_DRIVER_TYPE = 'pysqlite' + SQL_KEY_SIZE = 256 def __init__(self, path: Union[str, Path]) -> None: """Initialize the SQLiteDriver. @@ -95,298 +57,15 @@ def __init__(self, path: Union[str, Path]) -> None: """ path = str(path) - - settings = get_settings() - self.DB_CONNECTION_TIMEOUT: int = settings.DB_CONNECTION_TIMEOUT - self.LAZY_LOADING_MAX_SHAPE: Tuple[int, int] = settings.LAZY_LOADING_MAX_SHAPE - - self._connection: Connection - self._connected = False - - super().__init__(os.path.realpath(path)) + super().__init__(f'sqlite:///{os.path.realpath(path)}') @classmethod def _normalize_path(cls, path: str) -> str: - return os.path.normpath(os.path.realpath(path)) - - def connect(self, verify: bool = True) -> AbstractContextManager: - return self._connect(check=True) - - @contextlib.contextmanager - def _connect(self, check: bool = True) -> Iterator: - try: - close = False - if not self._connected: - with convert_exceptions(_ERROR_ON_CONNECT): - self._connection = sqlite3.connect( - self.path, timeout=self.DB_CONNECTION_TIMEOUT - ) - self._connection.row_factory = sqlite3.Row - self._connected = close = True - - if check: - self._connection_callback() - - try: - yield - except Exception: - self._connection.rollback() - raise - - finally: - if close: - self._connected = False - self._connection.commit() - self._connection.close() - - @requires_connection - @convert_exceptions(_ERROR_ON_CONNECT) - def _get_db_version(self) -> str: - """Terracotta version used to create the database""" - conn = self._connection - db_row = conn.execute('SELECT version from terracotta').fetchone() - return db_row['version'] - - db_version = cast(str, property(_get_db_version)) - - def _connection_callback(self) -> None: - """Called after opening a new connection""" - # check for version compatibility - def versiontuple(version_string: str) -> Sequence[str]: - return version_string.split('.') - - db_version = self.db_version - current_version = __version__ - - if versiontuple(db_version)[:2] != versiontuple(current_version)[:2]: - raise exceptions.InvalidDatabaseError( - f'Version conflict: database was created in v{db_version}, ' - f'but this is v{current_version}' - ) - - def _get_key_names(self) -> Tuple[str, ...]: - """Names of all keys defined by the database""" - return tuple(self.get_keys().keys()) - - key_names = cast(Tuple[str], property(_get_key_names)) - - @convert_exceptions('Could not create database') - def create(self, keys: Sequence[str], key_descriptions: Mapping[str, str] = None) -> None: - """Create and initialize database with empty tables. - - This must be called before opening the first connection. Tables must not exist already. - - Arguments: - - keys: Key names to use throughout the Terracotta database. - key_descriptions: Optional (but recommended) full-text description for some keys, - in the form of ``{key_name: description}``. + con_params = cls._parse_connection_string(path) + return os.path.normpath(os.path.realpath(con_params.path)) + def _create_database(self) -> None: + """The database is automatically created by the sqlite driver on connection, + so no need to do anything here """ - if key_descriptions is None: - key_descriptions = {} - else: - key_descriptions = dict(key_descriptions) - - if not all(k in keys for k in key_descriptions.keys()): - raise exceptions.InvalidKeyError('key description dict contains unknown keys') - - if not all(re.match(r'^\w+$', key) for key in keys): - raise exceptions.InvalidKeyError('key names must be alphanumeric') - - if any(key in self._RESERVED_KEYS for key in keys): - raise exceptions.InvalidKeyError(f'key names cannot be one of {self._RESERVED_KEYS!s}') - - for key in keys: - if key not in key_descriptions: - key_descriptions[key] = '' - - with self._connect(check=False): - conn = self._connection - conn.execute('CREATE TABLE terracotta (version VARCHAR[255])') - conn.execute('INSERT INTO terracotta VALUES (?)', [str(__version__)]) - - conn.execute(f'CREATE TABLE keys (key {self._KEY_TYPE}, description VARCHAR[max])') - key_rows = [(key, key_descriptions[key]) for key in keys] - conn.executemany('INSERT INTO keys VALUES (?, ?)', key_rows) - - key_string = ', '.join([f'{key} {self._KEY_TYPE}' for key in keys]) - conn.execute(f'CREATE TABLE datasets ({key_string}, filepath VARCHAR[8000], ' - f'PRIMARY KEY({", ".join(keys)}))') - - column_string = ', '.join(f'{col} {col_type}' for col, col_type - in self._METADATA_COLUMNS) - conn.execute(f'CREATE TABLE metadata ({key_string}, {column_string}, ' - f'PRIMARY KEY ({", ".join(keys)}))') - - @requires_connection - @convert_exceptions('Could not retrieve keys from database') - def get_keys(self) -> OrderedDict: - conn = self._connection - key_rows = conn.execute('SELECT * FROM keys') - - out: OrderedDict = OrderedDict() - for row in key_rows: - out[row['key']] = row['description'] - return out - - @trace('get_datasets') - @requires_connection - @convert_exceptions('Could not retrieve datasets') - def get_datasets(self, where: Mapping[str, Union[str, List[str]]] = None, - page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], str]: - conn = self._connection - - if limit is not None: - # explicitly cast to int to prevent SQL injection - page_fragment = f'LIMIT {int(limit)} OFFSET {int(page) * int(limit)}' - else: - page_fragment = '' - - # sort by keys to ensure deterministic results - order_fragment = f'ORDER BY {", ".join(self.key_names)}' - - if where is None: - rows = conn.execute(f'SELECT * FROM datasets {order_fragment} {page_fragment}') - else: - if not all(key in self.key_names for key in where.keys()): - raise exceptions.InvalidKeyError('Encountered unrecognized keys in ' - 'where clause') - conditions = [] - values = [] - for key, value in where.items(): - if isinstance(value, str): - value = [value] - values.extend(value) - conditions.append(' OR '.join([f'{key}=?'] * len(value))) - where_fragment = ' AND '.join([f'({condition})' for condition in conditions]) - rows = conn.execute( - f'SELECT * FROM datasets WHERE {where_fragment} {order_fragment} {page_fragment}', - values - ) - - def keytuple(row: Dict[str, Any]) -> Tuple[str, ...]: - return tuple(row[key] for key in self.key_names) - - return {keytuple(row): row['filepath'] for row in rows} - - @staticmethod - def _encode_data(decoded: Mapping[str, Any]) -> Dict[str, Any]: - """Transform from internal format to database representation""" - encoded = { - 'bounds_north': decoded['bounds'][0], - 'bounds_east': decoded['bounds'][1], - 'bounds_south': decoded['bounds'][2], - 'bounds_west': decoded['bounds'][3], - 'convex_hull': json.dumps(decoded['convex_hull']), - 'valid_percentage': decoded['valid_percentage'], - 'min': decoded['range'][0], - 'max': decoded['range'][1], - 'mean': decoded['mean'], - 'stdev': decoded['stdev'], - 'percentiles': np.array(decoded['percentiles'], dtype='float32').tobytes(), - 'metadata': json.dumps(decoded['metadata']) - } - return encoded - - @staticmethod - def _decode_data(encoded: Mapping[str, Any]) -> Dict[str, Any]: - """Transform from database format to internal representation""" - decoded = { - 'bounds': tuple([encoded[f'bounds_{d}'] for d in ('north', 'east', 'south', 'west')]), - 'convex_hull': json.loads(encoded['convex_hull']), - 'valid_percentage': encoded['valid_percentage'], - 'range': (encoded['min'], encoded['max']), - 'mean': encoded['mean'], - 'stdev': encoded['stdev'], - 'percentiles': np.frombuffer(encoded['percentiles'], dtype='float32').tolist(), - 'metadata': json.loads(encoded['metadata']) - } - return decoded - - @trace('get_metadata') - @requires_connection - @convert_exceptions('Could not retrieve metadata') - def get_metadata(self, keys: Union[Sequence[str], Mapping[str, str]]) -> Dict[str, Any]: - keys = tuple(self._key_dict_to_sequence(keys)) - - if len(keys) != len(self.key_names): - raise exceptions.InvalidKeyError( - f'Got wrong number of keys (available keys: {self.key_names})' - ) - - conn = self._connection - - where_string = ' AND '.join([f'{key}=?' for key in self.key_names]) - row = conn.execute(f'SELECT * FROM metadata WHERE {where_string}', keys).fetchone() - - if not row: # support lazy loading - filepath = self.get_datasets(dict(zip(self.key_names, keys)), page=0, limit=1) - if not filepath: - raise exceptions.DatasetNotFoundError(f'No dataset found for given keys {keys}') - - # compute metadata and try again - metadata = self.compute_metadata(filepath[keys], max_shape=self.LAZY_LOADING_MAX_SHAPE) - self.insert(keys, filepath[keys], metadata=metadata) - row = conn.execute(f'SELECT * FROM metadata WHERE {where_string}', keys).fetchone() - - assert row - - data_columns, _ = zip(*self._METADATA_COLUMNS) - encoded_data = {col: row[col] for col in self.key_names + data_columns} - return self._decode_data(encoded_data) - - @trace('insert') - @requires_connection - @convert_exceptions('Could not write to database') - def insert(self, - keys: Union[Sequence[str], Mapping[str, str]], - filepath: str, *, - metadata: Mapping[str, Any] = None, - skip_metadata: bool = False, - override_path: str = None) -> None: - conn = self._connection - - if len(keys) != len(self.key_names): - raise exceptions.InvalidKeyError( - f'Got wrong number of keys (available keys: {self.key_names})' - ) - - if override_path is None: - override_path = filepath - - keys = self._key_dict_to_sequence(keys) - template_string = ', '.join(['?'] * (len(keys) + 1)) - conn.execute(f'INSERT OR REPLACE INTO datasets VALUES ({template_string})', - [*keys, override_path]) - - if metadata is None and not skip_metadata: - metadata = self.compute_metadata(filepath) - - if metadata is not None: - encoded_data = self._encode_data(metadata) - row_keys, row_values = zip(*encoded_data.items()) - template_string = ', '.join(['?'] * (len(keys) + len(row_values))) - conn.execute(f'INSERT OR REPLACE INTO metadata ({", ".join(self.key_names)}, ' - f'{", ".join(row_keys)}) VALUES ({template_string})', [*keys, *row_values]) - - @trace('delete') - @requires_connection - @convert_exceptions('Could not write to database') - def delete(self, keys: Union[Sequence[str], Mapping[str, str]]) -> None: - conn = self._connection - - if len(keys) != len(self.key_names): - raise exceptions.InvalidKeyError( - f'Got wrong number of keys (available keys: {self.key_names})' - ) - - keys = self._key_dict_to_sequence(keys) - key_dict = dict(zip(self.key_names, keys)) - - if not self.get_datasets(key_dict): - raise exceptions.DatasetNotFoundError(f'No dataset found with keys {keys}') - - where_string = ' AND '.join([f'{key}=?' for key in self.key_names]) - conn.execute(f'DELETE FROM datasets WHERE {where_string}', keys) - conn.execute(f'DELETE FROM metadata WHERE {where_string}', keys) + pass From a5d9a1b82ef1a12ed1c59864273134f7da9c060a Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Fri, 17 Dec 2021 13:04:03 +0100 Subject: [PATCH 018/107] Ugly hack to make sqlite remote driver work with new and old structures --- terracotta/drivers/sqlite_remote.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/terracotta/drivers/sqlite_remote.py b/terracotta/drivers/sqlite_remote.py index b4f5de37..924bd85e 100644 --- a/terracotta/drivers/sqlite_remote.py +++ b/terracotta/drivers/sqlite_remote.py @@ -1,4 +1,4 @@ -"""drivers/sqlite.py +"""drivers/sqlite_remote.py SQLite-backed raster driver. Metadata is stored in an SQLite database, raster data is assumed to be present on disk. @@ -103,6 +103,7 @@ def __init__(self, remote_path: str) -> None: self._last_updated = -float('inf') super().__init__(local_db_file.name) + self.path = local_db_file.name @classmethod def _normalize_path(cls, path: str) -> str: @@ -129,9 +130,9 @@ def _update_db(self, remote_path: str, local_path: str) -> None: _update_from_s3(remote_path, local_path) self._last_updated = time.time() - def _connection_callback(self) -> None: + def _verify_db_version(self) -> None: self._update_db(self._remote_path, self.path) - super()._connection_callback() + super()._verify_db_version() def create(self, *args: Any, **kwargs: Any) -> None: raise NotImplementedError('Remote SQLite databases are read-only') From e21fdeef468564594c86ce7c1480a785ee5ebd57 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Fri, 17 Dec 2021 13:09:38 +0100 Subject: [PATCH 019/107] Require SQLAlchemy --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 2ad995f6..cb0b02ef 100644 --- a/setup.py +++ b/setup.py @@ -71,6 +71,7 @@ 'shapely', 'rasterio>=1.0,<=1.1.8', # TODO: unpin when performance issues with GDAL3 are fixed 'shapely', + 'sqlalchemy', 'toml', 'tqdm' ], From e239d094a051087e3a9ca0cd194fef17f9423443 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 21 Dec 2021 09:51:00 +0100 Subject: [PATCH 020/107] Cleanup driver code and adhere flake8, mypy --- terracotta/drivers/base.py | 17 +++++++++------ terracotta/drivers/common.py | 32 +++++++++++++++++++---------- terracotta/drivers/mysql.py | 1 + terracotta/drivers/sqlite.py | 1 + terracotta/drivers/sqlite_remote.py | 3 +-- 5 files changed, 35 insertions(+), 19 deletions(-) diff --git a/terracotta/drivers/base.py b/terracotta/drivers/base.py index 6e5afd19..2eaf5ca3 100644 --- a/terracotta/drivers/base.py +++ b/terracotta/drivers/base.py @@ -3,24 +3,29 @@ Base class for drivers. """ -from typing import Callable, List, Mapping, Any, Tuple, Sequence, Dict, Union, TypeVar +import contextlib +import functools from abc import ABC, abstractmethod from collections import OrderedDict -import functools -import contextlib +from typing import (Any, Callable, Dict, List, Mapping, Sequence, Tuple, + TypeVar, Union) Number = TypeVar('Number', int, float) T = TypeVar('T') -def requires_connection(fun: Callable[..., T] = None, *, verify: bool = True) -> Callable[..., T]: +def requires_connection( + fun: Callable[..., T] = None, *, + verify: bool = True +) -> Union[Callable[..., T], functools.partial]: if fun is None: return functools.partial(requires_connection, verify=verify) @functools.wraps(fun) def inner(self: Driver, *args: Any, **kwargs: Any) -> T: with self.connect(verify=verify): - return fun(self, *args, **kwargs) + # Apparently mypy thinks fun might still be None, hence the ignore: + return fun(self, *args, **kwargs) # type: ignore return inner @@ -57,7 +62,7 @@ def create(self, keys: Sequence[str], *, pass @abstractmethod - def connect(self) -> contextlib.AbstractContextManager: + def connect(self, verify: bool = True) -> contextlib.AbstractContextManager: """Context manager to connect to a given database and clean up on exit. This allows you to pool interactions with the database to prevent possibly diff --git a/terracotta/drivers/common.py b/terracotta/drivers/common.py index ae3876ed..f635889a 100644 --- a/terracotta/drivers/common.py +++ b/terracotta/drivers/common.py @@ -5,8 +5,8 @@ import urllib.parse as urlparse from abc import ABC, abstractmethod from collections import OrderedDict -from typing import (Any, Dict, Iterator, List, Mapping, Optional, Sequence, - Tuple, Type, Union) +from typing import (Any, Callable, Dict, Iterator, List, Mapping, Optional, + Sequence, Tuple, Type, Union) import numpy as np import sqlalchemy as sqla @@ -23,18 +23,24 @@ ) -def convert_exceptions(error_message: str): - def decorator(fun): +def convert_exceptions(error_message: str) -> Callable[..., Any]: + def decorator(fun: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(fun) - def inner(self: 'RelationalDriver', *args, **kwargs): - with convert_exceptions_context(error_message, self.DATABASE_DRIVER_EXCEPTIONS_TO_CONVERT): + def inner(self: 'RelationalDriver', *args: Any, **kwargs: Any) -> Any: + with convert_exceptions_context( + error_message, + self.DATABASE_DRIVER_EXCEPTIONS_TO_CONVERT + ): return fun(self, *args, **kwargs) return inner return decorator @contextlib.contextmanager -def convert_exceptions_context(error_message, exceptions_to_convert): +def convert_exceptions_context( + error_message: str, + exceptions_to_convert: Union[Type[Exception], Tuple[Type[Exception], ...]] +) -> Iterator: try: yield except exceptions_to_convert as exception: @@ -45,8 +51,9 @@ class RelationalDriver(RasterDriver, ABC): SQL_DATABASE_SCHEME: str # The database flavour, eg mysql, sqlite, etc SQL_DRIVER_TYPE: str # The actual database driver, eg pymysql, sqlite3, etc SQL_KEY_SIZE: int + SQL_TIMEOUT_KEY: str - DATABASE_DRIVER_EXCEPTIONS_TO_CONVERT: Tuple[Type[Exception]] = ( + DATABASE_DRIVER_EXCEPTIONS_TO_CONVERT: Tuple[Type[Exception], ...] = ( sqla.exc.OperationalError, sqla.exc.InternalError, sqla.exc.ProgrammingError, @@ -84,7 +91,7 @@ def __init__(self, path: str) -> None: connection_string, echo=True, future=True, - #connect_args={'timeout': db_connection_timeout} + connect_args={self.SQL_TIMEOUT_KEY: db_connection_timeout} ) self.sqla_metadata = sqla.MetaData() @@ -113,7 +120,7 @@ def _parse_connection_string(cls, connection_string: str) -> urlparse.ParseResul @contextlib.contextmanager def connect(self, verify: bool = True) -> Iterator: if not self.connected: - def _connect_with_exceptions_converted(): + def _connect_with_exceptions_converted() -> Connection: with convert_exceptions_context(_ERROR_ON_CONNECT, sqla.exc.OperationalError): connection = self.sqla_engine.connect() return connection @@ -238,7 +245,10 @@ def _initialize_database( ) self.connection.execute( key_names_table.insert(), - [dict(key_name=key, description=key_descriptions.get(key, ''), index=i) for i, key in enumerate(keys)] + [ + dict(key_name=key, description=key_descriptions.get(key, ''), index=i) + for i, key in enumerate(keys) + ] ) # self.connection.commit() diff --git a/terracotta/drivers/mysql.py b/terracotta/drivers/mysql.py index aac63998..f9ea3721 100644 --- a/terracotta/drivers/mysql.py +++ b/terracotta/drivers/mysql.py @@ -30,6 +30,7 @@ class MySQLDriver(RelationalDriver): SQL_DATABASE_SCHEME = 'mysql' SQL_DRIVER_TYPE = 'pymysql' SQL_KEY_SIZE = 50 + SQL_TIMEOUT_KEY = 'connect_timeout' DEFAULT_PORT = 3306 diff --git a/terracotta/drivers/sqlite.py b/terracotta/drivers/sqlite.py index 39f1387c..8f919ec4 100644 --- a/terracotta/drivers/sqlite.py +++ b/terracotta/drivers/sqlite.py @@ -45,6 +45,7 @@ class SQLiteDriver(RelationalDriver): SQL_DATABASE_SCHEME = 'sqlite' SQL_DRIVER_TYPE = 'pysqlite' SQL_KEY_SIZE = 256 + SQL_TIMEOUT_KEY = 'timeout' def __init__(self, path: Union[str, Path]) -> None: """Initialize the SQLiteDriver. diff --git a/terracotta/drivers/sqlite_remote.py b/terracotta/drivers/sqlite_remote.py index 924bd85e..5540f6e2 100644 --- a/terracotta/drivers/sqlite_remote.py +++ b/terracotta/drivers/sqlite_remote.py @@ -23,11 +23,10 @@ @contextlib.contextmanager def convert_exceptions(msg: str) -> Iterator: """Convert internal sqlite and boto exceptions to our InvalidDatabaseError""" - import sqlite3 import botocore.exceptions try: yield - except (sqlite3.OperationalError, botocore.exceptions.ClientError) as exc: + except botocore.exceptions.ClientError as exc: raise exceptions.InvalidDatabaseError(msg) from exc From d26ebc7cbb266adefba899c531884375d3ac7991 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 21 Dec 2021 09:55:05 +0100 Subject: [PATCH 021/107] Rename common to relational_base --- terracotta/drivers/mysql.py | 2 +- terracotta/drivers/{common.py => relational_base.py} | 0 terracotta/drivers/sqlite.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename terracotta/drivers/{common.py => relational_base.py} (100%) diff --git a/terracotta/drivers/mysql.py b/terracotta/drivers/mysql.py index f9ea3721..3f20c03c 100644 --- a/terracotta/drivers/mysql.py +++ b/terracotta/drivers/mysql.py @@ -7,7 +7,7 @@ from urllib.parse import ParseResult import sqlalchemy as sqla -from terracotta.drivers.common import RelationalDriver +from terracotta.drivers.relational_base import RelationalDriver class MySQLDriver(RelationalDriver): diff --git a/terracotta/drivers/common.py b/terracotta/drivers/relational_base.py similarity index 100% rename from terracotta/drivers/common.py rename to terracotta/drivers/relational_base.py diff --git a/terracotta/drivers/sqlite.py b/terracotta/drivers/sqlite.py index 8f919ec4..637a91b3 100644 --- a/terracotta/drivers/sqlite.py +++ b/terracotta/drivers/sqlite.py @@ -8,7 +8,7 @@ from pathlib import Path from typing import Union -from terracotta.drivers.common import RelationalDriver +from terracotta.drivers.relational_base import RelationalDriver class SQLiteDriver(RelationalDriver): From f51175bb920424f5d861e8c538a5c9fd98c67eaf Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 21 Dec 2021 10:31:24 +0100 Subject: [PATCH 022/107] Add and use ._local_path field on RemoteSQLiteDriver instead of inherited .path --- terracotta/drivers/sqlite_remote.py | 6 +++--- tests/drivers/test_sqlite_remote.py | 20 +++++++++++--------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/terracotta/drivers/sqlite_remote.py b/terracotta/drivers/sqlite_remote.py index 5540f6e2..d9c5816a 100644 --- a/terracotta/drivers/sqlite_remote.py +++ b/terracotta/drivers/sqlite_remote.py @@ -98,11 +98,11 @@ def __init__(self, remote_path: str) -> None: ) local_db_file.close() + self._local_path = local_db_file.name self._remote_path = str(remote_path) self._last_updated = -float('inf') super().__init__(local_db_file.name) - self.path = local_db_file.name @classmethod def _normalize_path(cls, path: str) -> str: @@ -130,7 +130,7 @@ def _update_db(self, remote_path: str, local_path: str) -> None: self._last_updated = time.time() def _verify_db_version(self) -> None: - self._update_db(self._remote_path, self.path) + self._update_db(self._remote_path, self._local_path) super()._verify_db_version() def create(self, *args: Any, **kwargs: Any) -> None: @@ -144,4 +144,4 @@ def delete(self, *args: Any, **kwargs: Any) -> None: def __del__(self) -> None: """Clean up temporary database upon exit""" - self.__rm(self.path) + self.__rm(self._local_path) diff --git a/tests/drivers/test_sqlite_remote.py b/tests/drivers/test_sqlite_remote.py index 9750bca8..8f6857ae 100644 --- a/tests/drivers/test_sqlite_remote.py +++ b/tests/drivers/test_sqlite_remote.py @@ -11,6 +11,8 @@ import pytest +from terracotta.drivers.sqlite_remote import RemoteSQLiteDriver + boto3 = pytest.importorskip('boto3') moto = pytest.importorskip('moto') @@ -104,33 +106,33 @@ def test_remote_database_cache(s3_db_factory, raster_file, monkeypatch): from terracotta import get_driver - driver = get_driver(dbpath) + driver: RemoteSQLiteDriver = get_driver(dbpath) driver._last_updated = -float('inf') with driver.connect(): assert driver.key_names == keys assert driver.get_datasets() == {} - modification_date = os.path.getmtime(driver.path) + modification_date = os.path.getmtime(driver._local_path) s3_db_factory(keys, datasets={('some', 'value'): str(raster_file)}) # no change yet assert driver.get_datasets() == {} - assert os.path.getmtime(driver.path) == modification_date + assert os.path.getmtime(driver._local_path) == modification_date # check if remote db is cached correctly driver._last_updated = time.time() with driver.connect(): # db connection is cached; so still no change assert driver.get_datasets() == {} - assert os.path.getmtime(driver.path) == modification_date + assert os.path.getmtime(driver._local_path) == modification_date # invalidate cache driver._last_updated = -float('inf') with driver.connect(): # now db is updated on reconnect assert list(driver.get_datasets().keys()) == [('some', 'value')] - assert os.path.getmtime(driver.path) != modification_date + assert os.path.getmtime(driver._local_path) != modification_date @moto.mock_s3 @@ -159,15 +161,15 @@ def test_destructor(s3_db_factory, raster_file, capsys): from terracotta import get_driver - driver = get_driver(dbpath) - assert os.path.isfile(driver.path) + driver: RemoteSQLiteDriver = get_driver(dbpath) + assert os.path.isfile(driver._local_path) driver.__del__() - assert not os.path.isfile(driver.path) + assert not os.path.isfile(driver._local_path) captured = capsys.readouterr() assert 'Exception ignored' not in captured.err # re-create file to prevent actual destructor from failing - with open(driver.path, 'w'): + with open(driver._local_path, 'w'): pass From 6d435bc598217d0505e99074726b6679d423e9d2 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 21 Dec 2021 13:16:39 +0100 Subject: [PATCH 023/107] Allow Path input and clean code --- terracotta/drivers/sqlite_remote.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/terracotta/drivers/sqlite_remote.py b/terracotta/drivers/sqlite_remote.py index d9c5816a..c6480031 100644 --- a/terracotta/drivers/sqlite_remote.py +++ b/terracotta/drivers/sqlite_remote.py @@ -4,16 +4,17 @@ to be present on disk. """ -from typing import Any, Iterator +import contextlib +import logging import os -import time -import tempfile import shutil -import logging -import contextlib +import tempfile +import time import urllib.parse as urlparse +from pathlib import Path +from typing import Any, Iterator, Union -from terracotta import get_settings, exceptions +from terracotta import exceptions, get_settings from terracotta.drivers.sqlite import SQLiteDriver from terracotta.profile import trace @@ -74,7 +75,7 @@ class RemoteSQLiteDriver(SQLiteDriver): """ - def __init__(self, remote_path: str) -> None: + def __init__(self, remote_path: Union[str, Path]) -> None: """Initialize the RemoteSQLiteDriver. This should not be called directly, use :func:`~terracotta.get_driver` instead. From 2b14c8356d14afeff8f60a2f3c2c18efa58dc810 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 21 Dec 2021 13:18:54 +0100 Subject: [PATCH 024/107] Let drivers handle path resolving internally --- terracotta/drivers/__init__.py | 4 ++-- terracotta/drivers/relational_base.py | 13 ++++++++++++- terracotta/drivers/sqlite.py | 11 +++++++---- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/terracotta/drivers/__init__.py b/terracotta/drivers/__init__.py index c8fd7415..9dd66f92 100644 --- a/terracotta/drivers/__init__.py +++ b/terracotta/drivers/__init__.py @@ -76,8 +76,8 @@ def get_driver(url_or_path: URLOrPathType, provider: str = None) -> Driver: if provider is None: # try and auto-detect provider = auto_detect_provider(url_or_path) - if isinstance(url_or_path, Path) or provider == 'sqlite': - url_or_path = str(Path(url_or_path).resolve()) + if isinstance(url_or_path, Path): + url_or_path = str(url_or_path) DriverClass = load_driver(provider) normalized_path = DriverClass._normalize_path(url_or_path) diff --git a/terracotta/drivers/relational_base.py b/terracotta/drivers/relational_base.py index f635889a..87b787f8 100644 --- a/terracotta/drivers/relational_base.py +++ b/terracotta/drivers/relational_base.py @@ -1,3 +1,8 @@ +"""drivers/relational_base.py + +Base class for relational database drivers, using SQLAlchemy +""" + import contextlib import functools import json @@ -85,7 +90,8 @@ def __init__(self, path: str) -> None: assert self.SQL_DRIVER_TYPE is not None self._CONNECTION_PARAMETERS = self._parse_connection_string(path) cp = self._CONNECTION_PARAMETERS - connection_string = f'{cp.scheme}+{self.SQL_DRIVER_TYPE}://{cp.netloc}{cp.path}' + resolved_path = self._resolve_path(cp.path) + connection_string = f'{cp.scheme}+{self.SQL_DRIVER_TYPE}://{cp.netloc}{resolved_path}' self.sqla_engine = sqla.create_engine( connection_string, @@ -117,6 +123,11 @@ def _parse_connection_string(cls, connection_string: str) -> urlparse.ParseResul return con_params + @classmethod + def _resolve_path(cls, path: str) -> str: + # Default to do nothing; may be overriden to actually handle file paths according to OS + return path + @contextlib.contextmanager def connect(self, verify: bool = True) -> Iterator: if not self.connected: diff --git a/terracotta/drivers/sqlite.py b/terracotta/drivers/sqlite.py index 637a91b3..4dfba07c 100644 --- a/terracotta/drivers/sqlite.py +++ b/terracotta/drivers/sqlite.py @@ -57,13 +57,16 @@ def __init__(self, path: Union[str, Path]) -> None: path: File path to target SQLite database (may or may not exist yet) """ - path = str(path) - super().__init__(f'sqlite:///{os.path.realpath(path)}') + super().__init__(str(path)) + + @classmethod + def _resolve_path(cls, path: str) -> str: + full_path = os.path.realpath(Path(path).resolve()) + return f'/{full_path}' @classmethod def _normalize_path(cls, path: str) -> str: - con_params = cls._parse_connection_string(path) - return os.path.normpath(os.path.realpath(con_params.path)) + return os.path.normpath(os.path.realpath(path)) def _create_database(self) -> None: """The database is automatically created by the sqlite driver on connection, From 48f7d2592296c055454352651d38541d8962eb98 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 21 Dec 2021 13:19:16 +0100 Subject: [PATCH 025/107] Add test for parsing of invalid schemes --- tests/drivers/test_drivers.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/drivers/test_drivers.py b/tests/drivers/test_drivers.py index 29918702..a333066e 100644 --- a/tests/drivers/test_drivers.py +++ b/tests/drivers/test_drivers.py @@ -67,6 +67,22 @@ def test_normalize_url(provider): assert driver._normalize_path(p) == first_path +@pytest.mark.parametrize('provider', TESTABLE_DRIVERS) +def test_parse_connection_string_with_invalid_schemes(provider): + from terracotta import drivers + + invalid_schemes = ( + 'fakescheme://test.example.com/foo', + 'fakescheme://test.example.com:80/foo', + ) + + for invalid_scheme in invalid_schemes: + with pytest.raises(ValueError) as exc: + driver = drivers.get_driver(invalid_scheme, provider) + print(type(driver)) + assert 'unsupported URL scheme' in str(exc.value) + + def test_get_driver_invalid(): from terracotta import drivers with pytest.raises(ValueError) as exc: From 98ebb7d4ca25ba2e8625027f82701b052898e2e2 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 21 Dec 2021 14:13:15 +0100 Subject: [PATCH 026/107] Only test invalid scheme for mysql --- tests/drivers/test_drivers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/drivers/test_drivers.py b/tests/drivers/test_drivers.py index a333066e..0e71c154 100644 --- a/tests/drivers/test_drivers.py +++ b/tests/drivers/test_drivers.py @@ -67,7 +67,7 @@ def test_normalize_url(provider): assert driver._normalize_path(p) == first_path -@pytest.mark.parametrize('provider', TESTABLE_DRIVERS) +@pytest.mark.parametrize('provider', ['mysql']) def test_parse_connection_string_with_invalid_schemes(provider): from terracotta import drivers From f5bc5f9a17ca0e932d12435fc78203c336ff60b8 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 21 Dec 2021 14:14:03 +0100 Subject: [PATCH 027/107] Handle paths such that they hopefully work on Windows as well --- terracotta/drivers/relational_base.py | 10 ++++++++-- terracotta/drivers/sqlite.py | 7 ++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/terracotta/drivers/relational_base.py b/terracotta/drivers/relational_base.py index 87b787f8..e9993299 100644 --- a/terracotta/drivers/relational_base.py +++ b/terracotta/drivers/relational_base.py @@ -58,6 +58,8 @@ class RelationalDriver(RasterDriver, ABC): SQL_KEY_SIZE: int SQL_TIMEOUT_KEY: str + FILE_BASED_DATABASE: bool = False + DATABASE_DRIVER_EXCEPTIONS_TO_CONVERT: Tuple[Type[Exception], ...] = ( sqla.exc.OperationalError, sqla.exc.InternalError, @@ -90,8 +92,8 @@ def __init__(self, path: str) -> None: assert self.SQL_DRIVER_TYPE is not None self._CONNECTION_PARAMETERS = self._parse_connection_string(path) cp = self._CONNECTION_PARAMETERS - resolved_path = self._resolve_path(cp.path) - connection_string = f'{cp.scheme}+{self.SQL_DRIVER_TYPE}://{cp.netloc}{resolved_path}' + resolved_path = self._resolve_path(cp.path[1:]) # Remove the leading '/' + connection_string = f'{cp.scheme}+{self.SQL_DRIVER_TYPE}://{cp.netloc}/{resolved_path}' self.sqla_engine = sqla.create_engine( connection_string, @@ -115,6 +117,10 @@ def __init__(self, path: str) -> None: def _parse_connection_string(cls, connection_string: str) -> urlparse.ParseResult: con_params = urlparse.urlparse(connection_string) + if con_params.scheme == 'file' and cls.FILE_BASED_DATABASE: + file_connection_string = connection_string[len('file://'):] + con_params = urlparse.urlparse(f'{cls.SQL_DATABASE_SCHEME}://{file_connection_string}') + if not con_params.scheme: con_params = urlparse.urlparse(f'{cls.SQL_DATABASE_SCHEME}://{connection_string}') diff --git a/terracotta/drivers/sqlite.py b/terracotta/drivers/sqlite.py index 4dfba07c..d0ff01cc 100644 --- a/terracotta/drivers/sqlite.py +++ b/terracotta/drivers/sqlite.py @@ -47,6 +47,8 @@ class SQLiteDriver(RelationalDriver): SQL_KEY_SIZE = 256 SQL_TIMEOUT_KEY = 'timeout' + FILE_BASED_DATABASE = True + def __init__(self, path: Union[str, Path]) -> None: """Initialize the SQLiteDriver. @@ -57,12 +59,11 @@ def __init__(self, path: Union[str, Path]) -> None: path: File path to target SQLite database (may or may not exist yet) """ - super().__init__(str(path)) + super().__init__(f'file:///{path}') @classmethod def _resolve_path(cls, path: str) -> str: - full_path = os.path.realpath(Path(path).resolve()) - return f'/{full_path}' + return os.path.realpath(Path(path).resolve()) @classmethod def _normalize_path(cls, path: str) -> str: From ddda1bc555c0ac1d4429eea99e00096a05b9cda8 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 4 Jan 2022 11:23:30 +0100 Subject: [PATCH 028/107] Cache drivers within each process only --- terracotta/drivers/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/terracotta/drivers/__init__.py b/terracotta/drivers/__init__.py index 9dd66f92..3470e40c 100644 --- a/terracotta/drivers/__init__.py +++ b/terracotta/drivers/__init__.py @@ -3,6 +3,7 @@ Define an interface to retrieve Terracotta drivers. """ +import os from typing import Union, Tuple, Dict, Type import urllib.parse as urlparse from pathlib import Path @@ -81,7 +82,7 @@ def get_driver(url_or_path: URLOrPathType, provider: str = None) -> Driver: DriverClass = load_driver(provider) normalized_path = DriverClass._normalize_path(url_or_path) - cache_key = (normalized_path, provider) + cache_key = (normalized_path, provider, os.getpid()) if cache_key not in _DRIVER_CACHE: _DRIVER_CACHE[cache_key] = DriverClass(url_or_path) From faf8f64942e29df6558b8fa8d71b21651d938974 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 4 Jan 2022 11:24:16 +0100 Subject: [PATCH 029/107] Set connection transaction isolation level to READ UNCOMMITTED to enable proper upserts --- terracotta/drivers/relational_base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/terracotta/drivers/relational_base.py b/terracotta/drivers/relational_base.py index e9993299..3d637879 100644 --- a/terracotta/drivers/relational_base.py +++ b/terracotta/drivers/relational_base.py @@ -139,7 +139,9 @@ def connect(self, verify: bool = True) -> Iterator: if not self.connected: def _connect_with_exceptions_converted() -> Connection: with convert_exceptions_context(_ERROR_ON_CONNECT, sqla.exc.OperationalError): - connection = self.sqla_engine.connect() + connection = self.sqla_engine.connect().execution_options( + isolation_level='READ UNCOMMITTED' + ) return connection try: with _connect_with_exceptions_converted() as connection: From 4a4657c519ceb7007d935c5647a1c61784a71eda Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 4 Jan 2022 11:33:39 +0100 Subject: [PATCH 030/107] Satisfy mypy --- terracotta/drivers/__init__.py | 2 +- terracotta/drivers/sqlite_remote.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/terracotta/drivers/__init__.py b/terracotta/drivers/__init__.py index 3470e40c..19631f28 100644 --- a/terracotta/drivers/__init__.py +++ b/terracotta/drivers/__init__.py @@ -42,7 +42,7 @@ def auto_detect_provider(url_or_path: Union[str, Path]) -> str: return 'sqlite' -_DRIVER_CACHE: Dict[Tuple[URLOrPathType, str], Driver] = {} +_DRIVER_CACHE: Dict[Tuple[URLOrPathType, str, int], Driver] = {} def get_driver(url_or_path: URLOrPathType, provider: str = None) -> Driver: diff --git a/terracotta/drivers/sqlite_remote.py b/terracotta/drivers/sqlite_remote.py index c6480031..04271cb7 100644 --- a/terracotta/drivers/sqlite_remote.py +++ b/terracotta/drivers/sqlite_remote.py @@ -137,10 +137,10 @@ def _verify_db_version(self) -> None: def create(self, *args: Any, **kwargs: Any) -> None: raise NotImplementedError('Remote SQLite databases are read-only') - def insert(self, *args: Any, **kwargs: Any) -> None: + def insert(self, *args: Any, **kwargs: Any) -> None: # type: ignore raise NotImplementedError('Remote SQLite databases are read-only') - def delete(self, *args: Any, **kwargs: Any) -> None: + def delete(self, *args: Any, **kwargs: Any) -> None: # type: ignore raise NotImplementedError('Remote SQLite databases are read-only') def __del__(self) -> None: From 2a2c57f0ad7909855fbcd029df41e63c6c839005 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 4 Jan 2022 12:34:03 +0100 Subject: [PATCH 031/107] Cleanup code --- terracotta/drivers/relational_base.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/terracotta/drivers/relational_base.py b/terracotta/drivers/relational_base.py index 3d637879..c13d7658 100644 --- a/terracotta/drivers/relational_base.py +++ b/terracotta/drivers/relational_base.py @@ -211,9 +211,8 @@ def create(self, keys: Sequence[str], key_descriptions: Mapping[str, str] = None @abstractmethod def _create_database(self) -> None: - # This might be made abstract, for each subclass to implement specifically # Note that some subclasses may not actually create any database here, as - # it may already exist for some vendors + # it may be created automatically on connection for some database vendors pass @requires_connection(verify=False) @@ -257,7 +256,6 @@ def _initialize_database( *[sqla.Column(name, column_type()) for name, column_type in self._METADATA_COLUMNS] ) self.sqla_metadata.create_all(self.sqla_engine) - # self.connection.commit() self.connection.execute( terracotta_table.insert().values(version=terracotta.__version__) @@ -269,10 +267,6 @@ def _initialize_database( for i, key in enumerate(keys) ] ) - # self.connection.commit() - - # invalidate key cache # TODO: Is that actually necessary? - self._db_keys = None @requires_connection @convert_exceptions('Could not retrieve keys from database') @@ -428,8 +422,6 @@ def insert( metadata_table.insert().values(**key_dict, **encoded_data) ) - # self.connection.commit() - @trace('delete') @requires_connection @convert_exceptions('Could not write to database') @@ -458,7 +450,6 @@ def delete(self, keys: Union[Sequence[str], Mapping[str, str]]) -> None: .delete() .where(*[metadata_table.c.get(column) == value for column, value in key_dict.items()]) ) - # self.connection.commit() @staticmethod def _encode_data(decoded: Mapping[str, Any]) -> Dict[str, Any]: From 0aa63c4ff23cd471a35551fbf0b87177acb8f654 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Fri, 7 Jan 2022 14:52:42 +0100 Subject: [PATCH 032/107] Fix failing url_parse test --- terracotta/drivers/relational_base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/terracotta/drivers/relational_base.py b/terracotta/drivers/relational_base.py index c13d7658..0abc1eb3 100644 --- a/terracotta/drivers/relational_base.py +++ b/terracotta/drivers/relational_base.py @@ -115,6 +115,9 @@ def __init__(self, path: str) -> None: @classmethod def _parse_connection_string(cls, connection_string: str) -> urlparse.ParseResult: + if "//" not in connection_string: + connection_string = f"//{connection_string}" + con_params = urlparse.urlparse(connection_string) if con_params.scheme == 'file' and cls.FILE_BASED_DATABASE: @@ -122,7 +125,7 @@ def _parse_connection_string(cls, connection_string: str) -> urlparse.ParseResul con_params = urlparse.urlparse(f'{cls.SQL_DATABASE_SCHEME}://{file_connection_string}') if not con_params.scheme: - con_params = urlparse.urlparse(f'{cls.SQL_DATABASE_SCHEME}://{connection_string}') + con_params = urlparse.urlparse(f'{cls.SQL_DATABASE_SCHEME}:{connection_string}') if con_params.scheme != cls.SQL_DATABASE_SCHEME: raise ValueError(f'unsupported URL scheme "{con_params.scheme}"') From 3a8144dbaa42ead226b22f7eb9eb1d391fa234aa Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Fri, 7 Jan 2022 14:54:18 +0100 Subject: [PATCH 033/107] Use docstring instead of comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Dion Häfner --- terracotta/drivers/base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/terracotta/drivers/base.py b/terracotta/drivers/base.py index 2eaf5ca3..5756e525 100644 --- a/terracotta/drivers/base.py +++ b/terracotta/drivers/base.py @@ -39,12 +39,14 @@ class Driver(ABC): @property @abstractmethod def db_version(self) -> str: - ... # Terracotta version used to create the database + """Terracotta version used to create the database.""" + pass @property @abstractmethod def key_names(self) -> Tuple[str, ...]: - ... # Names of all keys defined by the database + """Names of all keys defined by the database.""" + pass @abstractmethod def __init__(self, url_or_path: str) -> None: From d71d2d609e84b924040d36b0dfe1734663ed72a8 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Fri, 7 Jan 2022 15:00:08 +0100 Subject: [PATCH 034/107] Describe the verify argument in connect method --- terracotta/drivers/base.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/terracotta/drivers/base.py b/terracotta/drivers/base.py index 5756e525..97eeef6a 100644 --- a/terracotta/drivers/base.py +++ b/terracotta/drivers/base.py @@ -70,6 +70,12 @@ def connect(self, verify: bool = True) -> contextlib.AbstractContextManager: This allows you to pool interactions with the database to prevent possibly expensive reconnects, or to roll back several interactions if one of them fails. + Arguments: + + verify: Whether to verify the database (primarily its version) when connecting. + Should be `true` unless absolutely necessary, such as when instantiating the + database during creation of it. + Note: Make sure to call :meth:`create` on a fresh database before using this method. From 8b95b2090db2f8723dd9253c46f592fe3101cc61 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Fri, 7 Jan 2022 15:14:11 +0100 Subject: [PATCH 035/107] Rename SQL_DATABASE_SCHEME to SQL_URL_SCHEME --- terracotta/drivers/mysql.py | 2 +- terracotta/drivers/relational_base.py | 8 ++++---- terracotta/drivers/sqlite.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/terracotta/drivers/mysql.py b/terracotta/drivers/mysql.py index 3f20c03c..783d7bd0 100644 --- a/terracotta/drivers/mysql.py +++ b/terracotta/drivers/mysql.py @@ -27,7 +27,7 @@ class MySQLDriver(RelationalDriver): This driver caches raster data and key names, but not metadata. """ - SQL_DATABASE_SCHEME = 'mysql' + SQL_URL_SCHEME = 'mysql' SQL_DRIVER_TYPE = 'pymysql' SQL_KEY_SIZE = 50 SQL_TIMEOUT_KEY = 'connect_timeout' diff --git a/terracotta/drivers/relational_base.py b/terracotta/drivers/relational_base.py index 0abc1eb3..63dedf2b 100644 --- a/terracotta/drivers/relational_base.py +++ b/terracotta/drivers/relational_base.py @@ -53,7 +53,7 @@ def convert_exceptions_context( class RelationalDriver(RasterDriver, ABC): - SQL_DATABASE_SCHEME: str # The database flavour, eg mysql, sqlite, etc + SQL_URL_SCHEME: str # The database flavour, eg mysql, sqlite, etc SQL_DRIVER_TYPE: str # The actual database driver, eg pymysql, sqlite3, etc SQL_KEY_SIZE: int SQL_TIMEOUT_KEY: str @@ -122,12 +122,12 @@ def _parse_connection_string(cls, connection_string: str) -> urlparse.ParseResul if con_params.scheme == 'file' and cls.FILE_BASED_DATABASE: file_connection_string = connection_string[len('file://'):] - con_params = urlparse.urlparse(f'{cls.SQL_DATABASE_SCHEME}://{file_connection_string}') + con_params = urlparse.urlparse(f'{cls.SQL_URL_SCHEME}://{file_connection_string}') if not con_params.scheme: - con_params = urlparse.urlparse(f'{cls.SQL_DATABASE_SCHEME}:{connection_string}') + con_params = urlparse.urlparse(f'{cls.SQL_URL_SCHEME}:{connection_string}') - if con_params.scheme != cls.SQL_DATABASE_SCHEME: + if con_params.scheme != cls.SQL_URL_SCHEME: raise ValueError(f'unsupported URL scheme "{con_params.scheme}"') return con_params diff --git a/terracotta/drivers/sqlite.py b/terracotta/drivers/sqlite.py index d0ff01cc..16f84402 100644 --- a/terracotta/drivers/sqlite.py +++ b/terracotta/drivers/sqlite.py @@ -42,7 +42,7 @@ class SQLiteDriver(RelationalDriver): outside the main thread. """ - SQL_DATABASE_SCHEME = 'sqlite' + SQL_URL_SCHEME = 'sqlite' SQL_DRIVER_TYPE = 'pysqlite' SQL_KEY_SIZE = 256 SQL_TIMEOUT_KEY = 'timeout' From 38d78802660c385000d8daa42c1d3bb3fbe6524c Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Fri, 7 Jan 2022 15:15:05 +0100 Subject: [PATCH 036/107] Don't echo anymore --- terracotta/drivers/relational_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/terracotta/drivers/relational_base.py b/terracotta/drivers/relational_base.py index 63dedf2b..4b7d915b 100644 --- a/terracotta/drivers/relational_base.py +++ b/terracotta/drivers/relational_base.py @@ -97,7 +97,7 @@ def __init__(self, path: str) -> None: self.sqla_engine = sqla.create_engine( connection_string, - echo=True, + echo=False, future=True, connect_args={self.SQL_TIMEOUT_KEY: db_connection_timeout} ) From 5078bc36ae2b6825ed3a754394229fd8710473a5 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Fri, 7 Jan 2022 15:15:51 +0100 Subject: [PATCH 037/107] Describe assertion check better MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Dion Häfner --- terracotta/drivers/mysql.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/terracotta/drivers/mysql.py b/terracotta/drivers/mysql.py index 783d7bd0..00ec392e 100644 --- a/terracotta/drivers/mysql.py +++ b/terracotta/drivers/mysql.py @@ -46,7 +46,8 @@ def __init__(self, mysql_path: str) -> None: """ super().__init__(mysql_path) - self._parse_db_name(self._CONNECTION_PARAMETERS) # To enforce path is parsable + # raises an exception if path is invalid + self._parse_db_name(self._CONNECTION_PARAMETERS) @classmethod def _normalize_path(cls, path: str) -> str: From c4bb17c090a13d099ca7b7139e69af34a46ec4e1 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Fri, 7 Jan 2022 15:34:29 +0100 Subject: [PATCH 038/107] Restore mysql primary key size calculation --- terracotta/drivers/mysql.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/terracotta/drivers/mysql.py b/terracotta/drivers/mysql.py index 00ec392e..4b3b456c 100644 --- a/terracotta/drivers/mysql.py +++ b/terracotta/drivers/mysql.py @@ -4,6 +4,7 @@ to be present on disk. """ +from typing import Mapping, Sequence from urllib.parse import ParseResult import sqlalchemy as sqla @@ -29,9 +30,9 @@ class MySQLDriver(RelationalDriver): """ SQL_URL_SCHEME = 'mysql' SQL_DRIVER_TYPE = 'pymysql' - SQL_KEY_SIZE = 50 SQL_TIMEOUT_KEY = 'connect_timeout' + MAX_PRIMARY_KEY_SIZE = 767 // 4 DEFAULT_PORT = 3306 def __init__(self, mysql_path: str) -> None: @@ -79,3 +80,11 @@ def _create_database(self) -> None: db_name = self._parse_db_name(self._CONNECTION_PARAMETERS) connection.execute(sqla.text(f'CREATE DATABASE {db_name}')) connection.commit() + + def _initialize_database( + self, + keys: Sequence[str], + key_descriptions: Mapping[str, str] = None + ) -> None: + self.SQL_KEY_SIZE = self.MAX_PRIMARY_KEY_SIZE // len(keys) + super()._initialize_database(keys, key_descriptions) From f391f861e7b5c5d91943e88cb72149b23af2ba31 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Fri, 7 Jan 2022 15:35:45 +0100 Subject: [PATCH 039/107] Describe the max primary key length --- terracotta/drivers/mysql.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/terracotta/drivers/mysql.py b/terracotta/drivers/mysql.py index 4b3b456c..85f689fa 100644 --- a/terracotta/drivers/mysql.py +++ b/terracotta/drivers/mysql.py @@ -32,7 +32,7 @@ class MySQLDriver(RelationalDriver): SQL_DRIVER_TYPE = 'pymysql' SQL_TIMEOUT_KEY = 'connect_timeout' - MAX_PRIMARY_KEY_SIZE = 767 // 4 + MAX_PRIMARY_KEY_SIZE = 767 // 4 # Max key length for MySQL is at least 767B DEFAULT_PORT = 3306 def __init__(self, mysql_path: str) -> None: @@ -86,5 +86,6 @@ def _initialize_database( keys: Sequence[str], key_descriptions: Mapping[str, str] = None ) -> None: + # total primary key length has an upper limit in MySQL self.SQL_KEY_SIZE = self.MAX_PRIMARY_KEY_SIZE // len(keys) super()._initialize_database(keys, key_descriptions) From bcd9a521f434a067829e9b19dc23f0ff9bfe39d7 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Fri, 7 Jan 2022 16:11:40 +0100 Subject: [PATCH 040/107] Improve exception convertions --- terracotta/drivers/relational_base.py | 35 +++++++++------------------ terracotta/drivers/sqlite_remote.py | 6 ++--- 2 files changed, 15 insertions(+), 26 deletions(-) diff --git a/terracotta/drivers/relational_base.py b/terracotta/drivers/relational_base.py index 4b7d915b..30b81f80 100644 --- a/terracotta/drivers/relational_base.py +++ b/terracotta/drivers/relational_base.py @@ -10,8 +10,8 @@ import urllib.parse as urlparse from abc import ABC, abstractmethod from collections import OrderedDict -from typing import (Any, Callable, Dict, Iterator, List, Mapping, Optional, - Sequence, Tuple, Type, Union) +from typing import (Any, Dict, Iterator, List, Mapping, Optional, Sequence, + Tuple, Type, Union) import numpy as np import sqlalchemy as sqla @@ -27,24 +27,20 @@ 'to a valid Terracotta database, and that you ran driver.create().' ) +DATABASE_DRIVER_EXCEPTIONS_TO_CONVERT: Tuple[Type[Exception], ...] = ( + sqla.exc.OperationalError, + sqla.exc.InternalError, + sqla.exc.ProgrammingError, + sqla.exc.InvalidRequestError, +) -def convert_exceptions(error_message: str) -> Callable[..., Any]: - def decorator(fun: Callable[..., Any]) -> Callable[..., Any]: - @functools.wraps(fun) - def inner(self: 'RelationalDriver', *args: Any, **kwargs: Any) -> Any: - with convert_exceptions_context( - error_message, - self.DATABASE_DRIVER_EXCEPTIONS_TO_CONVERT - ): - return fun(self, *args, **kwargs) - return inner - return decorator +ExceptionType = Union[Type[Exception], Tuple[Type[Exception], ...]] @contextlib.contextmanager -def convert_exceptions_context( +def convert_exceptions( error_message: str, - exceptions_to_convert: Union[Type[Exception], Tuple[Type[Exception], ...]] + exceptions_to_convert: ExceptionType = DATABASE_DRIVER_EXCEPTIONS_TO_CONVERT, ) -> Iterator: try: yield @@ -60,13 +56,6 @@ class RelationalDriver(RasterDriver, ABC): FILE_BASED_DATABASE: bool = False - DATABASE_DRIVER_EXCEPTIONS_TO_CONVERT: Tuple[Type[Exception], ...] = ( - sqla.exc.OperationalError, - sqla.exc.InternalError, - sqla.exc.ProgrammingError, - sqla.exc.InvalidRequestError, - ) - SQLA_REAL = functools.partial(sqla.types.Float, precision=8) SQLA_TEXT = sqla.types.Text SQLA_BLOB = sqla.types.LargeBinary @@ -141,7 +130,7 @@ def _resolve_path(cls, path: str) -> str: def connect(self, verify: bool = True) -> Iterator: if not self.connected: def _connect_with_exceptions_converted() -> Connection: - with convert_exceptions_context(_ERROR_ON_CONNECT, sqla.exc.OperationalError): + with convert_exceptions(_ERROR_ON_CONNECT, sqla.exc.OperationalError): connection = self.sqla_engine.connect().execution_options( isolation_level='READ UNCOMMITTED' ) diff --git a/terracotta/drivers/sqlite_remote.py b/terracotta/drivers/sqlite_remote.py index 04271cb7..d1a4e08b 100644 --- a/terracotta/drivers/sqlite_remote.py +++ b/terracotta/drivers/sqlite_remote.py @@ -23,7 +23,7 @@ @contextlib.contextmanager def convert_exceptions(msg: str) -> Iterator: - """Convert internal sqlite and boto exceptions to our InvalidDatabaseError""" + """Convert internal boto exceptions to our InvalidDatabaseError""" import botocore.exceptions try: yield @@ -137,10 +137,10 @@ def _verify_db_version(self) -> None: def create(self, *args: Any, **kwargs: Any) -> None: raise NotImplementedError('Remote SQLite databases are read-only') - def insert(self, *args: Any, **kwargs: Any) -> None: # type: ignore + def insert(self, *args: Any, **kwargs: Any) -> None: raise NotImplementedError('Remote SQLite databases are read-only') - def delete(self, *args: Any, **kwargs: Any) -> None: # type: ignore + def delete(self, *args: Any, **kwargs: Any) -> None: raise NotImplementedError('Remote SQLite databases are read-only') def __del__(self) -> None: From 0d104c4b38ce90943a646fadffd7438705e47d6b Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Fri, 7 Jan 2022 16:47:01 +0100 Subject: [PATCH 041/107] Use cleaner exception handling in connect() --- terracotta/drivers/relational_base.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/terracotta/drivers/relational_base.py b/terracotta/drivers/relational_base.py index 30b81f80..c78bd154 100644 --- a/terracotta/drivers/relational_base.py +++ b/terracotta/drivers/relational_base.py @@ -128,15 +128,13 @@ def _resolve_path(cls, path: str) -> str: @contextlib.contextmanager def connect(self, verify: bool = True) -> Iterator: + @convert_exceptions(_ERROR_ON_CONNECT, sqla.exc.OperationalError) + def get_connection() -> Connection: + return self.sqla_engine.connect().execution_options(isolation_level='READ UNCOMMITTED') + if not self.connected: - def _connect_with_exceptions_converted() -> Connection: - with convert_exceptions(_ERROR_ON_CONNECT, sqla.exc.OperationalError): - connection = self.sqla_engine.connect().execution_options( - isolation_level='READ UNCOMMITTED' - ) - return connection try: - with _connect_with_exceptions_converted() as connection: + with get_connection() as connection: self.connection = connection self.connected = True if verify: From 813ebc73fc8a23a6b2740ced506b6335a6812592 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Fri, 7 Jan 2022 16:51:20 +0100 Subject: [PATCH 042/107] Undo renaming of _connection_callback --- terracotta/drivers/relational_base.py | 4 ++-- terracotta/drivers/sqlite_remote.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/terracotta/drivers/relational_base.py b/terracotta/drivers/relational_base.py index c78bd154..e04ef4ab 100644 --- a/terracotta/drivers/relational_base.py +++ b/terracotta/drivers/relational_base.py @@ -138,7 +138,7 @@ def get_connection() -> Connection: self.connection = connection self.connected = True if verify: - self._verify_db_version() + self._connection_callback() yield self.connection.commit() @@ -152,7 +152,7 @@ def get_connection() -> Connection: self.connection.rollback() raise exception - def _verify_db_version(self) -> None: + def _connection_callback(self) -> None: if not self.db_version_verified: # check for version compatibility def version_tuple(version_string: str) -> Sequence[str]: diff --git a/terracotta/drivers/sqlite_remote.py b/terracotta/drivers/sqlite_remote.py index d1a4e08b..e680673e 100644 --- a/terracotta/drivers/sqlite_remote.py +++ b/terracotta/drivers/sqlite_remote.py @@ -130,9 +130,9 @@ def _update_db(self, remote_path: str, local_path: str) -> None: _update_from_s3(remote_path, local_path) self._last_updated = time.time() - def _verify_db_version(self) -> None: + def _connection_callback(self) -> None: self._update_db(self._remote_path, self._local_path) - super()._verify_db_version() + super()._connection_callback() def create(self, *args: Any, **kwargs: Any) -> None: raise NotImplementedError('Remote SQLite databases are read-only') From 1b6a7f53f19b1a46dc812352b9daaeb51eee8d6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Fri, 7 Jan 2022 20:57:56 +0100 Subject: [PATCH 043/107] fix mypy error --- terracotta/cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/terracotta/cache.py b/terracotta/cache.py index b7e6cab1..9072dbe7 100644 --- a/terracotta/cache.py +++ b/terracotta/cache.py @@ -35,7 +35,7 @@ def __setitem__(self, key: Any, def _compress_ma(arr: np.ma.MaskedArray, compression_level: int) -> CompressionTuple: compressed_data = zlib.compress(arr.data, compression_level) mask_to_int = np.packbits(arr.mask.astype(np.uint8)) - compressed_mask = zlib.compress(mask_to_int, compression_level) + compressed_mask = zlib.compress(mask_to_int.data, compression_level) out = ( compressed_data, compressed_mask, From 25c594b0d3e1d1738700bb48ae63f562b3641d84 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 11 Jan 2022 13:31:34 +0100 Subject: [PATCH 044/107] Reimplement lazy loading max shape --- terracotta/drivers/relational_base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/terracotta/drivers/relational_base.py b/terracotta/drivers/relational_base.py index e04ef4ab..a4ff1b89 100644 --- a/terracotta/drivers/relational_base.py +++ b/terracotta/drivers/relational_base.py @@ -77,6 +77,7 @@ class RelationalDriver(RasterDriver, ABC): def __init__(self, path: str) -> None: settings = terracotta.get_settings() db_connection_timeout: int = settings.DB_CONNECTION_TIMEOUT + self.LAZY_LOADING_MAX_SHAPE: Tuple[int, int] = settings.LAZY_LOADING_MAX_SHAPE assert self.SQL_DRIVER_TYPE is not None self._CONNECTION_PARAMETERS = self._parse_connection_string(path) @@ -353,7 +354,8 @@ def get_metadata(self, keys: Union[Sequence[str], Mapping[str, str]]) -> Dict[st assert len(filepath) == 1 # compute metadata and try again - self.insert(keys, filepath[keys], skip_metadata=False) + metadata = self.compute_metadata(filepath[keys], max_shape=self.LAZY_LOADING_MAX_SHAPE) + self.insert(keys, filepath[keys], metadata=metadata) row = self.connection.execute(stmt).first() assert row From e850659c5c23fe6a9bad8fa20e2be54097fd3b48 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 11 Jan 2022 14:06:51 +0100 Subject: [PATCH 045/107] Don't echo anywhere --- terracotta/drivers/mysql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/terracotta/drivers/mysql.py b/terracotta/drivers/mysql.py index 85f689fa..c130b188 100644 --- a/terracotta/drivers/mysql.py +++ b/terracotta/drivers/mysql.py @@ -73,7 +73,7 @@ def _create_database(self) -> None: engine = sqla.create_engine( f'{self._CONNECTION_PARAMETERS.scheme}+{self.SQL_DRIVER_TYPE}://' f'{self._CONNECTION_PARAMETERS.netloc}', - echo=True, + echo=False, future=True ) with engine.connect() as connection: From 2f7eff287c53b4c437e83cf33185ecaf4f027342 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 11 Jan 2022 15:04:53 +0100 Subject: [PATCH 046/107] Explicitly define mysql tables with specific charset --- terracotta/drivers/mysql.py | 6 ++++ terracotta/drivers/relational_base.py | 43 ++++++++++++++------------- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/terracotta/drivers/mysql.py b/terracotta/drivers/mysql.py index c130b188..5818e9f9 100644 --- a/terracotta/drivers/mysql.py +++ b/terracotta/drivers/mysql.py @@ -4,10 +4,12 @@ to be present on disk. """ +import functools from typing import Mapping, Sequence from urllib.parse import ParseResult import sqlalchemy as sqla +from sqlalchemy.dialects.mysql import VARCHAR, TEXT from terracotta.drivers.relational_base import RelationalDriver @@ -32,6 +34,10 @@ class MySQLDriver(RelationalDriver): SQL_DRIVER_TYPE = 'pymysql' SQL_TIMEOUT_KEY = 'connect_timeout' + _CHARSET = 'utf8mb4' + SQLA_STRING = functools.partial(VARCHAR, charset=_CHARSET) + SQLA_TEXT = functools.partial(TEXT, charset=_CHARSET) + MAX_PRIMARY_KEY_SIZE = 767 // 4 # Max key length for MySQL is at least 767B DEFAULT_PORT = 3306 diff --git a/terracotta/drivers/relational_base.py b/terracotta/drivers/relational_base.py index a4ff1b89..a784ada6 100644 --- a/terracotta/drivers/relational_base.py +++ b/terracotta/drivers/relational_base.py @@ -56,25 +56,28 @@ class RelationalDriver(RasterDriver, ABC): FILE_BASED_DATABASE: bool = False + SQLA_STRING = sqla.types.String SQLA_REAL = functools.partial(sqla.types.Float, precision=8) SQLA_TEXT = sqla.types.Text SQLA_BLOB = sqla.types.LargeBinary - _METADATA_COLUMNS: Tuple[Tuple[str, sqla.types.TypeEngine], ...] = ( - ('bounds_north', SQLA_REAL), - ('bounds_east', SQLA_REAL), - ('bounds_south', SQLA_REAL), - ('bounds_west', SQLA_REAL), - ('convex_hull', SQLA_TEXT), - ('valid_percentage', SQLA_REAL), - ('min', SQLA_REAL), - ('max', SQLA_REAL), - ('mean', SQLA_REAL), - ('stdev', SQLA_REAL), - ('percentiles', SQLA_BLOB), - ('metadata', SQLA_TEXT) - ) def __init__(self, path: str) -> None: + # it is sadly necessary to define this in here, in order to let subclasses redefine types + self._METADATA_COLUMNS: Tuple[Tuple[str, sqla.types.TypeEngine], ...] = ( + ('bounds_north', self.SQLA_REAL), + ('bounds_east', self.SQLA_REAL), + ('bounds_south', self.SQLA_REAL), + ('bounds_west', self.SQLA_REAL), + ('convex_hull', self.SQLA_TEXT), + ('valid_percentage', self.SQLA_REAL), + ('min', self.SQLA_REAL), + ('max', self.SQLA_REAL), + ('mean', self.SQLA_REAL), + ('stdev', self.SQLA_REAL), + ('percentiles', self.SQLA_BLOB), + ('metadata', self.SQLA_TEXT) + ) + settings = terracotta.get_settings() db_connection_timeout: int = settings.DB_CONNECTION_TIMEOUT self.LAZY_LOADING_MAX_SHAPE: Tuple[int, int] = settings.LAZY_LOADING_MAX_SHAPE @@ -228,22 +231,22 @@ def _initialize_database( terracotta_table = sqla.Table( 'terracotta', self.sqla_metadata, - sqla.Column('version', sqla.types.String(255), primary_key=True) + sqla.Column('version', self.SQLA_STRING(255), primary_key=True) ) key_names_table = sqla.Table( 'key_names', self.sqla_metadata, - sqla.Column('key_name', sqla.types.String(self.SQL_KEY_SIZE), primary_key=True), - sqla.Column('description', sqla.types.String(8000)), + sqla.Column('key_name', self.SQLA_STRING(self.SQL_KEY_SIZE), primary_key=True), + sqla.Column('description', self.SQLA_STRING(8000)), sqla.Column('index', sqla.types.Integer, unique=True) ) _ = sqla.Table( 'datasets', self.sqla_metadata, - *[sqla.Column(key, sqla.types.String(self.SQL_KEY_SIZE), primary_key=True) for key in keys], # noqa: E501 - sqla.Column('filepath', sqla.types.String(8000)) + *[sqla.Column(key, self.SQLA_STRING(self.SQL_KEY_SIZE), primary_key=True) for key in keys], # noqa: E501 + sqla.Column('filepath', self.SQLA_STRING(8000)) ) _ = sqla.Table( 'metadata', self.sqla_metadata, - *[sqla.Column(key, sqla.types.String(self.SQL_KEY_SIZE), primary_key=True) for key in keys], # noqa: E501 + *[sqla.Column(key, self.SQLA_STRING(self.SQL_KEY_SIZE), primary_key=True) for key in keys], # noqa: E501 *[sqla.Column(name, column_type()) for name, column_type in self._METADATA_COLUMNS] ) self.sqla_metadata.create_all(self.sqla_engine) From a634b398f60af437dc0efa8fe8316395b3bb2db6 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 11 Jan 2022 15:22:06 +0100 Subject: [PATCH 047/107] Just use Pathlib resolve method --- terracotta/drivers/sqlite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/terracotta/drivers/sqlite.py b/terracotta/drivers/sqlite.py index 16f84402..811a0904 100644 --- a/terracotta/drivers/sqlite.py +++ b/terracotta/drivers/sqlite.py @@ -63,7 +63,7 @@ def __init__(self, path: Union[str, Path]) -> None: @classmethod def _resolve_path(cls, path: str) -> str: - return os.path.realpath(Path(path).resolve()) + return str(Path(path).resolve()) @classmethod def _normalize_path(cls, path: str) -> str: From 229bcc4e99842cf24f95a6c12468f6924f60fe17 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 11 Jan 2022 16:21:35 +0100 Subject: [PATCH 048/107] Move _METADATA_COLUMNS out as a class variable again --- terracotta/drivers/mysql.py | 4 +- terracotta/drivers/relational_base.py | 53 ++++++++++++++++----------- 2 files changed, 34 insertions(+), 23 deletions(-) diff --git a/terracotta/drivers/mysql.py b/terracotta/drivers/mysql.py index 5818e9f9..a0072e83 100644 --- a/terracotta/drivers/mysql.py +++ b/terracotta/drivers/mysql.py @@ -36,7 +36,6 @@ class MySQLDriver(RelationalDriver): _CHARSET = 'utf8mb4' SQLA_STRING = functools.partial(VARCHAR, charset=_CHARSET) - SQLA_TEXT = functools.partial(TEXT, charset=_CHARSET) MAX_PRIMARY_KEY_SIZE = 767 // 4 # Max key length for MySQL is at least 767B DEFAULT_PORT = 3306 @@ -53,6 +52,9 @@ def __init__(self, mysql_path: str) -> None: """ super().__init__(mysql_path) + + self.SQLA_METADATA_TYPE_LOOKUP['text'] = functools.partial(TEXT, charset=self._CHARSET) + # raises an exception if path is invalid self._parse_db_name(self._CONNECTION_PARAMETERS) diff --git a/terracotta/drivers/relational_base.py b/terracotta/drivers/relational_base.py index a784ada6..aad330c5 100644 --- a/terracotta/drivers/relational_base.py +++ b/terracotta/drivers/relational_base.py @@ -57,27 +57,28 @@ class RelationalDriver(RasterDriver, ABC): FILE_BASED_DATABASE: bool = False SQLA_STRING = sqla.types.String - SQLA_REAL = functools.partial(sqla.types.Float, precision=8) - SQLA_TEXT = sqla.types.Text - SQLA_BLOB = sqla.types.LargeBinary + SQLA_METADATA_TYPE_LOOKUP: Dict[str, sqla.types.TypeEngine] = { + 'real': functools.partial(sqla.types.Float, precision=8), + 'text': sqla.types.Text, + 'blob': sqla.types.LargeBinary + } + + _METADATA_COLUMNS: Tuple[Tuple[str, str], ...] = ( + ('bounds_north', 'real'), + ('bounds_east', 'real'), + ('bounds_south', 'real'), + ('bounds_west', 'real'), + ('convex_hull', 'text'), + ('valid_percentage', 'real'), + ('min', 'real'), + ('max', 'real'), + ('mean', 'real'), + ('stdev', 'real'), + ('percentiles', 'blob'), + ('metadata', 'text') + ) def __init__(self, path: str) -> None: - # it is sadly necessary to define this in here, in order to let subclasses redefine types - self._METADATA_COLUMNS: Tuple[Tuple[str, sqla.types.TypeEngine], ...] = ( - ('bounds_north', self.SQLA_REAL), - ('bounds_east', self.SQLA_REAL), - ('bounds_south', self.SQLA_REAL), - ('bounds_west', self.SQLA_REAL), - ('convex_hull', self.SQLA_TEXT), - ('valid_percentage', self.SQLA_REAL), - ('min', self.SQLA_REAL), - ('max', self.SQLA_REAL), - ('mean', self.SQLA_REAL), - ('stdev', self.SQLA_REAL), - ('percentiles', self.SQLA_BLOB), - ('metadata', self.SQLA_TEXT) - ) - settings = terracotta.get_settings() db_connection_timeout: int = settings.DB_CONNECTION_TIMEOUT self.LAZY_LOADING_MAX_SHAPE: Tuple[int, int] = settings.LAZY_LOADING_MAX_SHAPE @@ -241,13 +242,21 @@ def _initialize_database( ) _ = sqla.Table( 'datasets', self.sqla_metadata, - *[sqla.Column(key, self.SQLA_STRING(self.SQL_KEY_SIZE), primary_key=True) for key in keys], # noqa: E501 + *[ + sqla.Column(key, self.SQLA_STRING(self.SQL_KEY_SIZE), primary_key=True) + for key in keys + ], sqla.Column('filepath', self.SQLA_STRING(8000)) ) _ = sqla.Table( 'metadata', self.sqla_metadata, - *[sqla.Column(key, self.SQLA_STRING(self.SQL_KEY_SIZE), primary_key=True) for key in keys], # noqa: E501 - *[sqla.Column(name, column_type()) for name, column_type in self._METADATA_COLUMNS] + *[ + sqla.Column(key, self.SQLA_STRING(self.SQL_KEY_SIZE), primary_key=True) + for key in keys], + *[ + sqla.Column(name, self.SQLA_METADATA_TYPE_LOOKUP[column_type]()) + for name, column_type in self._METADATA_COLUMNS + ] ) self.sqla_metadata.create_all(self.sqla_engine) From 36e9e8efea173d4aee4b51e78bebd244ea30eb57 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 11 Jan 2022 17:56:45 +0100 Subject: [PATCH 049/107] Improve parsing of paths --- terracotta/drivers/mysql.py | 31 ++++++++---------------- terracotta/drivers/relational_base.py | 34 ++++++++++++--------------- terracotta/drivers/sqlite.py | 6 +---- tests/drivers/test_mysql.py | 15 ++++++------ 4 files changed, 33 insertions(+), 53 deletions(-) diff --git a/terracotta/drivers/mysql.py b/terracotta/drivers/mysql.py index a0072e83..6e11d170 100644 --- a/terracotta/drivers/mysql.py +++ b/terracotta/drivers/mysql.py @@ -6,10 +6,9 @@ import functools from typing import Mapping, Sequence -from urllib.parse import ParseResult import sqlalchemy as sqla -from sqlalchemy.dialects.mysql import VARCHAR, TEXT +from sqlalchemy.dialects.mysql import TEXT, VARCHAR from terracotta.drivers.relational_base import RelationalDriver @@ -55,38 +54,28 @@ def __init__(self, mysql_path: str) -> None: self.SQLA_METADATA_TYPE_LOOKUP['text'] = functools.partial(TEXT, charset=self._CHARSET) - # raises an exception if path is invalid - self._parse_db_name(self._CONNECTION_PARAMETERS) + # raise an exception if database name is invalid + if not self.url.database: + raise ValueError('database must be specified in MySQL path') + if '/' in self.url.database.strip('/'): + raise ValueError('invalid database path') @classmethod def _normalize_path(cls, path: str) -> str: - parts = cls._parse_connection_string(path) + url = cls._parse_path(path) - path = f'{parts.scheme}://{parts.hostname}:{parts.port or cls.DEFAULT_PORT}{parts.path}' + path = f'{url.drivername}://{url.host}:{url.port or cls.DEFAULT_PORT}/{url.database}' path = path.rstrip('/') return path - @staticmethod - def _parse_db_name(con_params: ParseResult) -> str: - if not con_params.path: - raise ValueError('database must be specified in MySQL path') - - path = con_params.path.strip('/') - if '/' in path: - raise ValueError('invalid database path') - - return path - def _create_database(self) -> None: engine = sqla.create_engine( - f'{self._CONNECTION_PARAMETERS.scheme}+{self.SQL_DRIVER_TYPE}://' - f'{self._CONNECTION_PARAMETERS.netloc}', + self.url.set(database=''), # `.set()` returns a copy with changed parameters echo=False, future=True ) with engine.connect() as connection: - db_name = self._parse_db_name(self._CONNECTION_PARAMETERS) - connection.execute(sqla.text(f'CREATE DATABASE {db_name}')) + connection.execute(sqla.text(f'CREATE DATABASE {self.url.database}')) connection.commit() def _initialize_database( diff --git a/terracotta/drivers/relational_base.py b/terracotta/drivers/relational_base.py index aad330c5..3b77c57f 100644 --- a/terracotta/drivers/relational_base.py +++ b/terracotta/drivers/relational_base.py @@ -17,6 +17,7 @@ import sqlalchemy as sqla import terracotta from sqlalchemy.engine.base import Connection +from sqlalchemy.engine.url import URL from terracotta import exceptions from terracotta.drivers.base import requires_connection from terracotta.drivers.raster_base import RasterDriver @@ -83,14 +84,9 @@ def __init__(self, path: str) -> None: db_connection_timeout: int = settings.DB_CONNECTION_TIMEOUT self.LAZY_LOADING_MAX_SHAPE: Tuple[int, int] = settings.LAZY_LOADING_MAX_SHAPE - assert self.SQL_DRIVER_TYPE is not None - self._CONNECTION_PARAMETERS = self._parse_connection_string(path) - cp = self._CONNECTION_PARAMETERS - resolved_path = self._resolve_path(cp.path[1:]) # Remove the leading '/' - connection_string = f'{cp.scheme}+{self.SQL_DRIVER_TYPE}://{cp.netloc}/{resolved_path}' - + self.url = self._parse_path(path) self.sqla_engine = sqla.create_engine( - connection_string, + self.url, echo=False, future=True, connect_args={self.SQL_TIMEOUT_KEY: db_connection_timeout} @@ -104,32 +100,32 @@ def __init__(self, path: str) -> None: self.db_version_verified: bool = False # use normalized path to make sure username and password don't leak into __repr__ - qualified_path = self._normalize_path(path) - super().__init__(qualified_path) + super().__init__(self._normalize_path(path)) @classmethod - def _parse_connection_string(cls, connection_string: str) -> urlparse.ParseResult: + def _parse_path(cls, connection_string: str) -> URL: if "//" not in connection_string: connection_string = f"//{connection_string}" con_params = urlparse.urlparse(connection_string) - if con_params.scheme == 'file' and cls.FILE_BASED_DATABASE: - file_connection_string = connection_string[len('file://'):] - con_params = urlparse.urlparse(f'{cls.SQL_URL_SCHEME}://{file_connection_string}') - if not con_params.scheme: con_params = urlparse.urlparse(f'{cls.SQL_URL_SCHEME}:{connection_string}') if con_params.scheme != cls.SQL_URL_SCHEME: raise ValueError(f'unsupported URL scheme "{con_params.scheme}"') - return con_params + url = URL.create( + drivername=f'{cls.SQL_URL_SCHEME}+{cls.SQL_DRIVER_TYPE}', + username=con_params.username, + password=con_params.password, + host=con_params.hostname, + port=con_params.port, + database=con_params.path[1:], # remove leading '/' from urlparse + query=con_params.query + ) - @classmethod - def _resolve_path(cls, path: str) -> str: - # Default to do nothing; may be overriden to actually handle file paths according to OS - return path + return url @contextlib.contextmanager def connect(self, verify: bool = True) -> Iterator: diff --git a/terracotta/drivers/sqlite.py b/terracotta/drivers/sqlite.py index 811a0904..79b2a039 100644 --- a/terracotta/drivers/sqlite.py +++ b/terracotta/drivers/sqlite.py @@ -59,11 +59,7 @@ def __init__(self, path: Union[str, Path]) -> None: path: File path to target SQLite database (may or may not exist yet) """ - super().__init__(f'file:///{path}') - - @classmethod - def _resolve_path(cls, path: str) -> str: - return str(Path(path).resolve()) + super().__init__(f'{self.SQL_URL_SCHEME}:///{path}') @classmethod def _normalize_path(cls, path: str) -> str: diff --git a/tests/drivers/test_mysql.py b/tests/drivers/test_mysql.py index 38bd0480..2e4153f5 100644 --- a/tests/drivers/test_mysql.py +++ b/tests/drivers/test_mysql.py @@ -2,19 +2,19 @@ TEST_CASES = { 'mysql://root@localhost:5000/test': dict( - username='root', password=None, hostname='localhost', port=5000, path='/test' + username='root', password=None, host='localhost', port=5000, database='test' ), 'root@localhost:5000/test': dict( - username='root', password=None, hostname='localhost', port=5000, path='/test' + username='root', password=None, host='localhost', port=5000, database='test' ), 'mysql://root:foo@localhost/test': dict( - username='root', password='foo', hostname='localhost', port=None, path='/test' + username='root', password='foo', host='localhost', port=None, database='test' ), 'mysql://localhost/test': dict( - password=None, hostname='localhost', port=None, path='/test' + password=None, host='localhost', port=None, database='test' ), 'localhost/test': dict( - password=None, hostname='localhost', port=None, path='/test' + password=None, host='localhost', port=None, database='test' ) } @@ -32,9 +32,8 @@ def test_path_parsing(case): drivers._DRIVER_CACHE = {} db = drivers.get_driver(case, provider='mysql') - db_args = db._CONNECTION_PARAMETERS - for attr in ('username', 'password', 'hostname', 'port', 'path'): - assert getattr(db_args, attr) == TEST_CASES[case].get(attr, None) + for attr in ('username', 'password', 'host', 'port', 'database'): + assert getattr(db.url, attr) == TEST_CASES[case].get(attr, None) @pytest.mark.parametrize('case', INVALID_TEST_CASES) From 421acb328875223da2e900f17bda9b8862e67103 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 11 Jan 2022 18:47:59 +0100 Subject: [PATCH 050/107] Set mysql driver encoding to uft8mb4 as well --- terracotta/drivers/mysql.py | 2 +- terracotta/drivers/relational_base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/terracotta/drivers/mysql.py b/terracotta/drivers/mysql.py index 6e11d170..6bd6d97b 100644 --- a/terracotta/drivers/mysql.py +++ b/terracotta/drivers/mysql.py @@ -50,7 +50,7 @@ def __init__(self, mysql_path: str) -> None: ``mysql://username:password@hostname/database`` """ - super().__init__(mysql_path) + super().__init__(f'{mysql_path}?charset={self._CHARSET}') self.SQLA_METADATA_TYPE_LOOKUP['text'] = functools.partial(TEXT, charset=self._CHARSET) diff --git a/terracotta/drivers/relational_base.py b/terracotta/drivers/relational_base.py index 3b77c57f..0705a4f3 100644 --- a/terracotta/drivers/relational_base.py +++ b/terracotta/drivers/relational_base.py @@ -122,7 +122,7 @@ def _parse_path(cls, connection_string: str) -> URL: host=con_params.hostname, port=con_params.port, database=con_params.path[1:], # remove leading '/' from urlparse - query=con_params.query + query=dict(urlparse.parse_qsl(con_params.query)) ) return url From f455edd429003a3ce818be20aa0dbc62340771a9 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Thu, 13 Jan 2022 15:52:55 +0100 Subject: [PATCH 051/107] Remove now unused code --- terracotta/drivers/relational_base.py | 2 -- terracotta/drivers/sqlite.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/terracotta/drivers/relational_base.py b/terracotta/drivers/relational_base.py index 0705a4f3..61b71d06 100644 --- a/terracotta/drivers/relational_base.py +++ b/terracotta/drivers/relational_base.py @@ -55,8 +55,6 @@ class RelationalDriver(RasterDriver, ABC): SQL_KEY_SIZE: int SQL_TIMEOUT_KEY: str - FILE_BASED_DATABASE: bool = False - SQLA_STRING = sqla.types.String SQLA_METADATA_TYPE_LOOKUP: Dict[str, sqla.types.TypeEngine] = { 'real': functools.partial(sqla.types.Float, precision=8), diff --git a/terracotta/drivers/sqlite.py b/terracotta/drivers/sqlite.py index 79b2a039..33d9e16f 100644 --- a/terracotta/drivers/sqlite.py +++ b/terracotta/drivers/sqlite.py @@ -47,8 +47,6 @@ class SQLiteDriver(RelationalDriver): SQL_KEY_SIZE = 256 SQL_TIMEOUT_KEY = 'timeout' - FILE_BASED_DATABASE = True - def __init__(self, path: Union[str, Path]) -> None: """Initialize the SQLiteDriver. From da2bc5bd63ca84fa36f64a7de1af357be0dc670f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Mon, 17 Jan 2022 21:11:44 +0100 Subject: [PATCH 052/107] update to most recent COG validate script --- terracotta/cog.py | 37 +++++++++++++++++++++++-------------- tests/test_cog.py | 6 +++--- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/terracotta/cog.py b/terracotta/cog.py index 14470604..08d612e4 100644 --- a/terracotta/cog.py +++ b/terracotta/cog.py @@ -25,7 +25,7 @@ def validate(src_path: str, strict: bool = True) -> bool: def check_raster_file(src_path: str) -> ValidationInfo: # pragma: no cover """ Implementation from - https://github.com/cogeotiff/rio-cogeo/blob/0f00a6ee1eff602014fbc88178a069bd9f4a10da/rio_cogeo/cogeo.py + https://github.com/cogeotiff/rio-cogeo/blob/a07d914e2d898878417638bbc089179f01eb5b28/rio_cogeo/cogeo.py#L385 This function is the rasterio equivalent of https://svn.osgeo.org/gdal/trunk/gdal/swig/python/samples/validate_cloud_optimized_geotiff.py @@ -44,15 +44,13 @@ def check_raster_file(src_path: str) -> ValidationInfo: # pragma: no cover errors.append('The file is not a GeoTIFF') return errors, warnings, details - filelist = [os.path.basename(f) for f in src.files] - src_bname = os.path.basename(src_path) - if len(filelist) > 1 and src_bname + '.ovr' in filelist: + if any(os.path.splitext(x)[-1] == '.ovr' for x in src.files): errors.append( 'Overviews found in external .ovr file. They should be internal' ) overviews = src.overviews(1) - if src.width >= 512 or src.height >= 512: + if src.width > 512 and src.height > 512: if not src.is_tiled: errors.append( 'The file is greater than 512xH or 512xW, but is not tiled' @@ -65,16 +63,28 @@ def check_raster_file(src_path: str) -> ValidationInfo: # pragma: no cover ) ifd_offset = int(src.get_tag_item('IFD_OFFSET', 'TIFF', bidx=1)) - ifd_offsets = [ifd_offset] + # Starting from GDAL 3.1, GeoTIFF and COG have ghost headers + # e.g: + # """ + # GDAL_STRUCTURAL_METADATA_SIZE=000140 bytes + # LAYOUT=IFDS_BEFORE_DATA + # BLOCK_ORDER=ROW_MAJOR + # BLOCK_LEADER=SIZE_AS_UINT4 + # BLOCK_TRAILER=LAST_4_BYTES_REPEATED + # KNOWN_INCOMPATIBLE_EDITION=NO + # """ + # + # This header should be < 200bytes if ifd_offset > 300: errors.append( f'The offset of the main IFD should be < 300. It is {ifd_offset} instead' ) + ifd_offsets = [ifd_offset] details['ifd_offsets'] = {} details['ifd_offsets']['main'] = ifd_offset - if not overviews == sorted(overviews): + if overviews and overviews != sorted(overviews): errors.append('Overviews should be sorted') for ix, dec in enumerate(overviews): @@ -111,9 +121,7 @@ def check_raster_file(src_path: str) -> ValidationInfo: # pragma: no cover ) ) - block_offset = int(src.get_tag_item('BLOCK_OFFSET_0_0', 'TIFF', bidx=1)) - if not block_offset: - errors.append('Missing BLOCK_OFFSET_0_0') + block_offset = src.get_tag_item('BLOCK_OFFSET_0_0', 'TIFF', bidx=1) data_offset = int(block_offset) if block_offset else 0 data_offsets = [data_offset] @@ -121,13 +129,14 @@ def check_raster_file(src_path: str) -> ValidationInfo: # pragma: no cover details['data_offsets']['main'] = data_offset for ix, dec in enumerate(overviews): - data_offset = int( - src.get_tag_item('BLOCK_OFFSET_0_0', 'TIFF', bidx=1, ovr=ix) + block_offset = src.get_tag_item( + 'BLOCK_OFFSET_0_0', 'TIFF', bidx=1, ovr=ix ) + data_offset = int(block_offset) if block_offset else 0 data_offsets.append(data_offset) details['data_offsets']['overview_{}'.format(ix)] = data_offset - if data_offsets[-1] < ifd_offsets[-1]: + if data_offsets[-1] != 0 and data_offsets[-1] < ifd_offsets[-1]: if len(overviews) > 0: errors.append( 'The offset of the first block of the smallest overview ' @@ -156,7 +165,7 @@ def check_raster_file(src_path: str) -> ValidationInfo: # pragma: no cover for ix, dec in enumerate(overviews): with rasterio.open(src_path, OVERVIEW_LEVEL=ix) as ovr_dst: - if ovr_dst.width >= 512 or ovr_dst.height >= 512: + if ovr_dst.width > 512 and ovr_dst.height > 512: if not ovr_dst.is_tiled: errors.append('Overview of index {} is not tiled'.format(ix)) diff --git a/tests/test_cog.py b/tests/test_cog.py index 448b301e..75f6e6aa 100644 --- a/tests/test_cog.py +++ b/tests/test_cog.py @@ -69,12 +69,12 @@ def test_validate_unoptimized(tmpdir): from terracotta import cog outfile = str(tmpdir / 'raster.tif') - raster_data = 1000 * np.random.rand(512, 512).astype(np.uint16) + raster_data = 1000 * np.random.rand(1024, 1024).astype(np.uint16) profile = BASE_PROFILE.copy() profile.update( height=raster_data.shape[0], - width=raster_data.shape[1] + width=raster_data.shape[1], ) with rasterio.open(outfile, 'w', **profile) as dst: @@ -87,7 +87,7 @@ def test_validate_no_overviews(tmpdir): from terracotta import cog outfile = str(tmpdir / 'raster.tif') - raster_data = 1000 * np.random.rand(512, 512).astype(np.uint16) + raster_data = 1000 * np.random.rand(1024, 1024).astype(np.uint16) profile = BASE_PROFILE.copy() profile.update( From c1b7bb0a2081d8a1fb1f522e1801db508c704d8b Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Mon, 31 Jan 2022 10:15:22 +0100 Subject: [PATCH 053/107] Split Driver --- terracotta/drivers/__init__.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/terracotta/drivers/__init__.py b/terracotta/drivers/__init__.py index 19631f28..e50635b5 100644 --- a/terracotta/drivers/__init__.py +++ b/terracotta/drivers/__init__.py @@ -8,12 +8,14 @@ import urllib.parse as urlparse from pathlib import Path -from terracotta.drivers.base import Driver +from terracotta.drivers.base import MetaStore +from terracotta.drivers.driver import TerracottaDriver +from terracotta.drivers.raster_base import RasterDriver URLOrPathType = Union[str, Path] -def load_driver(provider: str) -> Type[Driver]: +def load_driver(provider: str) -> Type[MetaStore]: if provider == 'sqlite-remote': from terracotta.drivers.sqlite_remote import RemoteSQLiteDriver return RemoteSQLiteDriver @@ -42,10 +44,10 @@ def auto_detect_provider(url_or_path: Union[str, Path]) -> str: return 'sqlite' -_DRIVER_CACHE: Dict[Tuple[URLOrPathType, str, int], Driver] = {} +_DRIVER_CACHE: Dict[Tuple[URLOrPathType, str, int], TerracottaDriver] = {} -def get_driver(url_or_path: URLOrPathType, provider: str = None) -> Driver: +def get_driver(url_or_path: URLOrPathType, provider: str = None) -> TerracottaDriver: """Retrieve Terracotta driver instance for the given path. This function always returns the same instance for identical inputs. @@ -85,6 +87,10 @@ def get_driver(url_or_path: URLOrPathType, provider: str = None) -> Driver: cache_key = (normalized_path, provider, os.getpid()) if cache_key not in _DRIVER_CACHE: - _DRIVER_CACHE[cache_key] = DriverClass(url_or_path) + driver = TerracottaDriver( + metastore=DriverClass(url_or_path), + rasterstore=RasterDriver() + ) + _DRIVER_CACHE[cache_key] = driver return _DRIVER_CACHE[cache_key] From d5877596eeaf914149df9559fd78dd5f429a223f Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Mon, 31 Jan 2022 10:15:47 +0100 Subject: [PATCH 054/107] Create TerracottaDriver --- terracotta/drivers/driver.py | 175 +++++++++++++++++++++++++++++++++++ 1 file changed, 175 insertions(+) create mode 100644 terracotta/drivers/driver.py diff --git a/terracotta/drivers/driver.py b/terracotta/drivers/driver.py new file mode 100644 index 00000000..10041c9d --- /dev/null +++ b/terracotta/drivers/driver.py @@ -0,0 +1,175 @@ +import contextlib +import functools +from typing import (Any, Callable, Collection, Dict, Mapping, OrderedDict, + Sequence, Tuple, TypeVar, Union, cast) + +import terracotta +from terracotta import exceptions +from terracotta.drivers.base import (KeysType, MetaStore, MultiValueKeysType, + RasterStore, requires_connection) + +ExtendedKeysType = Union[Sequence[str], KeysType] +T = TypeVar('T') + + +def only_element(iterable: Collection[T]) -> T: + if not iterable: + raise exceptions.DatasetNotFoundError('No dataset found') + assert len(iterable) == 1 + return next(iter(iterable)) + + +def standardize_keys( + fun: Callable[..., T] = None, *, + requires_all_keys: bool = False +) -> Union[Callable[..., T], functools.partial]: + if fun is None: + return functools.partial(standardize_keys, requires_all_keys=requires_all_keys) + + @functools.wraps(fun) + def inner( + self: "TerracottaDriver", + keys: ExtendedKeysType = None, + *args: Any, **kwargs: Any + ) -> T: + if requires_all_keys and (keys is None or len(keys) != len(self.key_names)): + raise exceptions.InvalidKeyError( + f'Got wrong number of keys (available keys: {self.key_names})' + ) + + if isinstance(keys, Mapping): + keys = dict(keys.items()) + elif isinstance(keys, Sequence): + keys = dict(zip(self.key_names, keys)) + elif keys is None: + keys = {} + else: + raise exceptions.InvalidKeyError( + 'Encountered unknown key type, expected Mapping or Sequence' + ) + + if unknown_keys := set(keys) - set(self.key_names): + raise exceptions.InvalidKeyError( + f'Encountered unrecognized keys {unknown_keys} (available keys: {self.key_names})' + ) + + # Apparently mypy thinks fun might still be None, hence the ignore: + return fun(self, keys, *args, **kwargs) # type: ignore + return inner + + +class TerracottaDriver: + + def __init__(self, metastore: MetaStore, rasterstore: RasterStore) -> None: + self.metastore = metastore + self.rasterstore = rasterstore + + settings = terracotta.get_settings() + self.LAZY_LOADING_MAX_SHAPE: Tuple[int, int] = settings.LAZY_LOADING_MAX_SHAPE + + @property + def db_version(self) -> str: + return self.metastore.db_version + + @property + def key_names(self) -> Tuple[str, ...]: + return self.metastore.key_names + + def create(self, keys: Sequence[str], *, + key_descriptions: Mapping[str, str] = None) -> None: + self.metastore.create(keys=keys, key_descriptions=key_descriptions) + + def connect(self, verify: bool = True) -> contextlib.AbstractContextManager: + return self.metastore.connect(verify=verify) + + @requires_connection + def get_keys(self) -> OrderedDict: + return self.metastore.get_keys() + + @requires_connection + @standardize_keys + def get_datasets(self, keys: MultiValueKeysType = None, + page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], Any]: + return self.metastore.get_datasets( + where=keys, + page=page, + limit=limit + ) + + @requires_connection + @standardize_keys(requires_all_keys=True) + def get_metadata(self, keys: ExtendedKeysType) -> Dict[str, Any]: + keys = cast(KeysType, keys) + + metadata = self.metastore.get_metadata(keys) + + if metadata is None: + # metadata is not computed yet, trigger lazy loading + handle = only_element(self.get_datasets(keys).values()) + metadata = self.compute_metadata(handle, max_shape=self.LAZY_LOADING_MAX_SHAPE) + self.insert(keys, handle, metadata=metadata) + + # this is necessary to make the lazy loading tests pass... + metadata = self.metastore.get_metadata(keys) + assert metadata is not None + + return metadata + + @requires_connection + @standardize_keys(requires_all_keys=True) + def insert( + self, keys: ExtendedKeysType, + handle: Any, *, + override_path: str = None, + metadata: Mapping[str, Any] = None, + skip_metadata: bool = False, + **kwargs: Any + ) -> None: + keys = cast(KeysType, keys) + + if metadata is None and not skip_metadata: + metadata = self.compute_metadata(handle) + + self.metastore.insert( + keys=keys, + handle=override_path or handle, + metadata=metadata, + **kwargs + ) + + @requires_connection + @standardize_keys(requires_all_keys=True) + def delete(self, keys: ExtendedKeysType) -> None: + keys = cast(KeysType, keys) + + self.metastore.delete(keys) + + # @standardize_keys(requires_all_keys=True) + def get_raster_tile(self, keys: ExtendedKeysType, *, + tile_bounds: Sequence[float] = None, + tile_size: Sequence[int] = (256, 256), + preserve_values: bool = False, + asynchronous: bool = False) -> Any: + handle = only_element(self.get_datasets(keys).values()) + + return self.rasterstore.get_raster_tile( + handle=handle, + tile_bounds=tile_bounds, + tile_size=tile_size, + preserve_values=preserve_values, + asynchronous=asynchronous, + ) + + def compute_metadata(self, handle: str, *, + extra_metadata: Any = None, + use_chunks: bool = None, + max_shape: Sequence[int] = None) -> Dict[str, Any]: + return self.rasterstore.compute_metadata( + handle=handle, + extra_metadata=extra_metadata, + use_chunks=use_chunks, + max_shape=max_shape, + ) + + def __repr__(self) -> str: + return self.metastore.__repr__() From 959ecbf2275c4338a1d2b63afd4729f13e45125c Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Mon, 31 Jan 2022 10:16:09 +0100 Subject: [PATCH 055/107] Move functionality up into Driver --- terracotta/drivers/raster_base.py | 110 ++++---------------------- terracotta/drivers/relational_base.py | 101 ++++++----------------- 2 files changed, 40 insertions(+), 171 deletions(-) diff --git a/terracotta/drivers/raster_base.py b/terracotta/drivers/raster_base.py index ccbcd669..003a156f 100644 --- a/terracotta/drivers/raster_base.py +++ b/terracotta/drivers/raster_base.py @@ -3,9 +3,8 @@ Base class for drivers operating on physical raster files. """ -from typing import (Any, Callable, Union, Mapping, Sequence, Dict, List, Tuple, - TypeVar, Optional, cast, TYPE_CHECKING) -from abc import abstractmethod +from typing import (Any, Callable, Sequence, Dict, Tuple, + TypeVar, Optional, TYPE_CHECKING) from concurrent.futures import Future, Executor, ProcessPoolExecutor, ThreadPoolExecutor from concurrent.futures.process import BrokenProcessPool @@ -28,7 +27,7 @@ from terracotta import get_settings, exceptions from terracotta.cache import CompressedLFUCache -from terracotta.drivers.base import requires_connection, Driver +from terracotta.drivers.base import RasterStore from terracotta.profile import trace Number = TypeVar('Number', int, float) @@ -76,7 +75,7 @@ def submit_to_executor(task: Callable[..., Any]) -> Future: return future -class RasterDriver(Driver): +class RasterDriver(RasterStore): """Mixin that implements methods to load raster data from disk. get_datasets has to return path to raster file as sole dict value. @@ -88,89 +87,13 @@ class RasterDriver(Driver): GDAL_DISABLE_READDIR_ON_OPEN='EMPTY_DIR' ) - @abstractmethod - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__(self) -> None: settings = get_settings() self._raster_cache = CompressedLFUCache( settings.RASTER_CACHE_SIZE, compression_level=settings.RASTER_CACHE_COMPRESS_LEVEL ) self._cache_lock = threading.RLock() - super().__init__(*args, **kwargs) - - # specify signature and docstring for insert - @abstractmethod - def insert(self, # type: ignore - keys: Union[Sequence[str], Mapping[str, str]], - filepath: str, *, - metadata: Mapping[str, Any] = None, - skip_metadata: bool = False, - override_path: str = None) -> None: - """Insert a raster file into the database. - - Arguments: - - keys: Keys identifying the new dataset. Can either be given as a sequence of key - values, or as a mapping ``{key_name: key_value}``. - filepath: Path to the GDAL-readable raster file. - metadata: If not given (default), call :meth:`compute_metadata` with default arguments - to compute raster metadata. Otherwise, use the given values. This can be used to - decouple metadata computation from insertion, or to use the optional arguments - of :meth:`compute_metadata`. - skip_metadata: Do not compute any raster metadata (will be computed during the first - request instead). Use sparingly; this option has a detrimental result on the end - user experience and might lead to surprising results. Has no effect if ``metadata`` - is given. - override_path: Override the path to the raster file in the database. Use this option if - you intend to copy the data somewhere else after insertion (e.g. when moving files - to a cloud storage later on). - - """ - pass - - # specify signature and docstring for get_datasets - @abstractmethod - def get_datasets(self, where: Mapping[str, Union[str, List[str]]] = None, - page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], str]: - """Retrieve keys and file paths of datasets. - - Arguments: - - where: Constraints on returned datasets in the form ``{key_name: allowed_key_value}``. - Returns all datasets if not given (default). - page: Current page of results. Has no effect if ``limit`` is not given. - limit: If given, return at most this many datasets. Unlimited by default. - - - Returns: - - :class:`dict` containing - ``{(key_value1, key_value2, ...): raster_file_path}`` - - Example: - - >>> import terracotta as tc - >>> driver = tc.get_driver('tc.sqlite') - >>> driver.get_datasets() - { - ('reflectance', '20180101', 'B04'): 'reflectance_20180101_B04.tif', - ('reflectance', '20180102', 'B04'): 'reflectance_20180102_B04.tif', - } - >>> driver.get_datasets({'date': '20180101'}) - {('reflectance', '20180101', 'B04'): 'reflectance_20180101_B04.tif'} - - """ - pass - - def _key_dict_to_sequence(self, keys: Union[Mapping[str, Any], Sequence[Any]]) -> List[Any]: - """Convert {key_name: key_value} to [key_value] with the correct key order.""" - try: - keys_as_mapping = cast(Mapping[str, Any], keys) - return [keys_as_mapping[key] for key in self.key_names] - except TypeError: # not a mapping - return list(keys) - except KeyError as exc: - raise exceptions.InvalidKeyError('Encountered unknown key') from exc @staticmethod def _hull_candidate_mask(mask: np.ndarray) -> np.ndarray: @@ -323,7 +246,7 @@ def _compute_image_stats(dataset: 'DatasetReader', @classmethod @trace('compute_metadata') - def compute_metadata(cls, raster_path: str, *, # type: ignore[override] # noqa: F821 + def compute_metadata(cls, handle: str, *, # noqa: F821 extra_metadata: Any = None, use_chunks: bool = None, max_shape: Sequence[int] = None) -> Dict[str, Any]: @@ -359,18 +282,18 @@ def compute_metadata(cls, raster_path: str, *, # type: ignore[override] # noqa raise ValueError('Cannot use both use_chunks and max_shape arguments') with rasterio.Env(**cls._RIO_ENV_KEYS): - if not validate(raster_path): + if not validate(handle): warnings.warn( - f'Raster file {raster_path} is not a valid cloud-optimized GeoTIFF. ' + f'Raster file {handle} is not a valid cloud-optimized GeoTIFF. ' 'Any interaction with it will be significantly slower. Consider optimizing ' 'it through `terracotta optimize-rasters` before ingestion.', exceptions.PerformanceWarning, stacklevel=3 ) - with rasterio.open(raster_path) as src: + with rasterio.open(handle) as src: if src.nodata is None and not cls._has_alpha_band(src): warnings.warn( - f'Raster file {raster_path} does not have a valid nodata value, ' + f'Raster file {handle} does not have a valid nodata value, ' 'and does not contain an alpha band. No data will be masked.' ) @@ -383,7 +306,7 @@ def compute_metadata(cls, raster_path: str, *, # type: ignore[override] # noqa if use_chunks: logger.debug( - f'Computing metadata for file {raster_path} using more than ' + f'Computing metadata for file {handle} using more than ' f'{RasterDriver._LARGE_RASTER_THRESHOLD // 10**6}M pixels, iterating ' 'over chunks' ) @@ -401,7 +324,7 @@ def compute_metadata(cls, raster_path: str, *, # type: ignore[override] # noqa raster_stats = RasterDriver._compute_image_stats(src, max_shape) if raster_stats is None: - raise ValueError(f'Raster file {raster_path} does not contain any valid data') + raise ValueError(f'Raster file {handle} does not contain any valid data') row_data.update(raster_stats) @@ -541,9 +464,8 @@ def _get_raster_tile(cls, path: str, *, return np.ma.masked_array(tile_data, mask=mask) # return type has to be Any until mypy supports conditional return types - @requires_connection def get_raster_tile(self, - keys: Union[Sequence[str], Mapping[str, str]], *, + handle: str, *, tile_bounds: Sequence[float] = None, tile_size: Sequence[int] = None, preserve_values: bool = False, @@ -555,17 +477,13 @@ def get_raster_tile(self, result: np.ma.MaskedArray settings = get_settings() - key_tuple = tuple(self._key_dict_to_sequence(keys)) - datasets = self.get_datasets(dict(zip(self.key_names, key_tuple))) - assert len(datasets) == 1 - path = datasets[key_tuple] if tile_size is None: tile_size = settings.DEFAULT_TILE_SIZE # make sure all arguments are hashable kwargs = dict( - path=path, + path=handle, tile_bounds=tuple(tile_bounds) if tile_bounds else None, tile_size=tuple(tile_size), preserve_values=preserve_values, diff --git a/terracotta/drivers/relational_base.py b/terracotta/drivers/relational_base.py index 61b71d06..3a37a337 100644 --- a/terracotta/drivers/relational_base.py +++ b/terracotta/drivers/relational_base.py @@ -10,8 +10,8 @@ import urllib.parse as urlparse from abc import ABC, abstractmethod from collections import OrderedDict -from typing import (Any, Dict, Iterator, List, Mapping, Optional, Sequence, - Tuple, Type, Union) +from typing import (Any, Dict, Iterator, Mapping, Optional, Sequence, Tuple, + Type, Union) import numpy as np import sqlalchemy as sqla @@ -19,8 +19,8 @@ from sqlalchemy.engine.base import Connection from sqlalchemy.engine.url import URL from terracotta import exceptions -from terracotta.drivers.base import requires_connection -from terracotta.drivers.raster_base import RasterDriver +from terracotta.drivers.base import (KeysType, MetaStore, MultiValueKeysType, + requires_connection) from terracotta.profile import trace _ERROR_ON_CONNECT = ( @@ -49,7 +49,7 @@ def convert_exceptions( raise exceptions.InvalidDatabaseError(error_message) from exception -class RelationalDriver(RasterDriver, ABC): +class RelationalDriver(MetaStore, ABC): SQL_URL_SCHEME: str # The database flavour, eg mysql, sqlite, etc SQL_DRIVER_TYPE: str # The actual database driver, eg pymysql, sqlite3, etc SQL_KEY_SIZE: int @@ -80,7 +80,6 @@ class RelationalDriver(RasterDriver, ABC): def __init__(self, path: str) -> None: settings = terracotta.get_settings() db_connection_timeout: int = settings.DB_CONNECTION_TIMEOUT - self.LAZY_LOADING_MAX_SHAPE: Tuple[int, int] = settings.LAZY_LOADING_MAX_SHAPE self.url = self._parse_path(path) self.sqla_engine = sqla.create_engine( @@ -289,22 +288,14 @@ def key_names(self) -> Tuple[str, ...]: @convert_exceptions('Could not retrieve datasets') def get_datasets( self, - where: Mapping[str, Union[str, List[str]]] = None, + where: MultiValueKeysType = None, page: int = 0, limit: int = None ) -> Dict[Tuple[str, ...], str]: - # Ensure standardized structure of where items - if where is None: - where = {} - else: - where = dict(where) - - if not all(key in self.key_names for key in where.keys()): - raise exceptions.InvalidKeyError('Encountered unrecognized keys in where clause') - - for key, value in where.items(): - if not isinstance(value, list): - where[key] = [value] + where = { + key: value if isinstance(value, list) else [value] + for key, value in (where or {}).items() + } datasets_table = sqla.Table('datasets', self.sqla_metadata, autoload_with=self.sqla_engine) stmt = ( @@ -332,13 +323,7 @@ def keytuple(row: Dict[str, Any]) -> Tuple[str, ...]: @trace('get_metadata') @requires_connection @convert_exceptions('Could not retrieve metadata') - def get_metadata(self, keys: Union[Sequence[str], Mapping[str, str]]) -> Dict[str, Any]: - keys = tuple(self._key_dict_to_sequence(keys)) - if len(keys) != len(self.key_names): - raise exceptions.InvalidKeyError( - f'Got wrong number of keys (available keys: {self.key_names})' - ) - + def get_metadata(self, keys: KeysType) -> Optional[Dict[str, Any]]: metadata_table = sqla.Table('metadata', self.sqla_metadata, autoload_with=self.sqla_engine) stmt = ( metadata_table @@ -346,25 +331,14 @@ def get_metadata(self, keys: Union[Sequence[str], Mapping[str, str]]) -> Dict[st .where( *[ metadata_table.c.get(key) == value - for key, value in zip(self.key_names, keys) + for key, value in keys.items() ] ) ) row = self.connection.execute(stmt).first() - - if not row: # support lazy loading - filepath = self.get_datasets(dict(zip(self.key_names, keys))) - if not filepath: - raise exceptions.DatasetNotFoundError(f'No dataset found for given keys {keys}') - assert len(filepath) == 1 - - # compute metadata and try again - metadata = self.compute_metadata(filepath[keys], max_shape=self.LAZY_LOADING_MAX_SHAPE) - self.insert(keys, filepath[keys], metadata=metadata) - row = self.connection.execute(stmt).first() - - assert row + if not row: + return None data_columns, _ = zip(*self._METADATA_COLUMNS) encoded_data = {col: row[col] for col in self.key_names + data_columns} @@ -375,78 +349,55 @@ def get_metadata(self, keys: Union[Sequence[str], Mapping[str, str]]) -> Dict[st @convert_exceptions('Could not write to database') def insert( self, - keys: Union[Sequence[str], Mapping[str, str]], - filepath: str, *, - metadata: Mapping[str, Any] = None, - skip_metadata: bool = False, - override_path: str = None + keys: KeysType, + handle: str, *, + metadata: Mapping[str, Any] = None ) -> None: - if len(keys) != len(self.key_names): - raise exceptions.InvalidKeyError( - f'Got wrong number of keys (available keys: {self.key_names})' - ) - - if override_path is None: - override_path = filepath - - keys = self._key_dict_to_sequence(keys) - key_dict = dict(zip(self.key_names, keys)) - datasets_table = sqla.Table('datasets', self.sqla_metadata, autoload_with=self.sqla_engine) metadata_table = sqla.Table('metadata', self.sqla_metadata, autoload_with=self.sqla_engine) self.connection.execute( datasets_table .delete() - .where(*[datasets_table.c.get(column) == value for column, value in key_dict.items()]) + .where(*[datasets_table.c.get(column) == value for column, value in keys.items()]) ) self.connection.execute( - datasets_table.insert().values(**key_dict, filepath=override_path) + datasets_table.insert().values(**keys, filepath=handle) ) - if metadata is None and not skip_metadata: - metadata = self.compute_metadata(filepath) - if metadata is not None: encoded_data = self._encode_data(metadata) self.connection.execute( metadata_table .delete() .where( - *[metadata_table.c.get(column) == value for column, value in key_dict.items()] + *[metadata_table.c.get(column) == value for column, value in keys.items()] ) ) self.connection.execute( - metadata_table.insert().values(**key_dict, **encoded_data) + metadata_table.insert().values(**keys, **encoded_data) ) @trace('delete') @requires_connection @convert_exceptions('Could not write to database') - def delete(self, keys: Union[Sequence[str], Mapping[str, str]]) -> None: - if len(keys) != len(self.key_names): - raise exceptions.InvalidKeyError( - f'Got wrong number of keys (available keys: {self.key_names})' - ) - - keys = self._key_dict_to_sequence(keys) - key_dict = dict(zip(self.key_names, keys)) - - if not self.get_datasets(key_dict): + def delete(self, keys: KeysType) -> None: + if not self.get_datasets(keys): raise exceptions.DatasetNotFoundError(f'No dataset found with keys {keys}') datasets_table = sqla.Table('datasets', self.sqla_metadata, autoload_with=self.sqla_engine) metadata_table = sqla.Table('metadata', self.sqla_metadata, autoload_with=self.sqla_engine) + print(keys) self.connection.execute( datasets_table .delete() - .where(*[datasets_table.c.get(column) == value for column, value in key_dict.items()]) + .where(*[datasets_table.c.get(column) == value for column, value in keys.items()]) ) self.connection.execute( metadata_table .delete() - .where(*[metadata_table.c.get(column) == value for column, value in key_dict.items()]) + .where(*[metadata_table.c.get(column) == value for column, value in keys.items()]) ) @staticmethod From e4fff5d63563e1b31ad6798066a945ac44fe6a2a Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Mon, 31 Jan 2022 10:16:40 +0100 Subject: [PATCH 056/107] Split Driver --- terracotta/drivers/base.py | 89 ++++++++++++++++++++------------------ 1 file changed, 47 insertions(+), 42 deletions(-) diff --git a/terracotta/drivers/base.py b/terracotta/drivers/base.py index 97eeef6a..2c6a6294 100644 --- a/terracotta/drivers/base.py +++ b/terracotta/drivers/base.py @@ -7,9 +7,11 @@ import functools from abc import ABC, abstractmethod from collections import OrderedDict -from typing import (Any, Callable, Dict, List, Mapping, Sequence, Tuple, - TypeVar, Union) +from typing import (Any, Callable, Dict, List, Mapping, Optional, Sequence, + Tuple, TypeVar, Union) +KeysType = Dict[str, str] +MultiValueKeysType = Dict[str, Union[str, List[str]]] Number = TypeVar('Number', int, float) T = TypeVar('T') @@ -22,14 +24,14 @@ def requires_connection( return functools.partial(requires_connection, verify=verify) @functools.wraps(fun) - def inner(self: Driver, *args: Any, **kwargs: Any) -> T: + def inner(self: MetaStore, *args: Any, **kwargs: Any) -> T: with self.connect(verify=verify): # Apparently mypy thinks fun might still be None, hence the ignore: return fun(self, *args, **kwargs) # type: ignore return inner -class Driver(ABC): +class MetaStore(ABC): """Abstract base class for all Terracotta data backends. Defines a common interface for all drivers. @@ -105,14 +107,14 @@ def get_keys(self) -> OrderedDict: pass @abstractmethod - def get_datasets(self, where: Mapping[str, Union[str, List[str]]] = None, + def get_datasets(self, where: MultiValueKeysType = None, page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], Any]: # Get all known dataset key combinations matching the given constraints, # and a handle to retrieve the data (driver dependent) pass @abstractmethod - def get_metadata(self, keys: Union[Sequence[str], Mapping[str, str]]) -> Dict[str, Any]: + def get_metadata(self, keys: KeysType) -> Optional[Dict[str, Any]]: """Return all stored metadata for given keys. Arguments: @@ -136,19 +138,50 @@ def get_metadata(self, keys: Union[Sequence[str], Mapping[str, str]]) -> Dict[st """ pass + @abstractmethod + def insert(self, keys: KeysType, + handle: Any, **kwargs: Any) -> None: + """Register a new dataset. Used to populate metadata database. + + Arguments: + + keys: Keys of the dataset. Can either be given as a sequence of key values, or + as a mapping ``{key_name: key_value}``. + handle: Handle to access dataset (driver dependent). + + """ + pass + + @abstractmethod + def delete(self, keys: KeysType) -> None: + """Remove a dataset from the metadata database. + + Arguments: + + keys: Keys of the dataset. Can either be given as a sequence of key values, or + as a mapping ``{key_name: key_value}``. + + """ + pass + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(\'{self.path}\')' + + +class RasterStore(ABC): + @abstractmethod # TODO: add accurate signature if mypy ever supports conditional return types - def get_raster_tile(self, keys: Union[Sequence[str], Mapping[str, str]], *, + def get_raster_tile(self, handle: str, *, tile_bounds: Sequence[float] = None, tile_size: Sequence[int] = (256, 256), preserve_values: bool = False, asynchronous: bool = False) -> Any: - """Load a raster tile with given keys and bounds. + """Load a raster tile with given handle and bounds. Arguments: - keys: Keys of the requested dataset. Can either be given as a sequence of key values, - or as a mapping ``{key_name: key_value}``. + handle: Handle of the requested dataset. tile_bounds: Physical bounds of the tile to read, in Web Mercator projection (EPSG3857). Reads the whole dataset if not given. tile_size: Shape of the output array to return. Must be two-dimensional. @@ -168,39 +201,11 @@ def get_raster_tile(self, keys: Union[Sequence[str], Mapping[str, str]], *, """ pass - @staticmethod + @classmethod @abstractmethod - def compute_metadata(data: Any, *, + def compute_metadata(cls, handle: str, *, extra_metadata: Any = None, - **kwargs: Any) -> Dict[str, Any]: + use_chunks: bool = None, + max_shape: Sequence[int] = None) -> Dict[str, Any]: # Compute metadata for a given input file (driver dependent) pass - - @abstractmethod - def insert(self, keys: Union[Sequence[str], Mapping[str, str]], - handle: Any, **kwargs: Any) -> None: - """Register a new dataset. Used to populate metadata database. - - Arguments: - - keys: Keys of the dataset. Can either be given as a sequence of key values, or - as a mapping ``{key_name: key_value}``. - handle: Handle to access dataset (driver dependent). - - """ - pass - - @abstractmethod - def delete(self, keys: Union[Sequence[str], Mapping[str, str]]) -> None: - """Remove a dataset from the metadata database. - - Arguments: - - keys: Keys of the dataset. Can either be given as a sequence of key values, or - as a mapping ``{key_name: key_value}``. - - """ - pass - - def __repr__(self) -> str: - return f'{self.__class__.__name__}(\'{self.path}\')' From c936905660fbfabb8ced84587d82a739e37f9df0 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Mon, 31 Jan 2022 10:17:28 +0100 Subject: [PATCH 057/107] Rename accordingly to Driver refactor --- terracotta/handlers/datasets.py | 2 +- terracotta/xyz.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/terracotta/handlers/datasets.py b/terracotta/handlers/datasets.py index 75012bc5..4d458937 100644 --- a/terracotta/handlers/datasets.py +++ b/terracotta/handlers/datasets.py @@ -19,7 +19,7 @@ def datasets(some_keys: Mapping[str, Union[str, List[str]]] = None, with driver.connect(): dataset_keys = driver.get_datasets( - where=some_keys, page=page, limit=limit + keys=some_keys, page=page, limit=limit ).keys() key_names = driver.key_names diff --git a/terracotta/xyz.py b/terracotta/xyz.py index b73253e1..97cedef4 100644 --- a/terracotta/xyz.py +++ b/terracotta/xyz.py @@ -8,11 +8,11 @@ import mercantile from terracotta import exceptions -from terracotta.drivers.base import Driver +from terracotta.drivers.driver import TerracottaDriver # TODO: add accurate signature if mypy ever supports conditional return types -def get_tile_data(driver: Driver, +def get_tile_data(driver: TerracottaDriver, keys: Union[Sequence[str], Mapping[str, str]], tile_xyz: Tuple[int, int, int] = None, *, tile_size: Tuple[int, int] = (256, 256), From 230d08122c75363325e5ee12f34517616b3657f8 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Mon, 31 Jan 2022 10:18:11 +0100 Subject: [PATCH 058/107] Update tests according to Driver refactor --- tests/drivers/test_drivers.py | 12 ++++----- tests/drivers/test_mysql.py | 2 +- tests/drivers/test_raster_drivers.py | 37 ++++++++++++++++------------ tests/drivers/test_sqlite_remote.py | 32 +++++++++++------------- 4 files changed, 43 insertions(+), 40 deletions(-) diff --git a/tests/drivers/test_drivers.py b/tests/drivers/test_drivers.py index e0ec7eb7..b61170e3 100644 --- a/tests/drivers/test_drivers.py +++ b/tests/drivers/test_drivers.py @@ -12,14 +12,14 @@ def test_auto_detect(driver_path, provider): from terracotta import drivers db = drivers.get_driver(driver_path) - assert db.__class__.__name__ == DRIVER_CLASSES[provider] + assert db.metastore.__class__.__name__ == DRIVER_CLASSES[provider] assert drivers.get_driver(driver_path, provider=provider) is db def test_normalize_base(tmpdir): - from terracotta.drivers import Driver + from terracotta.drivers import MetaStore # base class normalize is noop - assert Driver._normalize_path(str(tmpdir)) == str(tmpdir) + assert MetaStore._normalize_path(str(tmpdir)) == str(tmpdir) @pytest.mark.parametrize('provider', ['sqlite']) @@ -177,10 +177,10 @@ def __getattribute__(self, key): with pytest.raises(RuntimeError): with db.connect(): - db.connection = Evanescence() + db.metastore.connection = Evanescence() db.get_keys() - assert not db.connected + assert not db.metastore.connected with db.connect(): db.get_keys() @@ -222,7 +222,7 @@ def test_version_conflict(driver_path, provider, raster_file, monkeypatch): with monkeypatch.context() as m: fake_version = '0.0.0' m.setattr('terracotta.__version__', fake_version) - db.db_version_verified = False + db.metastore.db_version_verified = False with pytest.raises(exceptions.InvalidDatabaseError) as exc: with db.connect(): diff --git a/tests/drivers/test_mysql.py b/tests/drivers/test_mysql.py index 2e4153f5..a86c8477 100644 --- a/tests/drivers/test_mysql.py +++ b/tests/drivers/test_mysql.py @@ -33,7 +33,7 @@ def test_path_parsing(case): db = drivers.get_driver(case, provider='mysql') for attr in ('username', 'password', 'host', 'port', 'database'): - assert getattr(db.url, attr) == TEST_CASES[case].get(attr, None) + assert getattr(db.metastore.url, attr) == TEST_CASES[case].get(attr, None) @pytest.mark.parametrize('case', INVALID_TEST_CASES) diff --git a/tests/drivers/test_raster_drivers.py b/tests/drivers/test_raster_drivers.py index f0b7e1d5..3323c4c0 100644 --- a/tests/drivers/test_raster_drivers.py +++ b/tests/drivers/test_raster_drivers.py @@ -61,18 +61,18 @@ def test_where(driver_path, provider, raster_file): data = db.get_datasets() assert len(data) == 3 - data = db.get_datasets(where=dict(some='some')) + data = db.get_datasets(keys=dict(some='some')) assert len(data) == 2 - data = db.get_datasets(where=dict(some='some', keynames='value')) + data = db.get_datasets(keys=dict(some='some', keynames='value')) assert list(data.keys()) == [('some', 'value')] assert data[('some', 'value')] == str(raster_file) - data = db.get_datasets(where=dict(some='unknown')) + data = db.get_datasets(keys=dict(some='unknown')) assert data == {} with pytest.raises(exceptions.InvalidKeyError) as exc: - db.get_datasets(where=dict(unknown='foo')) + db.get_datasets(keys=dict(unknown='foo')) assert 'unrecognized keys' in str(exc.value) @@ -90,17 +90,17 @@ def test_where_with_multiquery(driver_path, provider, raster_file): data = db.get_datasets() assert len(data) == 3 - data = db.get_datasets(where=dict(some=['some'])) + data = db.get_datasets(keys=dict(some=['some'])) assert len(data) == 2 - data = db.get_datasets(where=dict(keynames=['value', 'other_value'])) + data = db.get_datasets(keys=dict(keynames=['value', 'other_value'])) assert len(data) == 2 - data = db.get_datasets(where=dict(some='some', keynames=['value', 'third_value'])) + data = db.get_datasets(keys=dict(some='some', keynames=['value', 'third_value'])) assert list(data.keys()) == [('some', 'value')] assert data[('some', 'value')] == str(raster_file) - data = db.get_datasets(where=dict(some=['unknown'])) + data = db.get_datasets(keys=dict(some=['unknown'])) assert data == {} @@ -124,7 +124,7 @@ def test_pagination(driver_path, provider, raster_file): data = db.get_datasets(limit=2, page=1) assert len(data) == 1 - data = db.get_datasets(where=dict(some='some'), limit=1, page=0) + data = db.get_datasets(keys=dict(some='some'), limit=1, page=0) assert len(data) == 1 @@ -143,7 +143,12 @@ def test_lazy_loading(driver_path, provider, raster_file): data1 = db.get_metadata(['some', 'value']) data2 = db.get_metadata({'some': 'some', 'keynames': 'other_value'}) - assert list(data1.keys()) == list(data2.keys()) + assert set(data1.keys()) == set(data2.keys()) + for k in data1.keys(): + if not np.all(data1[k] == data2[k]): + print(k) + print(data1[k]) + print(data2[k]) assert all(np.all(data1[k] == data2[k]) for k in data1.keys()) @@ -208,7 +213,7 @@ def test_wrong_key_number(driver_path, provider, raster_file): assert 'wrong number of keys' in str(exc.value) with pytest.raises(exceptions.InvalidKeyError) as exc: - db.insert(['a', 'b'], '') + db.insert(['a', 'b'], '', skip_metadata=True) assert 'wrong number of keys' in str(exc.value) with pytest.raises(exceptions.InvalidKeyError) as exc: @@ -331,7 +336,7 @@ def test_raster_cache(driver_path, provider, raster_file, asynchronous): db.insert(['some', 'value'], str(raster_file)) db.insert(['some', 'other_value'], str(raster_file)) - assert len(db._raster_cache) == 0 + assert len(db.rasterstore._raster_cache) == 0 data1 = db.get_raster_tile(['some', 'value'], tile_size=(256, 256), asynchronous=asynchronous) @@ -339,7 +344,7 @@ def test_raster_cache(driver_path, provider, raster_file, asynchronous): data1 = data1.result() time.sleep(1) # allow callback to finish - assert len(db._raster_cache) == 1 + assert len(db.rasterstore._raster_cache) == 1 data2 = db.get_raster_tile(['some', 'value'], tile_size=(256, 256), asynchronous=asynchronous) @@ -347,7 +352,7 @@ def test_raster_cache(driver_path, provider, raster_file, asynchronous): data2 = data2.result() np.testing.assert_array_equal(data1, data2) - assert len(db._raster_cache) == 1 + assert len(db.rasterstore._raster_cache) == 1 @pytest.mark.parametrize('provider', DRIVERS) @@ -363,7 +368,7 @@ def test_raster_cache_fail(driver_path, provider, raster_file, asynchronous): db.create(keys) db.insert(['some', 'value'], str(raster_file)) - assert len(db._raster_cache) == 0 + assert len(db.rasterstore._raster_cache) == 0 data1 = db.get_raster_tile(['some', 'value'], tile_size=(256, 256), asynchronous=asynchronous) @@ -371,7 +376,7 @@ def test_raster_cache_fail(driver_path, provider, raster_file, asynchronous): data1 = data1.result() time.sleep(1) # allow callback to finish - assert len(db._raster_cache) == 0 + assert len(db.rasterstore._raster_cache) == 0 @pytest.mark.parametrize('provider', DRIVERS) diff --git a/tests/drivers/test_sqlite_remote.py b/tests/drivers/test_sqlite_remote.py index 8f6857ae..e0b2a787 100644 --- a/tests/drivers/test_sqlite_remote.py +++ b/tests/drivers/test_sqlite_remote.py @@ -4,15 +4,13 @@ """ import os +import tempfile import time import uuid -import tempfile from pathlib import Path import pytest -from terracotta.drivers.sqlite_remote import RemoteSQLiteDriver - boto3 = pytest.importorskip('boto3') moto = pytest.importorskip('moto') @@ -92,7 +90,7 @@ def test_invalid_url(): @moto.mock_s3 def test_nonexisting_url(): - from terracotta import get_driver, exceptions + from terracotta import exceptions, get_driver driver = get_driver('s3://foo/db.sqlite') with pytest.raises(exceptions.InvalidDatabaseError): with driver.connect(): @@ -106,33 +104,33 @@ def test_remote_database_cache(s3_db_factory, raster_file, monkeypatch): from terracotta import get_driver - driver: RemoteSQLiteDriver = get_driver(dbpath) - driver._last_updated = -float('inf') + driver = get_driver(dbpath) + driver.metastore._last_updated = -float('inf') with driver.connect(): assert driver.key_names == keys assert driver.get_datasets() == {} - modification_date = os.path.getmtime(driver._local_path) + modification_date = os.path.getmtime(driver.metastore._local_path) s3_db_factory(keys, datasets={('some', 'value'): str(raster_file)}) # no change yet assert driver.get_datasets() == {} - assert os.path.getmtime(driver._local_path) == modification_date + assert os.path.getmtime(driver.metastore._local_path) == modification_date # check if remote db is cached correctly - driver._last_updated = time.time() + driver.metastore._last_updated = time.time() with driver.connect(): # db connection is cached; so still no change assert driver.get_datasets() == {} - assert os.path.getmtime(driver._local_path) == modification_date + assert os.path.getmtime(driver.metastore._local_path) == modification_date # invalidate cache - driver._last_updated = -float('inf') + driver.metastore._last_updated = -float('inf') with driver.connect(): # now db is updated on reconnect assert list(driver.get_datasets().keys()) == [('some', 'value')] - assert os.path.getmtime(driver._local_path) != modification_date + assert os.path.getmtime(driver.metastore._local_path) != modification_date @moto.mock_s3 @@ -161,15 +159,15 @@ def test_destructor(s3_db_factory, raster_file, capsys): from terracotta import get_driver - driver: RemoteSQLiteDriver = get_driver(dbpath) - assert os.path.isfile(driver._local_path) + driver = get_driver(dbpath) + assert os.path.isfile(driver.metastore._local_path) - driver.__del__() - assert not os.path.isfile(driver._local_path) + driver.metastore.__del__() + assert not os.path.isfile(driver.metastore._local_path) captured = capsys.readouterr() assert 'Exception ignored' not in captured.err # re-create file to prevent actual destructor from failing - with open(driver._local_path, 'w'): + with open(driver.metastore._local_path, 'w'): pass From 19b84dc0e4c140ec5b7b04807a2870681d041000 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Mon, 31 Jan 2022 10:31:53 +0100 Subject: [PATCH 059/107] Remove leftover debugging prints --- tests/drivers/test_raster_drivers.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/drivers/test_raster_drivers.py b/tests/drivers/test_raster_drivers.py index 3323c4c0..a265ae22 100644 --- a/tests/drivers/test_raster_drivers.py +++ b/tests/drivers/test_raster_drivers.py @@ -144,11 +144,6 @@ def test_lazy_loading(driver_path, provider, raster_file): data1 = db.get_metadata(['some', 'value']) data2 = db.get_metadata({'some': 'some', 'keynames': 'other_value'}) assert set(data1.keys()) == set(data2.keys()) - for k in data1.keys(): - if not np.all(data1[k] == data2[k]): - print(k) - print(data1[k]) - print(data2[k]) assert all(np.all(data1[k] == data2[k]) for k in data1.keys()) From 5161aec5c17465c5671015e824719b769f7a304e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Mon, 31 Jan 2022 13:11:03 +0100 Subject: [PATCH 060/107] move most logic from raster driver to raster.py module --- .github/workflows/test.yml | 3 +- terracotta/drivers/base.py | 7 +- terracotta/drivers/raster_base.py | 398 ++------------------------- terracotta/raster.py | 385 ++++++++++++++++++++++++++ tests/benchmarks.py | 6 +- tests/drivers/test_raster_drivers.py | 198 ------------- tests/test_raster.py | 201 ++++++++++++++ 7 files changed, 610 insertions(+), 588 deletions(-) create mode 100644 terracotta/raster.py create mode 100644 tests/test_raster.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c14bcd9a..910dd92d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -72,8 +72,7 @@ jobs: - name: Initialize mypy run: | - mypy . > /dev/null || true - mypy --install-types --non-interactive + mypy --install-types --non-interactive . || true - name: Run tests run: | diff --git a/terracotta/drivers/base.py b/terracotta/drivers/base.py index 2c6a6294..49fa031d 100644 --- a/terracotta/drivers/base.py +++ b/terracotta/drivers/base.py @@ -10,8 +10,8 @@ from typing import (Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union) -KeysType = Dict[str, str] -MultiValueKeysType = Dict[str, Union[str, List[str]]] +KeysType = Mapping[str, str] +MultiValueKeysType = Mapping[str, Union[str, List[str]]] Number = TypeVar('Number', int, float) T = TypeVar('T') @@ -201,9 +201,8 @@ def get_raster_tile(self, handle: str, *, """ pass - @classmethod @abstractmethod - def compute_metadata(cls, handle: str, *, + def compute_metadata(self, handle: str, *, extra_metadata: Any = None, use_chunks: bool = None, max_shape: Sequence[int] = None) -> Dict[str, Any]: diff --git a/terracotta/drivers/raster_base.py b/terracotta/drivers/raster_base.py index 003a156f..3fad7181 100644 --- a/terracotta/drivers/raster_base.py +++ b/terracotta/drivers/raster_base.py @@ -3,12 +3,10 @@ Base class for drivers operating on physical raster files. """ -from typing import (Any, Callable, Sequence, Dict, Tuple, - TypeVar, Optional, TYPE_CHECKING) +from typing import Any, Callable, Sequence, Dict, TypeVar from concurrent.futures import Future, Executor, ProcessPoolExecutor, ThreadPoolExecutor from concurrent.futures.process import BrokenProcessPool -import contextlib import functools import logging import warnings @@ -16,19 +14,10 @@ import numpy as np -if TYPE_CHECKING: # pragma: no cover - from rasterio.io import DatasetReader # noqa: F401 - -try: - from crick import TDigest, SummaryStats - has_crick = True -except ImportError: # pragma: no cover - has_crick = False - -from terracotta import get_settings, exceptions +from terracotta import get_settings +from terracotta import raster from terracotta.cache import CompressedLFUCache from terracotta.drivers.base import RasterStore -from terracotta.profile import trace Number = TypeVar('Number', int, float) @@ -82,7 +71,7 @@ class RasterDriver(RasterStore): """ _TARGET_CRS: str = 'epsg:3857' _LARGE_RASTER_THRESHOLD: int = 10980 * 10980 - _RIO_ENV_KEYS = dict( + _RIO_ENV_OPTIONS = dict( GDAL_TIFF_INTERNAL_MASK=True, GDAL_DISABLE_READDIR_ON_OPEN='EMPTY_DIR' ) @@ -95,373 +84,17 @@ def __init__(self) -> None: ) self._cache_lock = threading.RLock() - @staticmethod - def _hull_candidate_mask(mask: np.ndarray) -> np.ndarray: - """Returns a reduced boolean mask to speed up convex hull computations. - - Exploits the fact that only the first and last elements of each row and column - can contribute to the convex hull of a dataset. - """ - assert mask.ndim == 2 - assert mask.dtype == np.bool_ - - nx, ny = mask.shape - out = np.zeros_like(mask) - - # these operations do not short-circuit, but seems to be the best we can do - # NOTE: argmax returns 0 if a slice is all True or all False - first_row = np.argmax(mask, axis=0) - last_row = nx - 1 - np.argmax(mask[::-1, :], axis=0) - first_col = np.argmax(mask, axis=1) - last_col = ny - 1 - np.argmax(mask[:, ::-1], axis=1) - - all_rows = np.arange(nx) - all_cols = np.arange(ny) - - out[first_row, all_cols] = out[last_row, all_cols] = True - out[all_rows, first_col] = out[all_rows, last_col] = True - - # filter all-False slices - out &= mask - - return out - - @staticmethod - def _compute_image_stats_chunked(dataset: 'DatasetReader') -> Optional[Dict[str, Any]]: - """Compute statistics for the given rasterio dataset by looping over chunks.""" - from rasterio import features, warp, windows - from shapely import geometry - - total_count = valid_data_count = 0 - tdigest = TDigest() - sstats = SummaryStats() - convex_hull = geometry.Polygon() - - block_windows = [w for _, w in dataset.block_windows(1)] - - for w in block_windows: - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', message='invalid value encountered.*') - block_data = dataset.read(1, window=w, masked=True) - - # handle NaNs for float rasters - block_data = np.ma.masked_invalid(block_data, copy=False) - - total_count += int(block_data.size) - valid_data = block_data.compressed() - - if valid_data.size == 0: - continue - - valid_data_count += int(valid_data.size) - - if np.any(block_data.mask): - hull_candidates = RasterDriver._hull_candidate_mask(~block_data.mask) - hull_shapes = [geometry.shape(s) for s, _ in features.shapes( - np.ones(hull_candidates.shape, 'uint8'), - mask=hull_candidates, - transform=windows.transform(w, dataset.transform) - )] - else: - w, s, e, n = windows.bounds(w, dataset.transform) - hull_shapes = [geometry.Polygon([(w, s), (e, s), (e, n), (w, n)])] - convex_hull = geometry.MultiPolygon([convex_hull, *hull_shapes]).convex_hull - - tdigest.update(valid_data) - sstats.update(valid_data) - - if sstats.count() == 0: - return None - - convex_hull_wgs = warp.transform_geom( - dataset.crs, 'epsg:4326', geometry.mapping(convex_hull) - ) - - return { - 'valid_percentage': valid_data_count / total_count * 100, - 'range': (sstats.min(), sstats.max()), - 'mean': sstats.mean(), - 'stdev': sstats.std(), - 'percentiles': tdigest.quantile(np.arange(0.01, 1, 0.01)), - 'convex_hull': convex_hull_wgs - } - - @staticmethod - def _compute_image_stats(dataset: 'DatasetReader', - max_shape: Sequence[int] = None) -> Optional[Dict[str, Any]]: - """Compute statistics for the given rasterio dataset by reading it into memory.""" - from rasterio import features, warp, transform - from shapely import geometry - - out_shape = (dataset.height, dataset.width) - - if max_shape is not None: - out_shape = ( - min(max_shape[0], out_shape[0]), - min(max_shape[1], out_shape[1]) - ) - - data_transform = transform.from_bounds( - *dataset.bounds, height=out_shape[0], width=out_shape[1] - ) - raster_data = dataset.read(1, out_shape=out_shape, masked=True) - - if dataset.nodata is not None: - # nodata values might slip into output array if out_shape < dataset.shape - raster_data = np.ma.masked_equal(raster_data, dataset.nodata, copy=False) - - # handle NaNs for float rasters - raster_data = np.ma.masked_invalid(raster_data, copy=False) - - valid_data = raster_data.compressed() - - if valid_data.size == 0: - return None + from rasterio import Env + self._rio_env = Env(**self._RIO_ENV_OPTIONS) - if np.any(raster_data.mask): - hull_candidates = RasterDriver._hull_candidate_mask(~raster_data.mask) - hull_shapes = (geometry.shape(s) for s, _ in features.shapes( - np.ones(hull_candidates.shape, 'uint8'), - mask=hull_candidates, - transform=data_transform - )) - convex_hull = geometry.MultiPolygon(hull_shapes).convex_hull - else: - # no masked entries -> convex hull == dataset bounds - w, s, e, n = dataset.bounds - convex_hull = geometry.Polygon([(w, s), (e, s), (e, n), (w, n)]) - - convex_hull_wgs = warp.transform_geom( - dataset.crs, 'epsg:4326', geometry.mapping(convex_hull) - ) - - return { - 'valid_percentage': valid_data.size / raster_data.size * 100, - 'range': (float(valid_data.min()), float(valid_data.max())), - 'mean': float(valid_data.mean()), - 'stdev': float(valid_data.std()), - 'percentiles': np.percentile(valid_data, np.arange(1, 100)), - 'convex_hull': convex_hull_wgs - } - - @classmethod - @trace('compute_metadata') - def compute_metadata(cls, handle: str, *, # noqa: F821 + def compute_metadata(self, handle: str, *, # noqa: F821 extra_metadata: Any = None, use_chunks: bool = None, max_shape: Sequence[int] = None) -> Dict[str, Any]: - """Read given raster file and compute metadata from it. - - This handles most of the heavy lifting during raster ingestion. The returned metadata can - be passed directly to :meth:`insert`. - - Arguments: - - raster_path: Path to GDAL-readable raster file - extra_metadata: Any additional metadata to attach to the dataset. Will be - JSON-serialized and returned verbatim by :meth:`get_metadata`. - use_chunks: Whether to process the image in chunks (slower, but uses less memory). - If not given, use chunks for large images only. - max_shape: Gives the maximum number of pixels used in each dimension to compute - metadata. Setting this to a relatively small size such as ``(1024, 1024)`` will - result in much faster metadata computation for large images, at the expense of - inaccurate results. - - """ - import rasterio - from rasterio import warp - from terracotta.cog import validate - - row_data: Dict[str, Any] = {} - extra_metadata = extra_metadata or {} - - if max_shape is not None and len(max_shape) != 2: - raise ValueError('max_shape argument must contain 2 values') - - if use_chunks and max_shape is not None: - raise ValueError('Cannot use both use_chunks and max_shape arguments') - - with rasterio.Env(**cls._RIO_ENV_KEYS): - if not validate(handle): - warnings.warn( - f'Raster file {handle} is not a valid cloud-optimized GeoTIFF. ' - 'Any interaction with it will be significantly slower. Consider optimizing ' - 'it through `terracotta optimize-rasters` before ingestion.', - exceptions.PerformanceWarning, stacklevel=3 - ) - - with rasterio.open(handle) as src: - if src.nodata is None and not cls._has_alpha_band(src): - warnings.warn( - f'Raster file {handle} does not have a valid nodata value, ' - 'and does not contain an alpha band. No data will be masked.' - ) - - bounds = warp.transform_bounds( - src.crs, 'epsg:4326', *src.bounds, densify_pts=21 - ) - - if use_chunks is None and max_shape is None: - use_chunks = src.width * src.height > RasterDriver._LARGE_RASTER_THRESHOLD - - if use_chunks: - logger.debug( - f'Computing metadata for file {handle} using more than ' - f'{RasterDriver._LARGE_RASTER_THRESHOLD // 10**6}M pixels, iterating ' - 'over chunks' - ) - - if use_chunks and not has_crick: - warnings.warn( - 'Processing a large raster file, but crick failed to import. ' - 'Reading whole file into memory instead.', exceptions.PerformanceWarning - ) - use_chunks = False - - if use_chunks: - raster_stats = RasterDriver._compute_image_stats_chunked(src) - else: - raster_stats = RasterDriver._compute_image_stats(src, max_shape) - - if raster_stats is None: - raise ValueError(f'Raster file {handle} does not contain any valid data') - - row_data.update(raster_stats) - - row_data['bounds'] = bounds - row_data['metadata'] = extra_metadata - - return row_data - - @staticmethod - def _get_resampling_enum(method: str) -> Any: - from rasterio.enums import Resampling - - if method == 'nearest': - return Resampling.nearest - - if method == 'linear': - return Resampling.bilinear - - if method == 'cubic': - return Resampling.cubic - - if method == 'average': - return Resampling.average - - raise ValueError(f'unknown resampling method {method}') - - @staticmethod - def _has_alpha_band(src: 'DatasetReader') -> bool: - from rasterio.enums import MaskFlags, ColorInterp - return ( - any([MaskFlags.alpha in flags for flags in src.mask_flag_enums]) - or ColorInterp.alpha in src.colorinterp - ) - - @classmethod - @trace('get_raster_tile') - def _get_raster_tile(cls, path: str, *, - reprojection_method: str, - resampling_method: str, - tile_bounds: Tuple[float, float, float, float] = None, - tile_size: Tuple[int, int] = (256, 256), - preserve_values: bool = False) -> np.ma.MaskedArray: - """Load a raster dataset from a file through rasterio. - - Heavily inspired by mapbox/rio-tiler - """ - import rasterio - from rasterio import transform, windows, warp - from rasterio.vrt import WarpedVRT - from affine import Affine - - dst_bounds: Tuple[float, float, float, float] - - if preserve_values: - reproject_enum = resampling_enum = cls._get_resampling_enum('nearest') - else: - reproject_enum = cls._get_resampling_enum(reprojection_method) - resampling_enum = cls._get_resampling_enum(resampling_method) - - with contextlib.ExitStack() as es: - es.enter_context(rasterio.Env(**cls._RIO_ENV_KEYS)) - try: - with trace('open_dataset'): - src = es.enter_context(rasterio.open(path)) - except OSError: - raise IOError('error while reading file {}'.format(path)) - - # compute buonds in target CRS - dst_bounds = warp.transform_bounds(src.crs, cls._TARGET_CRS, *src.bounds) - - if tile_bounds is None: - tile_bounds = dst_bounds - - # prevent loads of very sparse data - cover_ratio = ( - (dst_bounds[2] - dst_bounds[0]) / (tile_bounds[2] - tile_bounds[0]) - * (dst_bounds[3] - dst_bounds[1]) / (tile_bounds[3] - tile_bounds[1]) - ) - - if cover_ratio < 0.01: - raise exceptions.TileOutOfBoundsError('dataset covers less than 1% of tile') - - # compute suggested resolution in target CRS - dst_transform, _, _ = warp.calculate_default_transform( - src.crs, cls._TARGET_CRS, src.width, src.height, *src.bounds - ) - dst_res = (abs(dst_transform.a), abs(dst_transform.e)) - - # make sure VRT resolves the entire tile - tile_transform = transform.from_bounds(*tile_bounds, *tile_size) - tile_res = (abs(tile_transform.a), abs(tile_transform.e)) - - if tile_res[0] < dst_res[0] or tile_res[1] < dst_res[1]: - dst_res = tile_res - resampling_enum = cls._get_resampling_enum('nearest') - - # pad tile bounds to prevent interpolation artefacts - num_pad_pixels = 2 - - # compute tile VRT shape and transform - dst_width = max(1, round((tile_bounds[2] - tile_bounds[0]) / dst_res[0])) - dst_height = max(1, round((tile_bounds[3] - tile_bounds[1]) / dst_res[1])) - vrt_transform = ( - transform.from_bounds(*tile_bounds, width=dst_width, height=dst_height) - * Affine.translation(-num_pad_pixels, -num_pad_pixels) - ) - vrt_height, vrt_width = dst_height + 2 * num_pad_pixels, dst_width + 2 * num_pad_pixels - - # remove padding in output - out_window = windows.Window( - col_off=num_pad_pixels, row_off=num_pad_pixels, width=dst_width, height=dst_height - ) - - # construct VRT - vrt = es.enter_context( - WarpedVRT( - src, crs=cls._TARGET_CRS, resampling=reproject_enum, - transform=vrt_transform, width=vrt_width, height=vrt_height, - add_alpha=not cls._has_alpha_band(src) - ) - ) - - # read data - with warnings.catch_warnings(), trace('read_from_vrt'): - warnings.filterwarnings('ignore', message='invalid value encountered.*') - tile_data = vrt.read( - 1, resampling=resampling_enum, window=out_window, out_shape=tile_size - ) - - # assemble alpha mask - mask_idx = vrt.count - mask = vrt.read(mask_idx, window=out_window, out_shape=tile_size) == 0 - - if src.nodata is not None: - mask |= tile_data == src.nodata - - return np.ma.masked_array(tile_data, mask=mask) + return raster.compute_metadata(handle, extra_metadata=extra_metadata, + use_chunks=use_chunks, max_shape=max_shape, + large_raster_threshold=self._LARGE_RASTER_THRESHOLD, + rio_env=self._rio_env) # return type has to be Any until mypy supports conditional return types def get_raster_tile(self, @@ -471,8 +104,7 @@ def get_raster_tile(self, preserve_values: bool = False, asynchronous: bool = False) -> Any: # This wrapper handles cache interaction and asynchronous tile retrieval. - # The real work is done in _get_raster_tile. - + # The real work is done in terracotta.raster.get_raster_tile. future: Future[np.ma.MaskedArray] result: np.ma.MaskedArray @@ -488,7 +120,9 @@ def get_raster_tile(self, tile_size=tuple(tile_size), preserve_values=preserve_values, reprojection_method=settings.REPROJECTION_METHOD, - resampling_method=settings.RESAMPLING_METHOD + resampling_method=settings.RESAMPLING_METHOD, + target_crs=self._TARGET_CRS, + rio_env=self._rio_env, ) cache_key = hash(tuple(kwargs.items())) @@ -507,7 +141,7 @@ def get_raster_tile(self, else: return result - retrieve_tile = functools.partial(self._get_raster_tile, **kwargs) + retrieve_tile = functools.partial(raster.get_raster_tile, **kwargs) future = submit_to_executor(retrieve_tile) diff --git a/terracotta/raster.py b/terracotta/raster.py new file mode 100644 index 00000000..420b511e --- /dev/null +++ b/terracotta/raster.py @@ -0,0 +1,385 @@ +"""raster.py + +Extract information from raster files through rasterio. +""" + +from typing import Optional, Any, Dict, Tuple, Sequence, TYPE_CHECKING +import contextlib +import warnings +import logging + +import numpy as np + +if TYPE_CHECKING: # pragma: no cover + from rasterio.io import DatasetReader # noqa: F401 + from rasterio import Env + +try: + from crick import TDigest, SummaryStats + has_crick = True +except ImportError: # pragma: no cover + has_crick = False + +from terracotta import exceptions +from terracotta.profile import trace + +logger = logging.getLogger(__name__) + + +def convex_hull_candidate_mask(mask: np.ndarray) -> np.ndarray: + """Returns a reduced boolean mask to speed up convex hull computations. + + Exploits the fact that only the first and last elements of each row and column + can contribute to the convex hull of a dataset. + """ + assert mask.ndim == 2 + assert mask.dtype == np.bool_ + + nx, ny = mask.shape + out = np.zeros_like(mask) + + # these operations do not short-circuit, but seems to be the best we can do + # NOTE: argmax returns 0 if a slice is all True or all False + first_row = np.argmax(mask, axis=0) + last_row = nx - 1 - np.argmax(mask[::-1, :], axis=0) + first_col = np.argmax(mask, axis=1) + last_col = ny - 1 - np.argmax(mask[:, ::-1], axis=1) + + all_rows = np.arange(nx) + all_cols = np.arange(ny) + + out[first_row, all_cols] = out[last_row, all_cols] = True + out[all_rows, first_col] = out[all_rows, last_col] = True + + # filter all-False slices + out &= mask + + return out + + +def compute_image_stats_chunked(dataset: 'DatasetReader') -> Optional[Dict[str, Any]]: + """Compute statistics for the given rasterio dataset by looping over chunks.""" + from rasterio import features, warp, windows + from shapely import geometry + + total_count = valid_data_count = 0 + tdigest = TDigest() + sstats = SummaryStats() + convex_hull = geometry.Polygon() + + block_windows = [w for _, w in dataset.block_windows(1)] + + for w in block_windows: + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', message='invalid value encountered.*') + block_data = dataset.read(1, window=w, masked=True) + + # handle NaNs for float rasters + block_data = np.ma.masked_invalid(block_data, copy=False) + + total_count += int(block_data.size) + valid_data = block_data.compressed() + + if valid_data.size == 0: + continue + + valid_data_count += int(valid_data.size) + + if np.any(block_data.mask): + hull_candidates = convex_hull_candidate_mask(~block_data.mask) + hull_shapes = [geometry.shape(s) for s, _ in features.shapes( + np.ones(hull_candidates.shape, 'uint8'), + mask=hull_candidates, + transform=windows.transform(w, dataset.transform) + )] + else: + w, s, e, n = windows.bounds(w, dataset.transform) + hull_shapes = [geometry.Polygon([(w, s), (e, s), (e, n), (w, n)])] + convex_hull = geometry.MultiPolygon([convex_hull, *hull_shapes]).convex_hull + + tdigest.update(valid_data) + sstats.update(valid_data) + + if sstats.count() == 0: + return None + + convex_hull_wgs = warp.transform_geom( + dataset.crs, 'epsg:4326', geometry.mapping(convex_hull) + ) + + return { + 'valid_percentage': valid_data_count / total_count * 100, + 'range': (sstats.min(), sstats.max()), + 'mean': sstats.mean(), + 'stdev': sstats.std(), + 'percentiles': tdigest.quantile(np.arange(0.01, 1, 0.01)), + 'convex_hull': convex_hull_wgs + } + + +def compute_image_stats(dataset: 'DatasetReader', + max_shape: Sequence[int] = None) -> Optional[Dict[str, Any]]: + """Compute statistics for the given rasterio dataset by reading it into memory.""" + from rasterio import features, warp, transform + from shapely import geometry + + out_shape = (dataset.height, dataset.width) + + if max_shape is not None: + out_shape = ( + min(max_shape[0], out_shape[0]), + min(max_shape[1], out_shape[1]) + ) + + data_transform = transform.from_bounds( + *dataset.bounds, height=out_shape[0], width=out_shape[1] + ) + raster_data = dataset.read(1, out_shape=out_shape, masked=True) + + if dataset.nodata is not None: + # nodata values might slip into output array if out_shape < dataset.shape + raster_data = np.ma.masked_equal(raster_data, dataset.nodata, copy=False) + + # handle NaNs for float rasters + raster_data = np.ma.masked_invalid(raster_data, copy=False) + + valid_data = raster_data.compressed() + + if valid_data.size == 0: + return None + + if np.any(raster_data.mask): + hull_candidates = convex_hull_candidate_mask(~raster_data.mask) + hull_shapes = (geometry.shape(s) for s, _ in features.shapes( + np.ones(hull_candidates.shape, 'uint8'), + mask=hull_candidates, + transform=data_transform + )) + convex_hull = geometry.MultiPolygon(hull_shapes).convex_hull + else: + # no masked entries -> convex hull == dataset bounds + w, s, e, n = dataset.bounds + convex_hull = geometry.Polygon([(w, s), (e, s), (e, n), (w, n)]) + + convex_hull_wgs = warp.transform_geom( + dataset.crs, 'epsg:4326', geometry.mapping(convex_hull) + ) + + return { + 'valid_percentage': valid_data.size / raster_data.size * 100, + 'range': (float(valid_data.min()), float(valid_data.max())), + 'mean': float(valid_data.mean()), + 'stdev': float(valid_data.std()), + 'percentiles': np.percentile(valid_data, np.arange(1, 100)), + 'convex_hull': convex_hull_wgs + } + + +@trace('compute_metadata') +def compute_metadata(handle: str, *, + extra_metadata: Any = None, + use_chunks: bool = None, + max_shape: Sequence[int] = None, + large_raster_threshold: int = None, + rio_env: 'Env' = None) -> Dict[str, Any]: + import rasterio + from rasterio import warp + from terracotta.cog import validate + + row_data: Dict[str, Any] = {} + extra_metadata = extra_metadata or {} + + if max_shape is not None and len(max_shape) != 2: + raise ValueError('max_shape argument must contain 2 values') + + if use_chunks and max_shape is not None: + raise ValueError('Cannot use both use_chunks and max_shape arguments') + + if rio_env is None: + rio_env = rasterio.Env() + + with rio_env: + if not validate(handle): + warnings.warn( + f'Raster file {handle} is not a valid cloud-optimized GeoTIFF. ' + 'Any interaction with it will be significantly slower. Consider optimizing ' + 'it through `terracotta optimize-rasters` before ingestion.', + exceptions.PerformanceWarning, stacklevel=3 + ) + + with rasterio.open(handle) as src: + if src.nodata is None and not has_alpha_band(src): + warnings.warn( + f'Raster file {handle} does not have a valid nodata value, ' + 'and does not contain an alpha band. No data will be masked.' + ) + + bounds = warp.transform_bounds( + src.crs, 'epsg:4326', *src.bounds, densify_pts=21 + ) + + if use_chunks is None and max_shape is None and large_raster_threshold is not None: + use_chunks = src.width * src.height > large_raster_threshold + + if use_chunks: + logger.debug( + f'Computing metadata for file {handle} using more than ' + f'{large_raster_threshold // 10**6}M pixels, iterating ' + 'over chunks' + ) + + if use_chunks and not has_crick: + warnings.warn( + 'Processing a large raster file, but crick failed to import. ' + 'Reading whole file into memory instead.', exceptions.PerformanceWarning + ) + use_chunks = False + + if use_chunks: + raster_stats = compute_image_stats_chunked(src) + else: + raster_stats = compute_image_stats(src, max_shape) + + if raster_stats is None: + raise ValueError(f'Raster file {handle} does not contain any valid data') + + row_data.update(raster_stats) + + row_data['bounds'] = bounds + row_data['metadata'] = extra_metadata + return row_data + + +def get_resampling_enum(method: str) -> Any: + from rasterio.enums import Resampling + + if method == 'nearest': + return Resampling.nearest + + if method == 'linear': + return Resampling.bilinear + + if method == 'cubic': + return Resampling.cubic + + if method == 'average': + return Resampling.average + + raise ValueError(f'unknown resampling method {method}') + + +def has_alpha_band(src: 'DatasetReader') -> bool: + from rasterio.enums import MaskFlags, ColorInterp + return ( + any([MaskFlags.alpha in flags for flags in src.mask_flag_enums]) + or ColorInterp.alpha in src.colorinterp + ) + + +@trace("get_raster_tile") +def get_raster_tile(path: str, *, + reprojection_method: str, + resampling_method: str, + tile_bounds: Tuple[float, float, float, float] = None, + tile_size: Tuple[int, int] = (256, 256), + preserve_values: bool = False, + target_crs: str = 'epsg:3857', + rio_env: 'Env' = None) -> np.ma.MaskedArray: + """Load a raster dataset from a file through rasterio. + + Heavily inspired by mapbox/rio-tiler + """ + import rasterio + from rasterio import transform, windows, warp + from rasterio.vrt import WarpedVRT + from affine import Affine + + dst_bounds: Tuple[float, float, float, float] + + if rio_env is None: + rio_env = rasterio.Env() + + if preserve_values: + reproject_enum = resampling_enum = get_resampling_enum('nearest') + else: + reproject_enum = get_resampling_enum(reprojection_method) + resampling_enum = get_resampling_enum(resampling_method) + + with contextlib.ExitStack() as es: + es.enter_context(rio_env) + try: + with trace('open_dataset'): + src = es.enter_context(rasterio.open(path)) + except OSError: + raise IOError('error while reading file {}'.format(path)) + + # compute buonds in target CRS + dst_bounds = warp.transform_bounds(src.crs, target_crs, *src.bounds) + + if tile_bounds is None: + tile_bounds = dst_bounds + + # prevent loads of very sparse data + cover_ratio = ( + (dst_bounds[2] - dst_bounds[0]) / (tile_bounds[2] - tile_bounds[0]) + * (dst_bounds[3] - dst_bounds[1]) / (tile_bounds[3] - tile_bounds[1]) + ) + + if cover_ratio < 0.01: + raise exceptions.TileOutOfBoundsError('dataset covers less than 1% of tile') + + # compute suggested resolution in target CRS + dst_transform, _, _ = warp.calculate_default_transform( + src.crs, target_crs, src.width, src.height, *src.bounds + ) + dst_res = (abs(dst_transform.a), abs(dst_transform.e)) + + # make sure VRT resolves the entire tile + tile_transform = transform.from_bounds(*tile_bounds, *tile_size) + tile_res = (abs(tile_transform.a), abs(tile_transform.e)) + + if tile_res[0] < dst_res[0] or tile_res[1] < dst_res[1]: + dst_res = tile_res + resampling_enum = get_resampling_enum('nearest') + + # pad tile bounds to prevent interpolation artefacts + num_pad_pixels = 2 + + # compute tile VRT shape and transform + dst_width = max(1, round((tile_bounds[2] - tile_bounds[0]) / dst_res[0])) + dst_height = max(1, round((tile_bounds[3] - tile_bounds[1]) / dst_res[1])) + vrt_transform = ( + transform.from_bounds(*tile_bounds, width=dst_width, height=dst_height) + * Affine.translation(-num_pad_pixels, -num_pad_pixels) + ) + vrt_height, vrt_width = dst_height + 2 * num_pad_pixels, dst_width + 2 * num_pad_pixels + + # remove padding in output + out_window = windows.Window( + col_off=num_pad_pixels, row_off=num_pad_pixels, width=dst_width, height=dst_height + ) + + # construct VRT + vrt = es.enter_context( + WarpedVRT( + src, crs=target_crs, resampling=reproject_enum, + transform=vrt_transform, width=vrt_width, height=vrt_height, + add_alpha=not has_alpha_band(src) + ) + ) + + # read data + with warnings.catch_warnings(), trace('read_from_vrt'): + warnings.filterwarnings('ignore', message='invalid value encountered.*') + tile_data = vrt.read( + 1, resampling=resampling_enum, window=out_window, out_shape=tile_size + ) + + # assemble alpha mask + mask_idx = vrt.count + mask = vrt.read(mask_idx, window=out_window, out_shape=tile_size) == 0 + + if src.nodata is not None: + mask |= tile_data == src.nodata + + return np.ma.masked_array(tile_data, mask=mask) diff --git a/tests/benchmarks.py b/tests/benchmarks.py index c2f64699..88694400 100644 --- a/tests/benchmarks.py +++ b/tests/benchmarks.py @@ -136,12 +136,14 @@ def test_bench_singleband_out_of_bounds(benchmark, benchmark_database): @pytest.mark.parametrize('raster_type', ['nodata', 'masked']) def test_bench_compute_metadata(benchmark, big_raster_file_nodata, big_raster_file_mask, chunks, raster_type): - from terracotta.drivers.raster_base import RasterDriver + from terracotta import raster + if raster_type == 'nodata': raster_file = big_raster_file_nodata elif raster_type == 'masked': raster_file = big_raster_file_mask - benchmark(RasterDriver.compute_metadata, str(raster_file), use_chunks=chunks) + + benchmark(raster.compute_metadata, str(raster_file), use_chunks=chunks) @pytest.mark.parametrize('in_memory', [False, True]) diff --git a/tests/drivers/test_raster_drivers.py b/tests/drivers/test_raster_drivers.py index a265ae22..89274b23 100644 --- a/tests/drivers/test_raster_drivers.py +++ b/tests/drivers/test_raster_drivers.py @@ -3,9 +3,6 @@ import platform import time -import rasterio -import rasterio.features -from shapely.geometry import shape, MultiPolygon import numpy as np DRIVERS = ['sqlite', 'mysql'] @@ -475,201 +472,6 @@ def test_nodata_consistency(driver_path, provider, big_raster_file_mask, big_ras np.testing.assert_array_equal(data_mask.mask, data_nodata.mask) -def geometry_mismatch(shape1, shape2): - """Compute relative mismatch of two shapes""" - return shape1.symmetric_difference(shape2).area / shape1.union(shape2).area - - -def convex_hull_exact(src): - kwargs = dict(bidx=1, band=False, as_mask=True, geographic=True) - - data = src.read() - if np.any(np.isnan(data)) and src.nodata is not None: - # hack: replace NaNs with nodata to make sure they are excluded - with rasterio.MemoryFile() as memfile, memfile.open(**src.profile) as tmpsrc: - data[np.isnan(data)] = src.nodata - tmpsrc.write(data) - dataset_shape = list(rasterio.features.dataset_features(tmpsrc, **kwargs)) - else: - dataset_shape = list(rasterio.features.dataset_features(src, **kwargs)) - - convex_hull = MultiPolygon([shape(s['geometry']) for s in dataset_shape]).convex_hull - return convex_hull - - -@pytest.mark.parametrize('use_chunks', [True, False]) -@pytest.mark.parametrize('nodata_type', ['nodata', 'masked', 'none', 'nan']) -def test_compute_metadata(big_raster_file_nodata, big_raster_file_nomask, - big_raster_file_mask, raster_file_float, nodata_type, use_chunks): - from terracotta.drivers.raster_base import RasterDriver - - if nodata_type == 'nodata': - raster_file = big_raster_file_nodata - elif nodata_type == 'masked': - raster_file = big_raster_file_mask - elif nodata_type == 'none': - raster_file = big_raster_file_nomask - elif nodata_type == 'nan': - raster_file = raster_file_float - - if use_chunks: - pytest.importorskip('crick') - - with rasterio.open(str(raster_file)) as src: - data = src.read(1, masked=True) - valid_data = np.ma.masked_invalid(data).compressed() - convex_hull = convex_hull_exact(src) - - # compare - if nodata_type == 'none': - with pytest.warns(UserWarning) as record: - mtd = RasterDriver.compute_metadata(str(raster_file), use_chunks=use_chunks) - assert 'does not have a valid nodata value' in str(record[0].message) - else: - mtd = RasterDriver.compute_metadata(str(raster_file), use_chunks=use_chunks) - - np.testing.assert_allclose(mtd['valid_percentage'], 100 * valid_data.size / data.size) - np.testing.assert_allclose(mtd['range'], (valid_data.min(), valid_data.max())) - np.testing.assert_allclose(mtd['mean'], valid_data.mean()) - np.testing.assert_allclose(mtd['stdev'], valid_data.std()) - - # allow some error margin since we only compute approximate quantiles - np.testing.assert_allclose( - mtd['percentiles'], - np.percentile(valid_data, np.arange(1, 100)), - rtol=2e-2, atol=valid_data.max() / 100 - ) - - assert geometry_mismatch(shape(mtd['convex_hull']), convex_hull) < 1e-6 - - -@pytest.mark.parametrize('nodata_type', ['nodata', 'masked', 'none', 'nan']) -def test_compute_metadata_approximate(nodata_type, big_raster_file_nodata, big_raster_file_mask, - big_raster_file_nomask, raster_file_float): - from terracotta.drivers.raster_base import RasterDriver - - if nodata_type == 'nodata': - raster_file = big_raster_file_nodata - elif nodata_type == 'masked': - raster_file = big_raster_file_mask - elif nodata_type == 'none': - raster_file = big_raster_file_nomask - elif nodata_type == 'nan': - raster_file = raster_file_float - - with rasterio.open(str(raster_file)) as src: - data = src.read(1, masked=True) - valid_data = np.ma.masked_invalid(data).compressed() - convex_hull = convex_hull_exact(src) - - # compare - if nodata_type == 'none': - with pytest.warns(UserWarning) as record: - mtd = RasterDriver.compute_metadata(str(raster_file), max_shape=(512, 512)) - assert 'does not have a valid nodata value' in str(record[0].message) - else: - mtd = RasterDriver.compute_metadata(str(raster_file), max_shape=(512, 512)) - - np.testing.assert_allclose(mtd['valid_percentage'], 100 * valid_data.size / data.size, atol=1) - np.testing.assert_allclose( - mtd['range'], (valid_data.min(), valid_data.max()), atol=valid_data.max() / 100 - ) - np.testing.assert_allclose(mtd['mean'], valid_data.mean(), rtol=0.02) - np.testing.assert_allclose(mtd['stdev'], valid_data.std(), rtol=0.02) - - np.testing.assert_allclose( - mtd['percentiles'], - np.percentile(valid_data, np.arange(1, 100)), - atol=valid_data.max() / 100, rtol=0.02 - ) - - assert geometry_mismatch(shape(mtd['convex_hull']), convex_hull) < 0.05 - - -def test_compute_metadata_invalid_options(big_raster_file_nodata): - from terracotta.drivers.raster_base import RasterDriver - - with pytest.raises(ValueError): - RasterDriver.compute_metadata( - str(big_raster_file_nodata), max_shape=(256, 256), use_chunks=True - ) - - with pytest.raises(ValueError): - RasterDriver.compute_metadata(str(big_raster_file_nodata), max_shape=(256, 256, 1)) - - -@pytest.mark.parametrize('use_chunks', [True, False]) -def test_compute_metadata_invalid_raster(invalid_raster_file, use_chunks): - from terracotta.drivers.raster_base import RasterDriver - - if use_chunks: - pytest.importorskip('crick') - - with pytest.raises(ValueError): - RasterDriver.compute_metadata(str(invalid_raster_file), use_chunks=use_chunks) - - -def test_compute_metadata_nocrick(big_raster_file_nodata, monkeypatch): - with rasterio.open(str(big_raster_file_nodata)) as src: - data = src.read(1, masked=True) - valid_data = np.ma.masked_invalid(data).compressed() - convex_hull = convex_hull_exact(src) - - from terracotta import exceptions - import terracotta.drivers.raster_base - - with monkeypatch.context() as m: - m.setattr(terracotta.drivers.raster_base, 'has_crick', False) - - with pytest.warns(exceptions.PerformanceWarning): - mtd = terracotta.drivers.raster_base.RasterDriver.compute_metadata( - str(big_raster_file_nodata), use_chunks=True - ) - - # compare - np.testing.assert_allclose(mtd['valid_percentage'], 100 * valid_data.size / data.size) - np.testing.assert_allclose(mtd['range'], (valid_data.min(), valid_data.max())) - np.testing.assert_allclose(mtd['mean'], valid_data.mean()) - np.testing.assert_allclose(mtd['stdev'], valid_data.std()) - - # allow error of 1%, since we only compute approximate quantiles - np.testing.assert_allclose( - mtd['percentiles'], - np.percentile(valid_data, np.arange(1, 100)), - rtol=2e-2 - ) - - assert geometry_mismatch(shape(mtd['convex_hull']), convex_hull) < 1e-6 - - -def test_compute_metadata_unoptimized(unoptimized_raster_file): - from terracotta import exceptions - from terracotta.drivers.raster_base import RasterDriver - - with rasterio.open(str(unoptimized_raster_file)) as src: - data = src.read(1, masked=True) - valid_data = np.ma.masked_invalid(data).compressed() - convex_hull = convex_hull_exact(src) - - # compare - with pytest.warns(exceptions.PerformanceWarning): - mtd = RasterDriver.compute_metadata(str(unoptimized_raster_file), use_chunks=False) - - np.testing.assert_allclose(mtd['valid_percentage'], 100 * valid_data.size / data.size) - np.testing.assert_allclose(mtd['range'], (valid_data.min(), valid_data.max())) - np.testing.assert_allclose(mtd['mean'], valid_data.mean()) - np.testing.assert_allclose(mtd['stdev'], valid_data.std()) - - # allow some error margin since we only compute approximate quantiles - np.testing.assert_allclose( - mtd['percentiles'], - np.percentile(valid_data, np.arange(1, 100)), - rtol=2e-2 - ) - - assert geometry_mismatch(shape(mtd['convex_hull']), convex_hull) < 1e-6 - - @pytest.mark.parametrize('provider', DRIVERS) def test_broken_process_pool(driver_path, provider, raster_file): import concurrent.futures diff --git a/tests/test_raster.py b/tests/test_raster.py new file mode 100644 index 00000000..69150c91 --- /dev/null +++ b/tests/test_raster.py @@ -0,0 +1,201 @@ +import pytest + +import numpy as np +import rasterio +import rasterio.features +from shapely.geometry import shape, MultiPolygon + + +def geometry_mismatch(shape1, shape2): + """Compute relative mismatch of two shapes""" + return shape1.symmetric_difference(shape2).area / shape1.union(shape2).area + + +def convex_hull_exact(src): + kwargs = dict(bidx=1, band=False, as_mask=True, geographic=True) + + data = src.read() + if np.any(np.isnan(data)) and src.nodata is not None: + # hack: replace NaNs with nodata to make sure they are excluded + with rasterio.MemoryFile() as memfile, memfile.open(**src.profile) as tmpsrc: + data[np.isnan(data)] = src.nodata + tmpsrc.write(data) + dataset_shape = list(rasterio.features.dataset_features(tmpsrc, **kwargs)) + else: + dataset_shape = list(rasterio.features.dataset_features(src, **kwargs)) + + convex_hull = MultiPolygon([shape(s['geometry']) for s in dataset_shape]).convex_hull + return convex_hull + + +@pytest.mark.parametrize('use_chunks', [True, False]) +@pytest.mark.parametrize('nodata_type', ['nodata', 'masked', 'none', 'nan']) +def test_compute_metadata(big_raster_file_nodata, big_raster_file_nomask, + big_raster_file_mask, raster_file_float, nodata_type, use_chunks): + from terracotta import raster + + if nodata_type == 'nodata': + raster_file = big_raster_file_nodata + elif nodata_type == 'masked': + raster_file = big_raster_file_mask + elif nodata_type == 'none': + raster_file = big_raster_file_nomask + elif nodata_type == 'nan': + raster_file = raster_file_float + + if use_chunks: + pytest.importorskip('crick') + + with rasterio.open(str(raster_file)) as src: + data = src.read(1, masked=True) + valid_data = np.ma.masked_invalid(data).compressed() + convex_hull = convex_hull_exact(src) + + # compare + if nodata_type == 'none': + with pytest.warns(UserWarning) as record: + mtd = raster.compute_metadata(str(raster_file), use_chunks=use_chunks) + assert 'does not have a valid nodata value' in str(record[0].message) + else: + mtd = raster.compute_metadata(str(raster_file), use_chunks=use_chunks) + + np.testing.assert_allclose(mtd['valid_percentage'], 100 * valid_data.size / data.size) + np.testing.assert_allclose(mtd['range'], (valid_data.min(), valid_data.max())) + np.testing.assert_allclose(mtd['mean'], valid_data.mean()) + np.testing.assert_allclose(mtd['stdev'], valid_data.std()) + + # allow some error margin since we only compute approximate quantiles + np.testing.assert_allclose( + mtd['percentiles'], + np.percentile(valid_data, np.arange(1, 100)), + rtol=2e-2, atol=valid_data.max() / 100 + ) + + assert geometry_mismatch(shape(mtd['convex_hull']), convex_hull) < 1e-6 + + +@pytest.mark.parametrize('nodata_type', ['nodata', 'masked', 'none', 'nan']) +def test_compute_metadata_approximate(nodata_type, big_raster_file_nodata, big_raster_file_mask, + big_raster_file_nomask, raster_file_float): + from terracotta import raster + + if nodata_type == 'nodata': + raster_file = big_raster_file_nodata + elif nodata_type == 'masked': + raster_file = big_raster_file_mask + elif nodata_type == 'none': + raster_file = big_raster_file_nomask + elif nodata_type == 'nan': + raster_file = raster_file_float + + with rasterio.open(str(raster_file)) as src: + data = src.read(1, masked=True) + valid_data = np.ma.masked_invalid(data).compressed() + convex_hull = convex_hull_exact(src) + + # compare + if nodata_type == 'none': + with pytest.warns(UserWarning) as record: + mtd = raster.compute_metadata(str(raster_file), max_shape=(512, 512)) + assert 'does not have a valid nodata value' in str(record[0].message) + else: + mtd = raster.compute_metadata(str(raster_file), max_shape=(512, 512)) + + np.testing.assert_allclose(mtd['valid_percentage'], 100 * valid_data.size / data.size, atol=1) + np.testing.assert_allclose( + mtd['range'], (valid_data.min(), valid_data.max()), atol=valid_data.max() / 100 + ) + np.testing.assert_allclose(mtd['mean'], valid_data.mean(), rtol=0.02) + np.testing.assert_allclose(mtd['stdev'], valid_data.std(), rtol=0.02) + + np.testing.assert_allclose( + mtd['percentiles'], + np.percentile(valid_data, np.arange(1, 100)), + atol=valid_data.max() / 100, rtol=0.02 + ) + + assert geometry_mismatch(shape(mtd['convex_hull']), convex_hull) < 0.05 + + +def test_compute_metadata_invalid_options(big_raster_file_nodata): + from terracotta import raster + + with pytest.raises(ValueError): + raster.compute_metadata( + str(big_raster_file_nodata), max_shape=(256, 256), use_chunks=True + ) + + with pytest.raises(ValueError): + raster.compute_metadata(str(big_raster_file_nodata), max_shape=(256, 256, 1)) + + +@pytest.mark.parametrize('use_chunks', [True, False]) +def test_compute_metadata_invalid_raster(invalid_raster_file, use_chunks): + from terracotta import raster + + if use_chunks: + pytest.importorskip('crick') + + with pytest.raises(ValueError): + raster.compute_metadata(str(invalid_raster_file), use_chunks=use_chunks) + + +def test_compute_metadata_nocrick(big_raster_file_nodata, monkeypatch): + with rasterio.open(str(big_raster_file_nodata)) as src: + data = src.read(1, masked=True) + valid_data = np.ma.masked_invalid(data).compressed() + convex_hull = convex_hull_exact(src) + + from terracotta import exceptions + import terracotta.drivers.raster_base + + with monkeypatch.context() as m: + m.setattr(terracotta.raster, 'has_crick', False) + + with pytest.warns(exceptions.PerformanceWarning): + mtd = terracotta.drivers.raster_base.raster.compute_metadata( + str(big_raster_file_nodata), use_chunks=True + ) + + # compare + np.testing.assert_allclose(mtd['valid_percentage'], 100 * valid_data.size / data.size) + np.testing.assert_allclose(mtd['range'], (valid_data.min(), valid_data.max())) + np.testing.assert_allclose(mtd['mean'], valid_data.mean()) + np.testing.assert_allclose(mtd['stdev'], valid_data.std()) + + # allow error of 1%, since we only compute approximate quantiles + np.testing.assert_allclose( + mtd['percentiles'], + np.percentile(valid_data, np.arange(1, 100)), + rtol=2e-2 + ) + + assert geometry_mismatch(shape(mtd['convex_hull']), convex_hull) < 1e-6 + + +def test_compute_metadata_unoptimized(unoptimized_raster_file): + from terracotta import exceptions + from terracotta import raster + + with rasterio.open(str(unoptimized_raster_file)) as src: + data = src.read(1, masked=True) + valid_data = np.ma.masked_invalid(data).compressed() + convex_hull = convex_hull_exact(src) + + # compare + with pytest.warns(exceptions.PerformanceWarning): + mtd = raster.compute_metadata(str(unoptimized_raster_file), use_chunks=False) + + np.testing.assert_allclose(mtd['valid_percentage'], 100 * valid_data.size / data.size) + np.testing.assert_allclose(mtd['range'], (valid_data.min(), valid_data.max())) + np.testing.assert_allclose(mtd['mean'], valid_data.mean()) + np.testing.assert_allclose(mtd['stdev'], valid_data.std()) + + # allow some error margin since we only compute approximate quantiles + np.testing.assert_allclose( + mtd['percentiles'], + np.percentile(valid_data, np.arange(1, 100)), + rtol=2e-2 + ) + + assert geometry_mismatch(shape(mtd['convex_hull']), convex_hull) < 1e-6 From 9477bf79aa6a16e9862be367a88649b78330f012 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Mon, 31 Jan 2022 13:27:10 +0100 Subject: [PATCH 061/107] go straight to :walrus: jail --- terracotta/drivers/driver.py | 3 ++- tests/benchmarks.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/terracotta/drivers/driver.py b/terracotta/drivers/driver.py index 10041c9d..061a5eaa 100644 --- a/terracotta/drivers/driver.py +++ b/terracotta/drivers/driver.py @@ -48,7 +48,8 @@ def inner( 'Encountered unknown key type, expected Mapping or Sequence' ) - if unknown_keys := set(keys) - set(self.key_names): + unknown_keys = set(keys) - set(self.key_names) + if unknown_keys: raise exceptions.InvalidKeyError( f'Encountered unrecognized keys {unknown_keys} (available keys: {self.key_names})' ) diff --git a/tests/benchmarks.py b/tests/benchmarks.py index 88694400..786b00f6 100644 --- a/tests/benchmarks.py +++ b/tests/benchmarks.py @@ -114,7 +114,7 @@ def test_bench_singleband(benchmark, zoom, resampling, big_raster_file_nodata, b rv = benchmark(client.get, '/singleband/nodata/1/preview.png') assert rv.status_code == 200 - assert not len(get_driver(str(benchmark_database))._raster_cache) + assert not len(get_driver(str(benchmark_database)).rasterstore._raster_cache) def test_bench_singleband_out_of_bounds(benchmark, benchmark_database): From 56256ac19dcef70c84f7d8645c40b5b31b7ce3d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Mon, 31 Jan 2022 13:38:40 +0100 Subject: [PATCH 062/107] ... and to py3.6 jail --- terracotta/drivers/driver.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/terracotta/drivers/driver.py b/terracotta/drivers/driver.py index 061a5eaa..8787542c 100644 --- a/terracotta/drivers/driver.py +++ b/terracotta/drivers/driver.py @@ -1,7 +1,8 @@ +from typing import (Any, Callable, Collection, Dict, Mapping, + Sequence, Tuple, TypeVar, Union, cast) import contextlib import functools -from typing import (Any, Callable, Collection, Dict, Mapping, OrderedDict, - Sequence, Tuple, TypeVar, Union, cast) +from collections import OrderedDict import terracotta from terracotta import exceptions From afd865ae19c91bd8d6ffe2c5ccd3ef44a1c5557a Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Mon, 31 Jan 2022 15:15:32 +0100 Subject: [PATCH 063/107] Add test for key standardization --- tests/drivers/test_drivers.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/drivers/test_drivers.py b/tests/drivers/test_drivers.py index b61170e3..88f26faf 100644 --- a/tests/drivers/test_drivers.py +++ b/tests/drivers/test_drivers.py @@ -229,3 +229,34 @@ def test_version_conflict(driver_path, provider, raster_file, monkeypatch): pass assert fake_version in str(exc.value) + + +@pytest.mark.parametrize('provider', TESTABLE_DRIVERS) +def test_invalid_key_types(driver_path, provider): + from terracotta import drivers, exceptions + + db = drivers.get_driver(driver_path, provider) + keys = ('some', 'keys') + + db.create(keys) + + db.get_datasets() + db.get_datasets(['a', 'b']) + db.get_datasets({'some': 'a', 'keys': 'b'}) + db.get_datasets(None) + + with pytest.raises(exceptions.InvalidKeyError) as exc: + db.get_datasets(45) + assert 'unknown key type' in str(exc) + + with pytest.raises(exceptions.InvalidKeyError) as exc: + db.delete(['a']) + assert 'wrong number of keys' in str(exc) + + with pytest.raises(exceptions.InvalidKeyError) as exc: + db.delete(None) + assert 'wrong number of keys' in str(exc) + + with pytest.raises(exceptions.InvalidKeyError) as exc: + db.get_datasets({'not-a-key': 'val'}) + assert 'unrecognized keys' in str(exc) From 6e5b95aff41f3907ac4cf124e1bcfc09ec45358b Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Mon, 31 Jan 2022 15:38:42 +0100 Subject: [PATCH 064/107] Test raster retrieval with all resampling methods --- tests/drivers/test_raster_drivers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/drivers/test_raster_drivers.py b/tests/drivers/test_raster_drivers.py index 89274b23..f7b2b1c9 100644 --- a/tests/drivers/test_raster_drivers.py +++ b/tests/drivers/test_raster_drivers.py @@ -299,7 +299,11 @@ def test_insertion_invalid_raster(driver_path, provider, invalid_raster_file): @pytest.mark.parametrize('provider', DRIVERS) -def test_raster_retrieval(driver_path, provider, raster_file): +@pytest.mark.parametrize('resampling_method', ['nearest', 'linear', 'cubic', 'average']) +def test_raster_retrieval(driver_path, provider, raster_file, resampling_method): + import terracotta + terracotta.update_settings(RESAMPLING_METHOD=resampling_method) + from terracotta import drivers db = drivers.get_driver(driver_path, provider=provider) keys = ('some', 'keynames') From 37cb0d0c4bcc5f84d44a6c02b52e22f52e1797a5 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Mon, 31 Jan 2022 15:51:12 +0100 Subject: [PATCH 065/107] Add test for raster.get_raster_tile --- tests/test_raster.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_raster.py b/tests/test_raster.py index 69150c91..0a92ddce 100644 --- a/tests/test_raster.py +++ b/tests/test_raster.py @@ -199,3 +199,18 @@ def test_compute_metadata_unoptimized(unoptimized_raster_file): ) assert geometry_mismatch(shape(mtd['convex_hull']), convex_hull) < 1e-6 + + +@pytest.mark.parametrize('preserve_values', [True, False]) +@pytest.mark.parametrize('resampling_method', ['nearest', 'linear', 'cubic', 'average']) +def test_get_raster_tile(raster_file, preserve_values, resampling_method): + from terracotta import raster + + data = raster.get_raster_tile( + str(raster_file), + reprojection_method=resampling_method, + resampling_method=resampling_method, + preserve_values=preserve_values, + tile_size=(256, 256) + ) + assert data.shape == (256, 256) From 76185e3e192910a2707e7bc387793f4d0087c4fc Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Mon, 31 Jan 2022 16:02:23 +0100 Subject: [PATCH 066/107] Test unknown resampling method --- tests/test_raster.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_raster.py b/tests/test_raster.py index 0a92ddce..741ff36b 100644 --- a/tests/test_raster.py +++ b/tests/test_raster.py @@ -214,3 +214,11 @@ def test_get_raster_tile(raster_file, preserve_values, resampling_method): tile_size=(256, 256) ) assert data.shape == (256, 256) + + +def test_invalid_resampling_method(): + from terracotta import raster + + with pytest.raises(ValueError) as exc: + raster.get_resampling_enum('not-a-resampling-method') + assert 'unknown resampling method' in str(exc) From 0d1096f047fdda4a57bdc42d423ff36d93a37563 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Mon, 31 Jan 2022 16:08:35 +0100 Subject: [PATCH 067/107] Test raster.get_metadata with large_raster_threshold exceeded --- tests/test_raster.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/test_raster.py b/tests/test_raster.py index 741ff36b..b4fb5883 100644 --- a/tests/test_raster.py +++ b/tests/test_raster.py @@ -28,10 +28,11 @@ def convex_hull_exact(src): return convex_hull -@pytest.mark.parametrize('use_chunks', [True, False]) +@pytest.mark.parametrize('large_raster_threshold', [None, 0]) +@pytest.mark.parametrize('use_chunks', [True, False, None]) @pytest.mark.parametrize('nodata_type', ['nodata', 'masked', 'none', 'nan']) -def test_compute_metadata(big_raster_file_nodata, big_raster_file_nomask, - big_raster_file_mask, raster_file_float, nodata_type, use_chunks): +def test_compute_metadata(big_raster_file_nodata, big_raster_file_nomask, big_raster_file_mask, + raster_file_float, nodata_type, use_chunks, large_raster_threshold): from terracotta import raster if nodata_type == 'nodata': @@ -54,10 +55,18 @@ def test_compute_metadata(big_raster_file_nodata, big_raster_file_nomask, # compare if nodata_type == 'none': with pytest.warns(UserWarning) as record: - mtd = raster.compute_metadata(str(raster_file), use_chunks=use_chunks) + mtd = raster.compute_metadata( + str(raster_file), + use_chunks=use_chunks, + large_raster_threshold=large_raster_threshold + ) assert 'does not have a valid nodata value' in str(record[0].message) else: - mtd = raster.compute_metadata(str(raster_file), use_chunks=use_chunks) + mtd = raster.compute_metadata( + str(raster_file), + use_chunks=use_chunks, + large_raster_threshold=large_raster_threshold + ) np.testing.assert_allclose(mtd['valid_percentage'], 100 * valid_data.size / data.size) np.testing.assert_allclose(mtd['range'], (valid_data.min(), valid_data.max())) From 551ae7c9e4d9cea84eebf45a73f065ee8b076d67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Tue, 1 Feb 2022 12:23:03 +0100 Subject: [PATCH 068/107] bump coverage --- terracotta/raster.py | 7 ++++--- tests/test_raster.py | 23 +++++++++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/terracotta/raster.py b/terracotta/raster.py index 420b511e..6d887db0 100644 --- a/terracotta/raster.py +++ b/terracotta/raster.py @@ -278,8 +278,8 @@ def has_alpha_band(src: 'DatasetReader') -> bool: @trace("get_raster_tile") def get_raster_tile(path: str, *, - reprojection_method: str, - resampling_method: str, + reprojection_method: str = "nearest", + resampling_method: str = "nearest", tile_bounds: Tuple[float, float, float, float] = None, tile_size: Tuple[int, int] = (256, 256), preserve_values: bool = False, @@ -334,7 +334,8 @@ def get_raster_tile(path: str, *, ) dst_res = (abs(dst_transform.a), abs(dst_transform.e)) - # make sure VRT resolves the entire tile + # in some cases (e.g. at extreme latitudes), the default transform + # suggests very coarse resolutions - in this case, fall back to native tile res tile_transform = transform.from_bounds(*tile_bounds, *tile_size) tile_res = (abs(tile_transform.a), abs(tile_transform.e)) diff --git a/tests/test_raster.py b/tests/test_raster.py index 69150c91..625cf545 100644 --- a/tests/test_raster.py +++ b/tests/test_raster.py @@ -199,3 +199,26 @@ def test_compute_metadata_unoptimized(unoptimized_raster_file): ) assert geometry_mismatch(shape(mtd['convex_hull']), convex_hull) < 1e-6 + + +def test_get_raster_tile_out_of_bounds(raster_file): + from terracotta import exceptions + from terracotta import raster + + bounds = ( + -1e30, + -1e30, + 1e30, + 1e30, + ) + + with pytest.raises(exceptions.TileOutOfBoundsError): + raster.get_raster_tile(str(raster_file), tile_bounds=bounds) + + +def test_get_raster_no_nodata(big_raster_file_nomask): + from terracotta import raster + + tile_size = (256, 256) + out = raster.get_raster_tile(str(big_raster_file_nomask), tile_size=tile_size) + assert out.shape == tile_size From 8d7ad06010f7e95baec6198a4cc979856ee4869d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Tue, 1 Feb 2022 13:17:55 +0100 Subject: [PATCH 069/107] replace type ignore with assertion --- terracotta/drivers/base.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/terracotta/drivers/base.py b/terracotta/drivers/base.py index 49fa031d..464ed441 100644 --- a/terracotta/drivers/base.py +++ b/terracotta/drivers/base.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict from typing import (Any, Callable, Dict, List, Mapping, Optional, Sequence, - Tuple, TypeVar, Union) + Tuple, TypeVar, Union, cast) KeysType = Mapping[str, str] MultiValueKeysType = Mapping[str, Union[str, List[str]]] @@ -25,9 +25,10 @@ def requires_connection( @functools.wraps(fun) def inner(self: MetaStore, *args: Any, **kwargs: Any) -> T: + assert fun is not None with self.connect(verify=verify): - # Apparently mypy thinks fun might still be None, hence the ignore: - return fun(self, *args, **kwargs) # type: ignore + return fun(self, *args, **kwargs) + return inner From 33373e132b5f31cd38eae46bb073321737f3e341 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Tue, 1 Feb 2022 14:05:47 +0100 Subject: [PATCH 070/107] :lipstick: --- terracotta/drivers/base.py | 2 +- terracotta/drivers/raster_base.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/terracotta/drivers/base.py b/terracotta/drivers/base.py index 464ed441..c7db2a04 100644 --- a/terracotta/drivers/base.py +++ b/terracotta/drivers/base.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict from typing import (Any, Callable, Dict, List, Mapping, Optional, Sequence, - Tuple, TypeVar, Union, cast) + Tuple, TypeVar, Union) KeysType = Mapping[str, str] MultiValueKeysType = Mapping[str, Union[str, List[str]]] diff --git a/terracotta/drivers/raster_base.py b/terracotta/drivers/raster_base.py index 3fad7181..f85a3bd6 100644 --- a/terracotta/drivers/raster_base.py +++ b/terracotta/drivers/raster_base.py @@ -87,7 +87,7 @@ def __init__(self) -> None: from rasterio import Env self._rio_env = Env(**self._RIO_ENV_OPTIONS) - def compute_metadata(self, handle: str, *, # noqa: F821 + def compute_metadata(self, handle: str, *, extra_metadata: Any = None, use_chunks: bool = None, max_shape: Sequence[int] = None) -> Dict[str, Any]: @@ -103,8 +103,6 @@ def get_raster_tile(self, tile_size: Sequence[int] = None, preserve_values: bool = False, asynchronous: bool = False) -> Any: - # This wrapper handles cache interaction and asynchronous tile retrieval. - # The real work is done in terracotta.raster.get_raster_tile. future: Future[np.ma.MaskedArray] result: np.ma.MaskedArray From 36eaf1dd4bdea037820083d4c8ffc1dc1faa1af9 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 1 Feb 2022 17:00:53 +0100 Subject: [PATCH 071/107] Rename driver files and make key standardization a method --- terracotta/drivers/__init__.py | 4 +- terracotta/drivers/base.py | 211 -------------------------- terracotta/drivers/driver.py | 177 --------------------- terracotta/drivers/raster_base.py | 2 +- terracotta/drivers/relational_base.py | 5 +- terracotta/xyz.py | 2 +- 6 files changed, 7 insertions(+), 394 deletions(-) delete mode 100644 terracotta/drivers/base.py delete mode 100644 terracotta/drivers/driver.py diff --git a/terracotta/drivers/__init__.py b/terracotta/drivers/__init__.py index e50635b5..6bfc09a7 100644 --- a/terracotta/drivers/__init__.py +++ b/terracotta/drivers/__init__.py @@ -8,8 +8,8 @@ import urllib.parse as urlparse from pathlib import Path -from terracotta.drivers.base import MetaStore -from terracotta.drivers.driver import TerracottaDriver +from terracotta.drivers.base_classes import MetaStore +from terracotta.drivers.terracotta_driver import TerracottaDriver from terracotta.drivers.raster_base import RasterDriver URLOrPathType = Union[str, Path] diff --git a/terracotta/drivers/base.py b/terracotta/drivers/base.py deleted file mode 100644 index c7db2a04..00000000 --- a/terracotta/drivers/base.py +++ /dev/null @@ -1,211 +0,0 @@ -"""drivers/base.py - -Base class for drivers. -""" - -import contextlib -import functools -from abc import ABC, abstractmethod -from collections import OrderedDict -from typing import (Any, Callable, Dict, List, Mapping, Optional, Sequence, - Tuple, TypeVar, Union) - -KeysType = Mapping[str, str] -MultiValueKeysType = Mapping[str, Union[str, List[str]]] -Number = TypeVar('Number', int, float) -T = TypeVar('T') - - -def requires_connection( - fun: Callable[..., T] = None, *, - verify: bool = True -) -> Union[Callable[..., T], functools.partial]: - if fun is None: - return functools.partial(requires_connection, verify=verify) - - @functools.wraps(fun) - def inner(self: MetaStore, *args: Any, **kwargs: Any) -> T: - assert fun is not None - with self.connect(verify=verify): - return fun(self, *args, **kwargs) - - return inner - - -class MetaStore(ABC): - """Abstract base class for all Terracotta data backends. - - Defines a common interface for all drivers. - """ - _RESERVED_KEYS = ('limit', 'page') - - @property - @abstractmethod - def db_version(self) -> str: - """Terracotta version used to create the database.""" - pass - - @property - @abstractmethod - def key_names(self) -> Tuple[str, ...]: - """Names of all keys defined by the database.""" - pass - - @abstractmethod - def __init__(self, url_or_path: str) -> None: - self.path = url_or_path - - @classmethod - def _normalize_path(cls, path: str) -> str: - """Convert given path to normalized version (that can be used for caching)""" - return path - - @abstractmethod - def create(self, keys: Sequence[str], *, - key_descriptions: Mapping[str, str] = None) -> None: - # Create a new, empty database (driver dependent) - pass - - @abstractmethod - def connect(self, verify: bool = True) -> contextlib.AbstractContextManager: - """Context manager to connect to a given database and clean up on exit. - - This allows you to pool interactions with the database to prevent possibly - expensive reconnects, or to roll back several interactions if one of them fails. - - Arguments: - - verify: Whether to verify the database (primarily its version) when connecting. - Should be `true` unless absolutely necessary, such as when instantiating the - database during creation of it. - - Note: - - Make sure to call :meth:`create` on a fresh database before using this method. - - Example: - - >>> import terracotta as tc - >>> driver = tc.get_driver('tc.sqlite') - >>> with driver.connect(): - ... for keys, dataset in datasets.items(): - ... # connection will be kept open between insert operations - ... driver.insert(keys, dataset) - - """ - pass - - @abstractmethod - def get_keys(self) -> OrderedDict: - """Get all known keys and their fulltext descriptions. - - Returns: - - An :class:`~collections.OrderedDict` in the form - ``{key_name: key_description}`` - - """ - pass - - @abstractmethod - def get_datasets(self, where: MultiValueKeysType = None, - page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], Any]: - # Get all known dataset key combinations matching the given constraints, - # and a handle to retrieve the data (driver dependent) - pass - - @abstractmethod - def get_metadata(self, keys: KeysType) -> Optional[Dict[str, Any]]: - """Return all stored metadata for given keys. - - Arguments: - - keys: Keys of the requested dataset. Can either be given as a sequence of key values, - or as a mapping ``{key_name: key_value}``. - - Returns: - - A :class:`dict` with the values - - - ``range``: global minimum and maximum value in dataset - - ``bounds``: physical bounds covered by dataset in latitude-longitude projection - - ``convex_hull``: GeoJSON shape specifying total data coverage in latitude-longitude - projection - - ``percentiles``: array of pre-computed percentiles from 1% through 99% - - ``mean``: global mean - - ``stdev``: global standard deviation - - ``metadata``: any additional client-relevant metadata - - """ - pass - - @abstractmethod - def insert(self, keys: KeysType, - handle: Any, **kwargs: Any) -> None: - """Register a new dataset. Used to populate metadata database. - - Arguments: - - keys: Keys of the dataset. Can either be given as a sequence of key values, or - as a mapping ``{key_name: key_value}``. - handle: Handle to access dataset (driver dependent). - - """ - pass - - @abstractmethod - def delete(self, keys: KeysType) -> None: - """Remove a dataset from the metadata database. - - Arguments: - - keys: Keys of the dataset. Can either be given as a sequence of key values, or - as a mapping ``{key_name: key_value}``. - - """ - pass - - def __repr__(self) -> str: - return f'{self.__class__.__name__}(\'{self.path}\')' - - -class RasterStore(ABC): - - @abstractmethod - # TODO: add accurate signature if mypy ever supports conditional return types - def get_raster_tile(self, handle: str, *, - tile_bounds: Sequence[float] = None, - tile_size: Sequence[int] = (256, 256), - preserve_values: bool = False, - asynchronous: bool = False) -> Any: - """Load a raster tile with given handle and bounds. - - Arguments: - - handle: Handle of the requested dataset. - tile_bounds: Physical bounds of the tile to read, in Web Mercator projection (EPSG3857). - Reads the whole dataset if not given. - tile_size: Shape of the output array to return. Must be two-dimensional. - Defaults to :attr:`~terracotta.config.TerracottaSettings.DEFAULT_TILE_SIZE`. - preserve_values: Whether to preserve exact numerical values (e.g. when reading - categorical data). Sets all interpolation to nearest neighbor. - asynchronous: If given, the tile will be read asynchronously in a separate thread. - This function will return immediately with a :class:`~concurrent.futures.Future` - that can be used to retrieve the result. - - Returns: - - Requested tile as :class:`~numpy.ma.MaskedArray` of shape ``tile_size`` if - ``asynchronous=False``, otherwise a :class:`~concurrent.futures.Future` containing - the result. - - """ - pass - - @abstractmethod - def compute_metadata(self, handle: str, *, - extra_metadata: Any = None, - use_chunks: bool = None, - max_shape: Sequence[int] = None) -> Dict[str, Any]: - # Compute metadata for a given input file (driver dependent) - pass diff --git a/terracotta/drivers/driver.py b/terracotta/drivers/driver.py deleted file mode 100644 index 8787542c..00000000 --- a/terracotta/drivers/driver.py +++ /dev/null @@ -1,177 +0,0 @@ -from typing import (Any, Callable, Collection, Dict, Mapping, - Sequence, Tuple, TypeVar, Union, cast) -import contextlib -import functools -from collections import OrderedDict - -import terracotta -from terracotta import exceptions -from terracotta.drivers.base import (KeysType, MetaStore, MultiValueKeysType, - RasterStore, requires_connection) - -ExtendedKeysType = Union[Sequence[str], KeysType] -T = TypeVar('T') - - -def only_element(iterable: Collection[T]) -> T: - if not iterable: - raise exceptions.DatasetNotFoundError('No dataset found') - assert len(iterable) == 1 - return next(iter(iterable)) - - -def standardize_keys( - fun: Callable[..., T] = None, *, - requires_all_keys: bool = False -) -> Union[Callable[..., T], functools.partial]: - if fun is None: - return functools.partial(standardize_keys, requires_all_keys=requires_all_keys) - - @functools.wraps(fun) - def inner( - self: "TerracottaDriver", - keys: ExtendedKeysType = None, - *args: Any, **kwargs: Any - ) -> T: - if requires_all_keys and (keys is None or len(keys) != len(self.key_names)): - raise exceptions.InvalidKeyError( - f'Got wrong number of keys (available keys: {self.key_names})' - ) - - if isinstance(keys, Mapping): - keys = dict(keys.items()) - elif isinstance(keys, Sequence): - keys = dict(zip(self.key_names, keys)) - elif keys is None: - keys = {} - else: - raise exceptions.InvalidKeyError( - 'Encountered unknown key type, expected Mapping or Sequence' - ) - - unknown_keys = set(keys) - set(self.key_names) - if unknown_keys: - raise exceptions.InvalidKeyError( - f'Encountered unrecognized keys {unknown_keys} (available keys: {self.key_names})' - ) - - # Apparently mypy thinks fun might still be None, hence the ignore: - return fun(self, keys, *args, **kwargs) # type: ignore - return inner - - -class TerracottaDriver: - - def __init__(self, metastore: MetaStore, rasterstore: RasterStore) -> None: - self.metastore = metastore - self.rasterstore = rasterstore - - settings = terracotta.get_settings() - self.LAZY_LOADING_MAX_SHAPE: Tuple[int, int] = settings.LAZY_LOADING_MAX_SHAPE - - @property - def db_version(self) -> str: - return self.metastore.db_version - - @property - def key_names(self) -> Tuple[str, ...]: - return self.metastore.key_names - - def create(self, keys: Sequence[str], *, - key_descriptions: Mapping[str, str] = None) -> None: - self.metastore.create(keys=keys, key_descriptions=key_descriptions) - - def connect(self, verify: bool = True) -> contextlib.AbstractContextManager: - return self.metastore.connect(verify=verify) - - @requires_connection - def get_keys(self) -> OrderedDict: - return self.metastore.get_keys() - - @requires_connection - @standardize_keys - def get_datasets(self, keys: MultiValueKeysType = None, - page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], Any]: - return self.metastore.get_datasets( - where=keys, - page=page, - limit=limit - ) - - @requires_connection - @standardize_keys(requires_all_keys=True) - def get_metadata(self, keys: ExtendedKeysType) -> Dict[str, Any]: - keys = cast(KeysType, keys) - - metadata = self.metastore.get_metadata(keys) - - if metadata is None: - # metadata is not computed yet, trigger lazy loading - handle = only_element(self.get_datasets(keys).values()) - metadata = self.compute_metadata(handle, max_shape=self.LAZY_LOADING_MAX_SHAPE) - self.insert(keys, handle, metadata=metadata) - - # this is necessary to make the lazy loading tests pass... - metadata = self.metastore.get_metadata(keys) - assert metadata is not None - - return metadata - - @requires_connection - @standardize_keys(requires_all_keys=True) - def insert( - self, keys: ExtendedKeysType, - handle: Any, *, - override_path: str = None, - metadata: Mapping[str, Any] = None, - skip_metadata: bool = False, - **kwargs: Any - ) -> None: - keys = cast(KeysType, keys) - - if metadata is None and not skip_metadata: - metadata = self.compute_metadata(handle) - - self.metastore.insert( - keys=keys, - handle=override_path or handle, - metadata=metadata, - **kwargs - ) - - @requires_connection - @standardize_keys(requires_all_keys=True) - def delete(self, keys: ExtendedKeysType) -> None: - keys = cast(KeysType, keys) - - self.metastore.delete(keys) - - # @standardize_keys(requires_all_keys=True) - def get_raster_tile(self, keys: ExtendedKeysType, *, - tile_bounds: Sequence[float] = None, - tile_size: Sequence[int] = (256, 256), - preserve_values: bool = False, - asynchronous: bool = False) -> Any: - handle = only_element(self.get_datasets(keys).values()) - - return self.rasterstore.get_raster_tile( - handle=handle, - tile_bounds=tile_bounds, - tile_size=tile_size, - preserve_values=preserve_values, - asynchronous=asynchronous, - ) - - def compute_metadata(self, handle: str, *, - extra_metadata: Any = None, - use_chunks: bool = None, - max_shape: Sequence[int] = None) -> Dict[str, Any]: - return self.rasterstore.compute_metadata( - handle=handle, - extra_metadata=extra_metadata, - use_chunks=use_chunks, - max_shape=max_shape, - ) - - def __repr__(self) -> str: - return self.metastore.__repr__() diff --git a/terracotta/drivers/raster_base.py b/terracotta/drivers/raster_base.py index f85a3bd6..42ff1f7f 100644 --- a/terracotta/drivers/raster_base.py +++ b/terracotta/drivers/raster_base.py @@ -17,7 +17,7 @@ from terracotta import get_settings from terracotta import raster from terracotta.cache import CompressedLFUCache -from terracotta.drivers.base import RasterStore +from terracotta.drivers.base_classes import RasterStore Number = TypeVar('Number', int, float) diff --git a/terracotta/drivers/relational_base.py b/terracotta/drivers/relational_base.py index 3a37a337..164bf41e 100644 --- a/terracotta/drivers/relational_base.py +++ b/terracotta/drivers/relational_base.py @@ -19,8 +19,9 @@ from sqlalchemy.engine.base import Connection from sqlalchemy.engine.url import URL from terracotta import exceptions -from terracotta.drivers.base import (KeysType, MetaStore, MultiValueKeysType, - requires_connection) +from terracotta.drivers.base_classes import (KeysType, MetaStore, + MultiValueKeysType, + requires_connection) from terracotta.profile import trace _ERROR_ON_CONNECT = ( diff --git a/terracotta/xyz.py b/terracotta/xyz.py index 97cedef4..dcbfe031 100644 --- a/terracotta/xyz.py +++ b/terracotta/xyz.py @@ -8,7 +8,7 @@ import mercantile from terracotta import exceptions -from terracotta.drivers.driver import TerracottaDriver +from terracotta.drivers.terracotta_driver import TerracottaDriver # TODO: add accurate signature if mypy ever supports conditional return types From 046720df73fe3731983b7c6fe5bd350a7623ac4c Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 1 Feb 2022 17:01:21 +0100 Subject: [PATCH 072/107] Remember the new/renamed files! --- terracotta/drivers/base_classes.py | 211 ++++++++++++++++++++++++ terracotta/drivers/terracotta_driver.py | 164 ++++++++++++++++++ 2 files changed, 375 insertions(+) create mode 100644 terracotta/drivers/base_classes.py create mode 100644 terracotta/drivers/terracotta_driver.py diff --git a/terracotta/drivers/base_classes.py b/terracotta/drivers/base_classes.py new file mode 100644 index 00000000..5bd1134a --- /dev/null +++ b/terracotta/drivers/base_classes.py @@ -0,0 +1,211 @@ +"""drivers/base_classes.py + +Base class for drivers. +""" + +import contextlib +import functools +from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import (Any, Callable, Dict, List, Mapping, Optional, Sequence, + Tuple, TypeVar, Union) + +KeysType = Mapping[str, str] +MultiValueKeysType = Mapping[str, Union[str, List[str]]] +Number = TypeVar('Number', int, float) +T = TypeVar('T') + + +def requires_connection( + fun: Callable[..., T] = None, *, + verify: bool = True +) -> Union[Callable[..., T], functools.partial]: + if fun is None: + return functools.partial(requires_connection, verify=verify) + + @functools.wraps(fun) + def inner(self: MetaStore, *args: Any, **kwargs: Any) -> T: + assert fun is not None + with self.connect(verify=verify): + return fun(self, *args, **kwargs) + + return inner + + +class MetaStore(ABC): + """Abstract base class for all Terracotta data backends. + + Defines a common interface for all drivers. + """ + _RESERVED_KEYS = ('limit', 'page') + + @property + @abstractmethod + def db_version(self) -> str: + """Terracotta version used to create the database.""" + pass + + @property + @abstractmethod + def key_names(self) -> Tuple[str, ...]: + """Names of all keys defined by the database.""" + pass + + @abstractmethod + def __init__(self, url_or_path: str) -> None: + self.path = url_or_path + + @classmethod + def _normalize_path(cls, path: str) -> str: + """Convert given path to normalized version (that can be used for caching)""" + return path + + @abstractmethod + def create(self, keys: Sequence[str], *, + key_descriptions: Mapping[str, str] = None) -> None: + # Create a new, empty database (driver dependent) + pass + + @abstractmethod + def connect(self, verify: bool = True) -> contextlib.AbstractContextManager: + """Context manager to connect to a given database and clean up on exit. + + This allows you to pool interactions with the database to prevent possibly + expensive reconnects, or to roll back several interactions if one of them fails. + + Arguments: + + verify: Whether to verify the database (primarily its version) when connecting. + Should be `true` unless absolutely necessary, such as when instantiating the + database during creation of it. + + Note: + + Make sure to call :meth:`create` on a fresh database before using this method. + + Example: + + >>> import terracotta as tc + >>> driver = tc.get_driver('tc.sqlite') + >>> with driver.connect(): + ... for keys, dataset in datasets.items(): + ... # connection will be kept open between insert operations + ... driver.insert(keys, dataset) + + """ + pass + + @abstractmethod + def get_keys(self) -> OrderedDict: + """Get all known keys and their fulltext descriptions. + + Returns: + + An :class:`~collections.OrderedDict` in the form + ``{key_name: key_description}`` + + """ + pass + + @abstractmethod + def get_datasets(self, where: MultiValueKeysType = None, + page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], Any]: + # Get all known dataset key combinations matching the given constraints, + # and a handle to retrieve the data (driver dependent) + pass + + @abstractmethod + def get_metadata(self, keys: KeysType) -> Optional[Dict[str, Any]]: + """Return all stored metadata for given keys. + + Arguments: + + keys: Keys of the requested dataset. Can either be given as a sequence of key values, + or as a mapping ``{key_name: key_value}``. + + Returns: + + A :class:`dict` with the values + + - ``range``: global minimum and maximum value in dataset + - ``bounds``: physical bounds covered by dataset in latitude-longitude projection + - ``convex_hull``: GeoJSON shape specifying total data coverage in latitude-longitude + projection + - ``percentiles``: array of pre-computed percentiles from 1% through 99% + - ``mean``: global mean + - ``stdev``: global standard deviation + - ``metadata``: any additional client-relevant metadata + + """ + pass + + @abstractmethod + def insert(self, keys: KeysType, + handle: Any, **kwargs: Any) -> None: + """Register a new dataset. Used to populate metadata database. + + Arguments: + + keys: Keys of the dataset. Can either be given as a sequence of key values, or + as a mapping ``{key_name: key_value}``. + handle: Handle to access dataset (driver dependent). + + """ + pass + + @abstractmethod + def delete(self, keys: KeysType) -> None: + """Remove a dataset from the metadata database. + + Arguments: + + keys: Keys of the dataset. Can either be given as a sequence of key values, or + as a mapping ``{key_name: key_value}``. + + """ + pass + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(\'{self.path}\')' + + +class RasterStore(ABC): + + @abstractmethod + # TODO: add accurate signature if mypy ever supports conditional return types + def get_raster_tile(self, handle: str, *, + tile_bounds: Sequence[float] = None, + tile_size: Sequence[int] = (256, 256), + preserve_values: bool = False, + asynchronous: bool = False) -> Any: + """Load a raster tile with given handle and bounds. + + Arguments: + + handle: Handle of the requested dataset. + tile_bounds: Physical bounds of the tile to read, in Web Mercator projection (EPSG3857). + Reads the whole dataset if not given. + tile_size: Shape of the output array to return. Must be two-dimensional. + Defaults to :attr:`~terracotta.config.TerracottaSettings.DEFAULT_TILE_SIZE`. + preserve_values: Whether to preserve exact numerical values (e.g. when reading + categorical data). Sets all interpolation to nearest neighbor. + asynchronous: If given, the tile will be read asynchronously in a separate thread. + This function will return immediately with a :class:`~concurrent.futures.Future` + that can be used to retrieve the result. + + Returns: + + Requested tile as :class:`~numpy.ma.MaskedArray` of shape ``tile_size`` if + ``asynchronous=False``, otherwise a :class:`~concurrent.futures.Future` containing + the result. + + """ + pass + + @abstractmethod + def compute_metadata(self, handle: str, *, + extra_metadata: Any = None, + use_chunks: bool = None, + max_shape: Sequence[int] = None) -> Dict[str, Any]: + # Compute metadata for a given input file (driver dependent) + pass diff --git a/terracotta/drivers/terracotta_driver.py b/terracotta/drivers/terracotta_driver.py new file mode 100644 index 00000000..e51c88db --- /dev/null +++ b/terracotta/drivers/terracotta_driver.py @@ -0,0 +1,164 @@ +"""drivers/terracotta_driver.py + +The driver to interact with. +""" + +import contextlib +from collections import OrderedDict +from typing import (Any, Collection, Dict, Mapping, Sequence, Tuple, TypeVar, + Union) + +import terracotta +from terracotta import exceptions +from terracotta.drivers.base_classes import (KeysType, MetaStore, + MultiValueKeysType, RasterStore, + requires_connection) + +ExtendedKeysType = Union[Sequence[str], KeysType] +T = TypeVar('T') + + +def only_element(iterable: Collection[T]) -> T: + if not iterable: + raise exceptions.DatasetNotFoundError('No dataset found') + assert len(iterable) == 1 + return next(iter(iterable)) + + +class TerracottaDriver: + + def __init__(self, metastore: MetaStore, rasterstore: RasterStore) -> None: + self.metastore = metastore + self.rasterstore = rasterstore + + settings = terracotta.get_settings() + self.LAZY_LOADING_MAX_SHAPE: Tuple[int, int] = settings.LAZY_LOADING_MAX_SHAPE + + @property + def db_version(self) -> str: + return self.metastore.db_version + + @property + def key_names(self) -> Tuple[str, ...]: + return self.metastore.key_names + + def create(self, keys: Sequence[str], *, + key_descriptions: Mapping[str, str] = None) -> None: + self.metastore.create(keys=keys, key_descriptions=key_descriptions) + + def connect(self, verify: bool = True) -> contextlib.AbstractContextManager: + return self.metastore.connect(verify=verify) + + @requires_connection + def get_keys(self) -> OrderedDict: + return self.metastore.get_keys() + + @requires_connection + def get_datasets(self, keys: MultiValueKeysType = None, + page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], Any]: + return self.metastore.get_datasets( + where=keys, + page=page, + limit=limit + ) + + @requires_connection + def get_metadata(self, keys: ExtendedKeysType) -> Dict[str, Any]: + keys = self._standardize_keys(keys) + + metadata = self.metastore.get_metadata(keys) + + if metadata is None: + # metadata is not computed yet, trigger lazy loading + handle = only_element(self.get_datasets(keys).values()) + metadata = self.compute_metadata(handle, max_shape=self.LAZY_LOADING_MAX_SHAPE) + self.insert(keys, handle, metadata=metadata) + + # this is necessary to make the lazy loading tests pass... + metadata = self.metastore.get_metadata(keys) + assert metadata is not None + + return metadata + + @requires_connection + def insert( + self, keys: ExtendedKeysType, + handle: Any, *, + override_path: str = None, + metadata: Mapping[str, Any] = None, + skip_metadata: bool = False, + **kwargs: Any + ) -> None: + keys = self._standardize_keys(keys) + + if metadata is None and not skip_metadata: + metadata = self.compute_metadata(handle) + + self.metastore.insert( + keys=keys, + handle=override_path or handle, + metadata=metadata, + **kwargs + ) + + @requires_connection + def delete(self, keys: ExtendedKeysType) -> None: + keys = self._standardize_keys(keys) + + self.metastore.delete(keys) + + def get_raster_tile(self, keys: ExtendedKeysType, *, + tile_bounds: Sequence[float] = None, + tile_size: Sequence[int] = (256, 256), + preserve_values: bool = False, + asynchronous: bool = False) -> Any: + handle = only_element(self.get_datasets(keys).values()) + + return self.rasterstore.get_raster_tile( + handle=handle, + tile_bounds=tile_bounds, + tile_size=tile_size, + preserve_values=preserve_values, + asynchronous=asynchronous, + ) + + def compute_metadata(self, handle: str, *, + extra_metadata: Any = None, + use_chunks: bool = None, + max_shape: Sequence[int] = None) -> Dict[str, Any]: + return self.rasterstore.compute_metadata( + handle=handle, + extra_metadata=extra_metadata, + use_chunks=use_chunks, + max_shape=max_shape, + ) + + def _standardize_keys( + self, + keys: ExtendedKeysType, + requires_all_keys: bool = True + ) -> KeysType: + if requires_all_keys and (keys is None or len(keys) != len(self.key_names)): + raise exceptions.InvalidKeyError( + f'Got wrong number of keys (available keys: {self.key_names})' + ) + + if isinstance(keys, Mapping): + keys = dict(keys.items()) + elif isinstance(keys, Sequence): + keys = dict(zip(self.key_names, keys)) + else: + raise exceptions.InvalidKeyError( + 'Encountered unknown key type, expected Mapping or Sequence' + ) + + unknown_keys = set(keys) - set(self.key_names) + if unknown_keys: + raise exceptions.InvalidKeyError( + f'Encountered unrecognized keys {unknown_keys} (available keys: {self.key_names})' + ) + + return keys + + def __repr__(self) -> str: + return self.metastore.__repr__() From bd45e00e85f6ebebd6d0ea076eec68b61db50296 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 1 Feb 2022 17:09:29 +0100 Subject: [PATCH 073/107] Use underscores in meta_store and raster_store --- terracotta/drivers/__init__.py | 4 ++-- terracotta/drivers/terracotta_driver.py | 32 ++++++++++++------------- tests/benchmarks.py | 2 +- tests/drivers/test_drivers.py | 8 +++---- tests/drivers/test_mysql.py | 2 +- tests/drivers/test_raster_drivers.py | 10 ++++---- tests/drivers/test_sqlite_remote.py | 22 ++++++++--------- 7 files changed, 40 insertions(+), 40 deletions(-) diff --git a/terracotta/drivers/__init__.py b/terracotta/drivers/__init__.py index 6bfc09a7..b8f55927 100644 --- a/terracotta/drivers/__init__.py +++ b/terracotta/drivers/__init__.py @@ -88,8 +88,8 @@ def get_driver(url_or_path: URLOrPathType, provider: str = None) -> TerracottaDr if cache_key not in _DRIVER_CACHE: driver = TerracottaDriver( - metastore=DriverClass(url_or_path), - rasterstore=RasterDriver() + meta_store=DriverClass(url_or_path), + raster_store=RasterDriver() ) _DRIVER_CACHE[cache_key] = driver diff --git a/terracotta/drivers/terracotta_driver.py b/terracotta/drivers/terracotta_driver.py index e51c88db..6c885b59 100644 --- a/terracotta/drivers/terracotta_driver.py +++ b/terracotta/drivers/terracotta_driver.py @@ -27,36 +27,36 @@ def only_element(iterable: Collection[T]) -> T: class TerracottaDriver: - def __init__(self, metastore: MetaStore, rasterstore: RasterStore) -> None: - self.metastore = metastore - self.rasterstore = rasterstore + def __init__(self, meta_store: MetaStore, raster_store: RasterStore) -> None: + self.meta_store = meta_store + self.raster_store = raster_store settings = terracotta.get_settings() self.LAZY_LOADING_MAX_SHAPE: Tuple[int, int] = settings.LAZY_LOADING_MAX_SHAPE @property def db_version(self) -> str: - return self.metastore.db_version + return self.meta_store.db_version @property def key_names(self) -> Tuple[str, ...]: - return self.metastore.key_names + return self.meta_store.key_names def create(self, keys: Sequence[str], *, key_descriptions: Mapping[str, str] = None) -> None: - self.metastore.create(keys=keys, key_descriptions=key_descriptions) + self.meta_store.create(keys=keys, key_descriptions=key_descriptions) def connect(self, verify: bool = True) -> contextlib.AbstractContextManager: - return self.metastore.connect(verify=verify) + return self.meta_store.connect(verify=verify) @requires_connection def get_keys(self) -> OrderedDict: - return self.metastore.get_keys() + return self.meta_store.get_keys() @requires_connection def get_datasets(self, keys: MultiValueKeysType = None, page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], Any]: - return self.metastore.get_datasets( + return self.meta_store.get_datasets( where=keys, page=page, limit=limit @@ -66,7 +66,7 @@ def get_datasets(self, keys: MultiValueKeysType = None, def get_metadata(self, keys: ExtendedKeysType) -> Dict[str, Any]: keys = self._standardize_keys(keys) - metadata = self.metastore.get_metadata(keys) + metadata = self.meta_store.get_metadata(keys) if metadata is None: # metadata is not computed yet, trigger lazy loading @@ -75,7 +75,7 @@ def get_metadata(self, keys: ExtendedKeysType) -> Dict[str, Any]: self.insert(keys, handle, metadata=metadata) # this is necessary to make the lazy loading tests pass... - metadata = self.metastore.get_metadata(keys) + metadata = self.meta_store.get_metadata(keys) assert metadata is not None return metadata @@ -94,7 +94,7 @@ def insert( if metadata is None and not skip_metadata: metadata = self.compute_metadata(handle) - self.metastore.insert( + self.meta_store.insert( keys=keys, handle=override_path or handle, metadata=metadata, @@ -105,7 +105,7 @@ def insert( def delete(self, keys: ExtendedKeysType) -> None: keys = self._standardize_keys(keys) - self.metastore.delete(keys) + self.meta_store.delete(keys) def get_raster_tile(self, keys: ExtendedKeysType, *, tile_bounds: Sequence[float] = None, @@ -114,7 +114,7 @@ def get_raster_tile(self, keys: ExtendedKeysType, *, asynchronous: bool = False) -> Any: handle = only_element(self.get_datasets(keys).values()) - return self.rasterstore.get_raster_tile( + return self.raster_store.get_raster_tile( handle=handle, tile_bounds=tile_bounds, tile_size=tile_size, @@ -126,7 +126,7 @@ def compute_metadata(self, handle: str, *, extra_metadata: Any = None, use_chunks: bool = None, max_shape: Sequence[int] = None) -> Dict[str, Any]: - return self.rasterstore.compute_metadata( + return self.raster_store.compute_metadata( handle=handle, extra_metadata=extra_metadata, use_chunks=use_chunks, @@ -161,4 +161,4 @@ def _standardize_keys( return keys def __repr__(self) -> str: - return self.metastore.__repr__() + return self.meta_store.__repr__() diff --git a/tests/benchmarks.py b/tests/benchmarks.py index 786b00f6..adc9ec34 100644 --- a/tests/benchmarks.py +++ b/tests/benchmarks.py @@ -114,7 +114,7 @@ def test_bench_singleband(benchmark, zoom, resampling, big_raster_file_nodata, b rv = benchmark(client.get, '/singleband/nodata/1/preview.png') assert rv.status_code == 200 - assert not len(get_driver(str(benchmark_database)).rasterstore._raster_cache) + assert not len(get_driver(str(benchmark_database)).raster_store._raster_cache) def test_bench_singleband_out_of_bounds(benchmark, benchmark_database): diff --git a/tests/drivers/test_drivers.py b/tests/drivers/test_drivers.py index 88f26faf..d7b97302 100644 --- a/tests/drivers/test_drivers.py +++ b/tests/drivers/test_drivers.py @@ -12,7 +12,7 @@ def test_auto_detect(driver_path, provider): from terracotta import drivers db = drivers.get_driver(driver_path) - assert db.metastore.__class__.__name__ == DRIVER_CLASSES[provider] + assert db.meta_store.__class__.__name__ == DRIVER_CLASSES[provider] assert drivers.get_driver(driver_path, provider=provider) is db @@ -177,10 +177,10 @@ def __getattribute__(self, key): with pytest.raises(RuntimeError): with db.connect(): - db.metastore.connection = Evanescence() + db.meta_store.connection = Evanescence() db.get_keys() - assert not db.metastore.connected + assert not db.meta_store.connected with db.connect(): db.get_keys() @@ -222,7 +222,7 @@ def test_version_conflict(driver_path, provider, raster_file, monkeypatch): with monkeypatch.context() as m: fake_version = '0.0.0' m.setattr('terracotta.__version__', fake_version) - db.metastore.db_version_verified = False + db.meta_store.db_version_verified = False with pytest.raises(exceptions.InvalidDatabaseError) as exc: with db.connect(): diff --git a/tests/drivers/test_mysql.py b/tests/drivers/test_mysql.py index a86c8477..e51aec44 100644 --- a/tests/drivers/test_mysql.py +++ b/tests/drivers/test_mysql.py @@ -33,7 +33,7 @@ def test_path_parsing(case): db = drivers.get_driver(case, provider='mysql') for attr in ('username', 'password', 'host', 'port', 'database'): - assert getattr(db.metastore.url, attr) == TEST_CASES[case].get(attr, None) + assert getattr(db.meta_store.url, attr) == TEST_CASES[case].get(attr, None) @pytest.mark.parametrize('case', INVALID_TEST_CASES) diff --git a/tests/drivers/test_raster_drivers.py b/tests/drivers/test_raster_drivers.py index f7b2b1c9..398378df 100644 --- a/tests/drivers/test_raster_drivers.py +++ b/tests/drivers/test_raster_drivers.py @@ -332,7 +332,7 @@ def test_raster_cache(driver_path, provider, raster_file, asynchronous): db.insert(['some', 'value'], str(raster_file)) db.insert(['some', 'other_value'], str(raster_file)) - assert len(db.rasterstore._raster_cache) == 0 + assert len(db.raster_store._raster_cache) == 0 data1 = db.get_raster_tile(['some', 'value'], tile_size=(256, 256), asynchronous=asynchronous) @@ -340,7 +340,7 @@ def test_raster_cache(driver_path, provider, raster_file, asynchronous): data1 = data1.result() time.sleep(1) # allow callback to finish - assert len(db.rasterstore._raster_cache) == 1 + assert len(db.raster_store._raster_cache) == 1 data2 = db.get_raster_tile(['some', 'value'], tile_size=(256, 256), asynchronous=asynchronous) @@ -348,7 +348,7 @@ def test_raster_cache(driver_path, provider, raster_file, asynchronous): data2 = data2.result() np.testing.assert_array_equal(data1, data2) - assert len(db.rasterstore._raster_cache) == 1 + assert len(db.raster_store._raster_cache) == 1 @pytest.mark.parametrize('provider', DRIVERS) @@ -364,7 +364,7 @@ def test_raster_cache_fail(driver_path, provider, raster_file, asynchronous): db.create(keys) db.insert(['some', 'value'], str(raster_file)) - assert len(db.rasterstore._raster_cache) == 0 + assert len(db.raster_store._raster_cache) == 0 data1 = db.get_raster_tile(['some', 'value'], tile_size=(256, 256), asynchronous=asynchronous) @@ -372,7 +372,7 @@ def test_raster_cache_fail(driver_path, provider, raster_file, asynchronous): data1 = data1.result() time.sleep(1) # allow callback to finish - assert len(db.rasterstore._raster_cache) == 0 + assert len(db.raster_store._raster_cache) == 0 @pytest.mark.parametrize('provider', DRIVERS) diff --git a/tests/drivers/test_sqlite_remote.py b/tests/drivers/test_sqlite_remote.py index e0b2a787..d4852617 100644 --- a/tests/drivers/test_sqlite_remote.py +++ b/tests/drivers/test_sqlite_remote.py @@ -105,32 +105,32 @@ def test_remote_database_cache(s3_db_factory, raster_file, monkeypatch): from terracotta import get_driver driver = get_driver(dbpath) - driver.metastore._last_updated = -float('inf') + driver.meta_store._last_updated = -float('inf') with driver.connect(): assert driver.key_names == keys assert driver.get_datasets() == {} - modification_date = os.path.getmtime(driver.metastore._local_path) + modification_date = os.path.getmtime(driver.meta_store._local_path) s3_db_factory(keys, datasets={('some', 'value'): str(raster_file)}) # no change yet assert driver.get_datasets() == {} - assert os.path.getmtime(driver.metastore._local_path) == modification_date + assert os.path.getmtime(driver.meta_store._local_path) == modification_date # check if remote db is cached correctly - driver.metastore._last_updated = time.time() + driver.meta_store._last_updated = time.time() with driver.connect(): # db connection is cached; so still no change assert driver.get_datasets() == {} - assert os.path.getmtime(driver.metastore._local_path) == modification_date + assert os.path.getmtime(driver.meta_store._local_path) == modification_date # invalidate cache - driver.metastore._last_updated = -float('inf') + driver.meta_store._last_updated = -float('inf') with driver.connect(): # now db is updated on reconnect assert list(driver.get_datasets().keys()) == [('some', 'value')] - assert os.path.getmtime(driver.metastore._local_path) != modification_date + assert os.path.getmtime(driver.meta_store._local_path) != modification_date @moto.mock_s3 @@ -160,14 +160,14 @@ def test_destructor(s3_db_factory, raster_file, capsys): from terracotta import get_driver driver = get_driver(dbpath) - assert os.path.isfile(driver.metastore._local_path) + assert os.path.isfile(driver.meta_store._local_path) - driver.metastore.__del__() - assert not os.path.isfile(driver.metastore._local_path) + driver.meta_store.__del__() + assert not os.path.isfile(driver.meta_store._local_path) captured = capsys.readouterr() assert 'Exception ignored' not in captured.err # re-create file to prevent actual destructor from failing - with open(driver.metastore._local_path, 'w'): + with open(driver.meta_store._local_path, 'w'): pass From d5d1b09e6478e5e1590a814318bf06af5b2a7f7b Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 1 Feb 2022 17:09:55 +0100 Subject: [PATCH 074/107] Also standardize the where/keys for get_datasets() --- terracotta/drivers/terracotta_driver.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/terracotta/drivers/terracotta_driver.py b/terracotta/drivers/terracotta_driver.py index 6c885b59..e757d295 100644 --- a/terracotta/drivers/terracotta_driver.py +++ b/terracotta/drivers/terracotta_driver.py @@ -57,7 +57,7 @@ def get_keys(self) -> OrderedDict: def get_datasets(self, keys: MultiValueKeysType = None, page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], Any]: return self.meta_store.get_datasets( - where=keys, + where=self._standardize_keys(keys, requires_all_keys=False), page=page, limit=limit ) @@ -135,7 +135,7 @@ def compute_metadata(self, handle: str, *, def _standardize_keys( self, - keys: ExtendedKeysType, + keys: Union[ExtendedKeysType, MultiValueKeysType], requires_all_keys: bool = True ) -> KeysType: if requires_all_keys and (keys is None or len(keys) != len(self.key_names)): @@ -147,6 +147,8 @@ def _standardize_keys( keys = dict(keys.items()) elif isinstance(keys, Sequence): keys = dict(zip(self.key_names, keys)) + elif keys is None: + keys = {} else: raise exceptions.InvalidKeyError( 'Encountered unknown key type, expected Mapping or Sequence' From 3fdd90a7dae3c3831e44e574c5a717121686fb3b Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 1 Feb 2022 17:10:29 +0100 Subject: [PATCH 075/107] Rename to squeeze --- terracotta/drivers/terracotta_driver.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/terracotta/drivers/terracotta_driver.py b/terracotta/drivers/terracotta_driver.py index e757d295..a9c44938 100644 --- a/terracotta/drivers/terracotta_driver.py +++ b/terracotta/drivers/terracotta_driver.py @@ -18,7 +18,7 @@ T = TypeVar('T') -def only_element(iterable: Collection[T]) -> T: +def squeeze(iterable: Collection[T]) -> T: if not iterable: raise exceptions.DatasetNotFoundError('No dataset found') assert len(iterable) == 1 @@ -70,7 +70,7 @@ def get_metadata(self, keys: ExtendedKeysType) -> Dict[str, Any]: if metadata is None: # metadata is not computed yet, trigger lazy loading - handle = only_element(self.get_datasets(keys).values()) + handle = squeeze(self.get_datasets(keys).values()) metadata = self.compute_metadata(handle, max_shape=self.LAZY_LOADING_MAX_SHAPE) self.insert(keys, handle, metadata=metadata) @@ -112,7 +112,7 @@ def get_raster_tile(self, keys: ExtendedKeysType, *, tile_size: Sequence[int] = (256, 256), preserve_values: bool = False, asynchronous: bool = False) -> Any: - handle = only_element(self.get_datasets(keys).values()) + handle = squeeze(self.get_datasets(keys).values()) return self.raster_store.get_raster_tile( handle=handle, From 84a5219d622352785a822d5f6126a3bd38610239 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 1 Feb 2022 17:15:52 +0100 Subject: [PATCH 076/107] Improve repr --- terracotta/drivers/terracotta_driver.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/terracotta/drivers/terracotta_driver.py b/terracotta/drivers/terracotta_driver.py index a9c44938..25539b6e 100644 --- a/terracotta/drivers/terracotta_driver.py +++ b/terracotta/drivers/terracotta_driver.py @@ -163,4 +163,9 @@ def _standardize_keys( return keys def __repr__(self) -> str: - return self.meta_store.__repr__() + return ( + f'{self.__class__.__name__}(\n' + f'\tmeta_store={self.meta_store.__class__.__name__}(path="{self.meta_store.path}"),\n' + f'\traster_store={self.raster_store.__class__.__name__}()\n' + ')' + ) From 1086d529d750e83780d9ef03371cedde0bc1c9c6 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 1 Feb 2022 17:17:13 +0100 Subject: [PATCH 077/107] Rename to GeoTiffRasterStore --- terracotta/drivers/__init__.py | 4 ++-- terracotta/drivers/raster_base.py | 2 +- terracotta/scripts/optimize_rasters.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/terracotta/drivers/__init__.py b/terracotta/drivers/__init__.py index b8f55927..ebebfb0f 100644 --- a/terracotta/drivers/__init__.py +++ b/terracotta/drivers/__init__.py @@ -10,7 +10,7 @@ from terracotta.drivers.base_classes import MetaStore from terracotta.drivers.terracotta_driver import TerracottaDriver -from terracotta.drivers.raster_base import RasterDriver +from terracotta.drivers.raster_base import GeoTiffRasterStore URLOrPathType = Union[str, Path] @@ -89,7 +89,7 @@ def get_driver(url_or_path: URLOrPathType, provider: str = None) -> TerracottaDr if cache_key not in _DRIVER_CACHE: driver = TerracottaDriver( meta_store=DriverClass(url_or_path), - raster_store=RasterDriver() + raster_store=GeoTiffRasterStore() ) _DRIVER_CACHE[cache_key] = driver diff --git a/terracotta/drivers/raster_base.py b/terracotta/drivers/raster_base.py index 42ff1f7f..7d30461d 100644 --- a/terracotta/drivers/raster_base.py +++ b/terracotta/drivers/raster_base.py @@ -64,7 +64,7 @@ def submit_to_executor(task: Callable[..., Any]) -> Future: return future -class RasterDriver(RasterStore): +class GeoTiffRasterStore(RasterStore): """Mixin that implements methods to load raster data from disk. get_datasets has to return path to raster file as sole dict value. diff --git a/terracotta/scripts/optimize_rasters.py b/terracotta/scripts/optimize_rasters.py index cc52f48f..693c14dc 100644 --- a/terracotta/scripts/optimize_rasters.py +++ b/terracotta/scripts/optimize_rasters.py @@ -85,8 +85,8 @@ def _prefered_compression_method() -> str: def _get_vrt(src: DatasetReader, rs_method: int) -> WarpedVRT: - from terracotta.drivers.raster_base import RasterDriver - target_crs = RasterDriver._TARGET_CRS + from terracotta.drivers.raster_base import GeoTiffRasterStore + target_crs = GeoTiffRasterStore._TARGET_CRS vrt_transform, vrt_width, vrt_height = calculate_default_transform( src.crs, target_crs, src.width, src.height, *src.bounds ) From 65cd29a96e24e8883ae2a7d7526c6dd9bdf8c764 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 1 Feb 2022 17:18:19 +0100 Subject: [PATCH 078/107] Rename to RelationalMetaStore --- terracotta/drivers/mysql.py | 4 ++-- terracotta/drivers/relational_base.py | 2 +- terracotta/drivers/sqlite.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/terracotta/drivers/mysql.py b/terracotta/drivers/mysql.py index 6bd6d97b..597e8075 100644 --- a/terracotta/drivers/mysql.py +++ b/terracotta/drivers/mysql.py @@ -9,10 +9,10 @@ import sqlalchemy as sqla from sqlalchemy.dialects.mysql import TEXT, VARCHAR -from terracotta.drivers.relational_base import RelationalDriver +from terracotta.drivers.relational_base import RelationalMetaStore -class MySQLDriver(RelationalDriver): +class MySQLDriver(RelationalMetaStore): """A MySQL-backed raster driver. Assumes raster data to be present in separate GDAL-readable files on disk or remotely. diff --git a/terracotta/drivers/relational_base.py b/terracotta/drivers/relational_base.py index 164bf41e..84373ebe 100644 --- a/terracotta/drivers/relational_base.py +++ b/terracotta/drivers/relational_base.py @@ -50,7 +50,7 @@ def convert_exceptions( raise exceptions.InvalidDatabaseError(error_message) from exception -class RelationalDriver(MetaStore, ABC): +class RelationalMetaStore(MetaStore, ABC): SQL_URL_SCHEME: str # The database flavour, eg mysql, sqlite, etc SQL_DRIVER_TYPE: str # The actual database driver, eg pymysql, sqlite3, etc SQL_KEY_SIZE: int diff --git a/terracotta/drivers/sqlite.py b/terracotta/drivers/sqlite.py index 33d9e16f..dea07437 100644 --- a/terracotta/drivers/sqlite.py +++ b/terracotta/drivers/sqlite.py @@ -8,10 +8,10 @@ from pathlib import Path from typing import Union -from terracotta.drivers.relational_base import RelationalDriver +from terracotta.drivers.relational_base import RelationalMetaStore -class SQLiteDriver(RelationalDriver): +class SQLiteDriver(RelationalMetaStore): """An SQLite-backed raster driver. Assumes raster data to be present in separate GDAL-readable files on disk or remotely. From fed5a66e8f3102d3c00c503639adb07617ac5d3a Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 1 Feb 2022 17:18:45 +0100 Subject: [PATCH 079/107] Don't use too implicit hacks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Dion Häfner --- terracotta/drivers/relational_base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/terracotta/drivers/relational_base.py b/terracotta/drivers/relational_base.py index 3a37a337..23672cc8 100644 --- a/terracotta/drivers/relational_base.py +++ b/terracotta/drivers/relational_base.py @@ -292,9 +292,12 @@ def get_datasets( page: int = 0, limit: int = None ) -> Dict[Tuple[str, ...], str]: + if where is None: + where = {} + where = { key: value if isinstance(value, list) else [value] - for key, value in (where or {}).items() + for key, value in where.items() } datasets_table = sqla.Table('datasets', self.sqla_metadata, autoload_with=self.sqla_engine) From ab6449fb98bf7f1ce1a6e5cd14ca8dc105ea90c0 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 1 Feb 2022 17:21:18 +0100 Subject: [PATCH 080/107] Update test to new repr --- tests/drivers/test_drivers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/drivers/test_drivers.py b/tests/drivers/test_drivers.py index d7b97302..43ebe0cf 100644 --- a/tests/drivers/test_drivers.py +++ b/tests/drivers/test_drivers.py @@ -190,7 +190,7 @@ def __getattribute__(self, key): def test_repr(driver_path, provider): from terracotta import drivers db = drivers.get_driver(driver_path, provider=provider) - assert repr(db).startswith(DRIVER_CLASSES[provider]) + assert f'meta_store={DRIVER_CLASSES[provider]}' in repr(db) @pytest.mark.parametrize('provider', TESTABLE_DRIVERS) From 4ab4bdda4f7d2af1db59d705892659355e2517a3 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 1 Feb 2022 17:25:06 +0100 Subject: [PATCH 081/107] Rename filepath to handle --- terracotta/drivers/relational_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/terracotta/drivers/relational_base.py b/terracotta/drivers/relational_base.py index 1dc27013..1b2ef8cf 100644 --- a/terracotta/drivers/relational_base.py +++ b/terracotta/drivers/relational_base.py @@ -240,7 +240,7 @@ def _initialize_database( sqla.Column(key, self.SQLA_STRING(self.SQL_KEY_SIZE), primary_key=True) for key in keys ], - sqla.Column('filepath', self.SQLA_STRING(8000)) + sqla.Column('handle', self.SQLA_STRING(8000)) ) _ = sqla.Table( 'metadata', self.sqla_metadata, @@ -321,7 +321,7 @@ def get_datasets( def keytuple(row: Dict[str, Any]) -> Tuple[str, ...]: return tuple(row[key] for key in self.key_names) - datasets = {keytuple(row): row['filepath'] for row in result} + datasets = {keytuple(row): row['handle'] for row in result} return datasets @trace('get_metadata') @@ -366,7 +366,7 @@ def insert( .where(*[datasets_table.c.get(column) == value for column, value in keys.items()]) ) self.connection.execute( - datasets_table.insert().values(**keys, filepath=handle) + datasets_table.insert().values(**keys, handle=handle) ) if metadata is not None: From f89052e294ff46ba18caff5cf63426d0fc7436da Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 1 Feb 2022 17:25:24 +0100 Subject: [PATCH 082/107] Don't print anything --- terracotta/drivers/relational_base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/terracotta/drivers/relational_base.py b/terracotta/drivers/relational_base.py index 1b2ef8cf..28de6d43 100644 --- a/terracotta/drivers/relational_base.py +++ b/terracotta/drivers/relational_base.py @@ -392,7 +392,6 @@ def delete(self, keys: KeysType) -> None: datasets_table = sqla.Table('datasets', self.sqla_metadata, autoload_with=self.sqla_engine) metadata_table = sqla.Table('metadata', self.sqla_metadata, autoload_with=self.sqla_engine) - print(keys) self.connection.execute( datasets_table .delete() From ca15c4ad882f9810365e8784b05a7239bd45a53c Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 1 Feb 2022 17:31:20 +0100 Subject: [PATCH 083/107] Rename *_stores --- terracotta/drivers/__init__.py | 8 ++++---- .../{raster_base.py => geotiff_raster_store.py} | 0 terracotta/drivers/{mysql.py => mysql_meta_store.py} | 2 +- .../{relational_base.py => relational_meta_store.py} | 0 terracotta/drivers/{sqlite.py => sqlite_meta_store.py} | 2 +- .../{sqlite_remote.py => sqlite_remote_meta_store.py} | 2 +- terracotta/scripts/optimize_rasters.py | 2 +- tests/drivers/test_raster_drivers.py | 10 +++++----- tests/test_raster.py | 4 ++-- 9 files changed, 15 insertions(+), 15 deletions(-) rename terracotta/drivers/{raster_base.py => geotiff_raster_store.py} (100%) rename terracotta/drivers/{mysql.py => mysql_meta_store.py} (97%) rename terracotta/drivers/{relational_base.py => relational_meta_store.py} (100%) rename terracotta/drivers/{sqlite.py => sqlite_meta_store.py} (96%) rename terracotta/drivers/{sqlite_remote.py => sqlite_remote_meta_store.py} (98%) diff --git a/terracotta/drivers/__init__.py b/terracotta/drivers/__init__.py index ebebfb0f..aa7f7e12 100644 --- a/terracotta/drivers/__init__.py +++ b/terracotta/drivers/__init__.py @@ -10,22 +10,22 @@ from terracotta.drivers.base_classes import MetaStore from terracotta.drivers.terracotta_driver import TerracottaDriver -from terracotta.drivers.raster_base import GeoTiffRasterStore +from terracotta.drivers.geotiff_raster_store import GeoTiffRasterStore URLOrPathType = Union[str, Path] def load_driver(provider: str) -> Type[MetaStore]: if provider == 'sqlite-remote': - from terracotta.drivers.sqlite_remote import RemoteSQLiteDriver + from terracotta.drivers.sqlite_remote_meta_store import RemoteSQLiteDriver return RemoteSQLiteDriver if provider == 'mysql': - from terracotta.drivers.mysql import MySQLDriver + from terracotta.drivers.mysql_meta_store import MySQLDriver return MySQLDriver if provider == 'sqlite': - from terracotta.drivers.sqlite import SQLiteDriver + from terracotta.drivers.sqlite_meta_store import SQLiteDriver return SQLiteDriver raise ValueError(f'Unknown database provider {provider}') diff --git a/terracotta/drivers/raster_base.py b/terracotta/drivers/geotiff_raster_store.py similarity index 100% rename from terracotta/drivers/raster_base.py rename to terracotta/drivers/geotiff_raster_store.py diff --git a/terracotta/drivers/mysql.py b/terracotta/drivers/mysql_meta_store.py similarity index 97% rename from terracotta/drivers/mysql.py rename to terracotta/drivers/mysql_meta_store.py index 597e8075..81a24194 100644 --- a/terracotta/drivers/mysql.py +++ b/terracotta/drivers/mysql_meta_store.py @@ -9,7 +9,7 @@ import sqlalchemy as sqla from sqlalchemy.dialects.mysql import TEXT, VARCHAR -from terracotta.drivers.relational_base import RelationalMetaStore +from terracotta.drivers.relational_meta_store import RelationalMetaStore class MySQLDriver(RelationalMetaStore): diff --git a/terracotta/drivers/relational_base.py b/terracotta/drivers/relational_meta_store.py similarity index 100% rename from terracotta/drivers/relational_base.py rename to terracotta/drivers/relational_meta_store.py diff --git a/terracotta/drivers/sqlite.py b/terracotta/drivers/sqlite_meta_store.py similarity index 96% rename from terracotta/drivers/sqlite.py rename to terracotta/drivers/sqlite_meta_store.py index dea07437..d2d24423 100644 --- a/terracotta/drivers/sqlite.py +++ b/terracotta/drivers/sqlite_meta_store.py @@ -8,7 +8,7 @@ from pathlib import Path from typing import Union -from terracotta.drivers.relational_base import RelationalMetaStore +from terracotta.drivers.relational_meta_store import RelationalMetaStore class SQLiteDriver(RelationalMetaStore): diff --git a/terracotta/drivers/sqlite_remote.py b/terracotta/drivers/sqlite_remote_meta_store.py similarity index 98% rename from terracotta/drivers/sqlite_remote.py rename to terracotta/drivers/sqlite_remote_meta_store.py index e680673e..dec5e868 100644 --- a/terracotta/drivers/sqlite_remote.py +++ b/terracotta/drivers/sqlite_remote_meta_store.py @@ -15,7 +15,7 @@ from typing import Any, Iterator, Union from terracotta import exceptions, get_settings -from terracotta.drivers.sqlite import SQLiteDriver +from terracotta.drivers.sqlite_meta_store import SQLiteDriver from terracotta.profile import trace logger = logging.getLogger(__name__) diff --git a/terracotta/scripts/optimize_rasters.py b/terracotta/scripts/optimize_rasters.py index 693c14dc..85361af2 100644 --- a/terracotta/scripts/optimize_rasters.py +++ b/terracotta/scripts/optimize_rasters.py @@ -85,7 +85,7 @@ def _prefered_compression_method() -> str: def _get_vrt(src: DatasetReader, rs_method: int) -> WarpedVRT: - from terracotta.drivers.raster_base import GeoTiffRasterStore + from terracotta.drivers.geotiff_raster_store import GeoTiffRasterStore target_crs = GeoTiffRasterStore._TARGET_CRS vrt_transform, vrt_width, vrt_height = calculate_default_transform( src.crs, target_crs, src.width, src.height, *src.bounds diff --git a/tests/drivers/test_raster_drivers.py b/tests/drivers/test_raster_drivers.py index 398378df..6905a3be 100644 --- a/tests/drivers/test_raster_drivers.py +++ b/tests/drivers/test_raster_drivers.py @@ -380,7 +380,7 @@ def test_multiprocessing_fallback(driver_path, provider, raster_file, monkeypatc import concurrent.futures from importlib import reload from terracotta import drivers - import terracotta.drivers.raster_base + import terracotta.drivers.geotiff_raster_store def dummy(*args, **kwargs): raise OSError('monkeypatched') @@ -389,7 +389,7 @@ def dummy(*args, **kwargs): with monkeypatch.context() as m, pytest.warns(UserWarning): m.setattr(concurrent.futures, 'ProcessPoolExecutor', dummy) - reload(terracotta.drivers.raster_base) + reload(terracotta.drivers.geotiff_raster_store) db = drivers.get_driver(driver_path, provider=provider) keys = ('some', 'keynames') @@ -405,7 +405,7 @@ def dummy(*args, **kwargs): np.testing.assert_array_equal(data1, data2) finally: - reload(terracotta.drivers.raster_base) + reload(terracotta.drivers.geotiff_raster_store) @pytest.mark.parametrize('provider', DRIVERS) @@ -480,7 +480,7 @@ def test_nodata_consistency(driver_path, provider, big_raster_file_mask, big_ras def test_broken_process_pool(driver_path, provider, raster_file): import concurrent.futures from terracotta import drivers - from terracotta.drivers.raster_base import context + from terracotta.drivers.geotiff_raster_store import context class BrokenPool: def submit(self, *args, **kwargs): @@ -507,7 +507,7 @@ def submit(self, *args, **kwargs): def test_no_multiprocessing(): import concurrent.futures from terracotta import update_settings - from terracotta.drivers.raster_base import create_executor + from terracotta.drivers.geotiff_raster_store import create_executor update_settings(USE_MULTIPROCESSING=False) diff --git a/tests/test_raster.py b/tests/test_raster.py index e907cf13..2e24f4f8 100644 --- a/tests/test_raster.py +++ b/tests/test_raster.py @@ -156,13 +156,13 @@ def test_compute_metadata_nocrick(big_raster_file_nodata, monkeypatch): convex_hull = convex_hull_exact(src) from terracotta import exceptions - import terracotta.drivers.raster_base + import terracotta.drivers.geotiff_raster_store with monkeypatch.context() as m: m.setattr(terracotta.raster, 'has_crick', False) with pytest.warns(exceptions.PerformanceWarning): - mtd = terracotta.drivers.raster_base.raster.compute_metadata( + mtd = terracotta.drivers.geotiff_raster_store.raster.compute_metadata( str(big_raster_file_nodata), use_chunks=True ) From 401728af9c3f960dd81281f8ef2df8a5a8297c70 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Tue, 1 Feb 2022 17:44:39 +0100 Subject: [PATCH 084/107] Re-rename keys to where --- terracotta/drivers/relational_meta_store.py | 2 +- terracotta/drivers/terracotta_driver.py | 8 ++++---- terracotta/handlers/datasets.py | 2 +- tests/drivers/test_raster_drivers.py | 18 +++++++++--------- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/terracotta/drivers/relational_meta_store.py b/terracotta/drivers/relational_meta_store.py index 28de6d43..cffbf1c5 100644 --- a/terracotta/drivers/relational_meta_store.py +++ b/terracotta/drivers/relational_meta_store.py @@ -295,7 +295,7 @@ def get_datasets( ) -> Dict[Tuple[str, ...], str]: if where is None: where = {} - + where = { key: value if isinstance(value, list) else [value] for key, value in where.items() diff --git a/terracotta/drivers/terracotta_driver.py b/terracotta/drivers/terracotta_driver.py index 25539b6e..485b5e72 100644 --- a/terracotta/drivers/terracotta_driver.py +++ b/terracotta/drivers/terracotta_driver.py @@ -5,7 +5,7 @@ import contextlib from collections import OrderedDict -from typing import (Any, Collection, Dict, Mapping, Sequence, Tuple, TypeVar, +from typing import (Any, Collection, Dict, Mapping, Optional, Sequence, Tuple, TypeVar, Union) import terracotta @@ -54,10 +54,10 @@ def get_keys(self) -> OrderedDict: return self.meta_store.get_keys() @requires_connection - def get_datasets(self, keys: MultiValueKeysType = None, + def get_datasets(self, where: MultiValueKeysType = None, page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], Any]: return self.meta_store.get_datasets( - where=self._standardize_keys(keys, requires_all_keys=False), + where=self._standardize_keys(where, requires_all_keys=False), page=page, limit=limit ) @@ -135,7 +135,7 @@ def compute_metadata(self, handle: str, *, def _standardize_keys( self, - keys: Union[ExtendedKeysType, MultiValueKeysType], + keys: Union[ExtendedKeysType, Optional[MultiValueKeysType]], requires_all_keys: bool = True ) -> KeysType: if requires_all_keys and (keys is None or len(keys) != len(self.key_names)): diff --git a/terracotta/handlers/datasets.py b/terracotta/handlers/datasets.py index 4d458937..75012bc5 100644 --- a/terracotta/handlers/datasets.py +++ b/terracotta/handlers/datasets.py @@ -19,7 +19,7 @@ def datasets(some_keys: Mapping[str, Union[str, List[str]]] = None, with driver.connect(): dataset_keys = driver.get_datasets( - keys=some_keys, page=page, limit=limit + where=some_keys, page=page, limit=limit ).keys() key_names = driver.key_names diff --git a/tests/drivers/test_raster_drivers.py b/tests/drivers/test_raster_drivers.py index 6905a3be..7954cf07 100644 --- a/tests/drivers/test_raster_drivers.py +++ b/tests/drivers/test_raster_drivers.py @@ -58,18 +58,18 @@ def test_where(driver_path, provider, raster_file): data = db.get_datasets() assert len(data) == 3 - data = db.get_datasets(keys=dict(some='some')) + data = db.get_datasets(where=dict(some='some')) assert len(data) == 2 - data = db.get_datasets(keys=dict(some='some', keynames='value')) + data = db.get_datasets(where=dict(some='some', keynames='value')) assert list(data.keys()) == [('some', 'value')] assert data[('some', 'value')] == str(raster_file) - data = db.get_datasets(keys=dict(some='unknown')) + data = db.get_datasets(where=dict(some='unknown')) assert data == {} with pytest.raises(exceptions.InvalidKeyError) as exc: - db.get_datasets(keys=dict(unknown='foo')) + db.get_datasets(where=dict(unknown='foo')) assert 'unrecognized keys' in str(exc.value) @@ -87,17 +87,17 @@ def test_where_with_multiquery(driver_path, provider, raster_file): data = db.get_datasets() assert len(data) == 3 - data = db.get_datasets(keys=dict(some=['some'])) + data = db.get_datasets(where=dict(some=['some'])) assert len(data) == 2 - data = db.get_datasets(keys=dict(keynames=['value', 'other_value'])) + data = db.get_datasets(where=dict(keynames=['value', 'other_value'])) assert len(data) == 2 - data = db.get_datasets(keys=dict(some='some', keynames=['value', 'third_value'])) + data = db.get_datasets(where=dict(some='some', keynames=['value', 'third_value'])) assert list(data.keys()) == [('some', 'value')] assert data[('some', 'value')] == str(raster_file) - data = db.get_datasets(keys=dict(some=['unknown'])) + data = db.get_datasets(where=dict(some=['unknown'])) assert data == {} @@ -121,7 +121,7 @@ def test_pagination(driver_path, provider, raster_file): data = db.get_datasets(limit=2, page=1) assert len(data) == 1 - data = db.get_datasets(keys=dict(some='some'), limit=1, page=0) + data = db.get_datasets(where=dict(some='some'), limit=1, page=0) assert len(data) == 1 From 62be08db20e485d7e4c721d0f5a81e7ec5185236 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Fri, 4 Feb 2022 15:25:53 +0100 Subject: [PATCH 085/107] Check for missing dataset in get_metadata, not in squeeze --- terracotta/drivers/terracotta_driver.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/terracotta/drivers/terracotta_driver.py b/terracotta/drivers/terracotta_driver.py index 485b5e72..6e97a36e 100644 --- a/terracotta/drivers/terracotta_driver.py +++ b/terracotta/drivers/terracotta_driver.py @@ -19,8 +19,6 @@ def squeeze(iterable: Collection[T]) -> T: - if not iterable: - raise exceptions.DatasetNotFoundError('No dataset found') assert len(iterable) == 1 return next(iter(iterable)) @@ -70,7 +68,11 @@ def get_metadata(self, keys: ExtendedKeysType) -> Dict[str, Any]: if metadata is None: # metadata is not computed yet, trigger lazy loading - handle = squeeze(self.get_datasets(keys).values()) + dataset = self.get_datasets(keys) + if not dataset: + raise exceptions.DatasetNotFoundError('No dataset found') + + handle = squeeze(dataset.values()) metadata = self.compute_metadata(handle, max_shape=self.LAZY_LOADING_MAX_SHAPE) self.insert(keys, handle, metadata=metadata) From 236f677f97dae1a25885c05d9b362d3a94424085 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Fri, 4 Feb 2022 15:54:48 +0100 Subject: [PATCH 086/107] Define keystype explicitly --- terracotta/drivers/terracotta_driver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/terracotta/drivers/terracotta_driver.py b/terracotta/drivers/terracotta_driver.py index 6e97a36e..f33e43d1 100644 --- a/terracotta/drivers/terracotta_driver.py +++ b/terracotta/drivers/terracotta_driver.py @@ -14,7 +14,7 @@ MultiValueKeysType, RasterStore, requires_connection) -ExtendedKeysType = Union[Sequence[str], KeysType] +ExtendedKeysType = Union[Sequence[str], Mapping[str, str]] T = TypeVar('T') From c8d93ee7034d31b40c1dfdd0e083596c3f3fcecb Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Fri, 4 Feb 2022 16:11:24 +0100 Subject: [PATCH 087/107] Make keys standardization type check --- terracotta/drivers/terracotta_driver.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/terracotta/drivers/terracotta_driver.py b/terracotta/drivers/terracotta_driver.py index f33e43d1..1d84d491 100644 --- a/terracotta/drivers/terracotta_driver.py +++ b/terracotta/drivers/terracotta_driver.py @@ -5,7 +5,7 @@ import contextlib from collections import OrderedDict -from typing import (Any, Collection, Dict, Mapping, Optional, Sequence, Tuple, TypeVar, +from typing import (Any, Collection, Dict, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union) import terracotta @@ -15,6 +15,7 @@ requires_connection) ExtendedKeysType = Union[Sequence[str], Mapping[str, str]] +ExtendedMultiValueKeysType = Union[Sequence[str], Mapping[str, Union[str, List[str]]]] T = TypeVar('T') @@ -55,7 +56,7 @@ def get_keys(self) -> OrderedDict: def get_datasets(self, where: MultiValueKeysType = None, page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], Any]: return self.meta_store.get_datasets( - where=self._standardize_keys(where, requires_all_keys=False), + where=self._standardize_multi_value_keys(where, requires_all_keys=False), page=page, limit=limit ) @@ -136,10 +137,20 @@ def compute_metadata(self, handle: str, *, ) def _standardize_keys( + self, keys: ExtendedKeysType, requires_all_keys: bool = True + ) -> KeysType: + return self._ensure_keys_as_dict(keys, requires_all_keys) + + def _standardize_multi_value_keys( + self, keys: Optional[ExtendedMultiValueKeysType], requires_all_keys: bool = True + ) -> MultiValueKeysType: + return self._ensure_keys_as_dict(keys, requires_all_keys) + + def _ensure_keys_as_dict( self, keys: Union[ExtendedKeysType, Optional[MultiValueKeysType]], requires_all_keys: bool = True - ) -> KeysType: + ) -> Dict[str, Any]: if requires_all_keys and (keys is None or len(keys) != len(self.key_names)): raise exceptions.InvalidKeyError( f'Got wrong number of keys (available keys: {self.key_names})' From f523ebe44133636bfb3356cb03b66f3e5a33ce31 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Fri, 4 Feb 2022 16:19:53 +0100 Subject: [PATCH 088/107] Improve descriptiveness of metadata reload comment --- terracotta/drivers/terracotta_driver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/terracotta/drivers/terracotta_driver.py b/terracotta/drivers/terracotta_driver.py index 1d84d491..3d487fdf 100644 --- a/terracotta/drivers/terracotta_driver.py +++ b/terracotta/drivers/terracotta_driver.py @@ -77,7 +77,7 @@ def get_metadata(self, keys: ExtendedKeysType) -> Dict[str, Any]: metadata = self.compute_metadata(handle, max_shape=self.LAZY_LOADING_MAX_SHAPE) self.insert(keys, handle, metadata=metadata) - # this is necessary to make the lazy loading tests pass... + # ensure standardized/consistent output (types and floating point precision) metadata = self.meta_store.get_metadata(keys) assert metadata is not None From 0cce1b7e02d8b925f7ca7fdfb60e9cf76a2f134f Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Fri, 4 Feb 2022 16:31:20 +0100 Subject: [PATCH 089/107] Re-rename handle to path --- terracotta/drivers/base_classes.py | 14 +++++++------- terracotta/drivers/geotiff_raster_store.py | 8 ++++---- terracotta/drivers/relational_meta_store.py | 8 ++++---- terracotta/drivers/terracotta_driver.py | 20 ++++++++++---------- terracotta/raster.py | 14 +++++++------- 5 files changed, 32 insertions(+), 32 deletions(-) diff --git a/terracotta/drivers/base_classes.py b/terracotta/drivers/base_classes.py index 5bd1134a..e5f11c80 100644 --- a/terracotta/drivers/base_classes.py +++ b/terracotta/drivers/base_classes.py @@ -111,7 +111,7 @@ def get_keys(self) -> OrderedDict: def get_datasets(self, where: MultiValueKeysType = None, page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], Any]: # Get all known dataset key combinations matching the given constraints, - # and a handle to retrieve the data (driver dependent) + # and a path to retrieve the data (driver dependent) pass @abstractmethod @@ -141,14 +141,14 @@ def get_metadata(self, keys: KeysType) -> Optional[Dict[str, Any]]: @abstractmethod def insert(self, keys: KeysType, - handle: Any, **kwargs: Any) -> None: + path: Any, **kwargs: Any) -> None: """Register a new dataset. Used to populate metadata database. Arguments: keys: Keys of the dataset. Can either be given as a sequence of key values, or as a mapping ``{key_name: key_value}``. - handle: Handle to access dataset (driver dependent). + path: Path to access dataset (driver dependent). """ pass @@ -173,16 +173,16 @@ class RasterStore(ABC): @abstractmethod # TODO: add accurate signature if mypy ever supports conditional return types - def get_raster_tile(self, handle: str, *, + def get_raster_tile(self, path: str, *, tile_bounds: Sequence[float] = None, tile_size: Sequence[int] = (256, 256), preserve_values: bool = False, asynchronous: bool = False) -> Any: - """Load a raster tile with given handle and bounds. + """Load a raster tile with given path and bounds. Arguments: - handle: Handle of the requested dataset. + path: Path of the requested dataset. tile_bounds: Physical bounds of the tile to read, in Web Mercator projection (EPSG3857). Reads the whole dataset if not given. tile_size: Shape of the output array to return. Must be two-dimensional. @@ -203,7 +203,7 @@ def get_raster_tile(self, handle: str, *, pass @abstractmethod - def compute_metadata(self, handle: str, *, + def compute_metadata(self, path: str, *, extra_metadata: Any = None, use_chunks: bool = None, max_shape: Sequence[int] = None) -> Dict[str, Any]: diff --git a/terracotta/drivers/geotiff_raster_store.py b/terracotta/drivers/geotiff_raster_store.py index 7d30461d..122f7fd2 100644 --- a/terracotta/drivers/geotiff_raster_store.py +++ b/terracotta/drivers/geotiff_raster_store.py @@ -87,18 +87,18 @@ def __init__(self) -> None: from rasterio import Env self._rio_env = Env(**self._RIO_ENV_OPTIONS) - def compute_metadata(self, handle: str, *, + def compute_metadata(self, path: str, *, extra_metadata: Any = None, use_chunks: bool = None, max_shape: Sequence[int] = None) -> Dict[str, Any]: - return raster.compute_metadata(handle, extra_metadata=extra_metadata, + return raster.compute_metadata(path, extra_metadata=extra_metadata, use_chunks=use_chunks, max_shape=max_shape, large_raster_threshold=self._LARGE_RASTER_THRESHOLD, rio_env=self._rio_env) # return type has to be Any until mypy supports conditional return types def get_raster_tile(self, - handle: str, *, + path: str, *, tile_bounds: Sequence[float] = None, tile_size: Sequence[int] = None, preserve_values: bool = False, @@ -113,7 +113,7 @@ def get_raster_tile(self, # make sure all arguments are hashable kwargs = dict( - path=handle, + path=path, tile_bounds=tuple(tile_bounds) if tile_bounds else None, tile_size=tuple(tile_size), preserve_values=preserve_values, diff --git a/terracotta/drivers/relational_meta_store.py b/terracotta/drivers/relational_meta_store.py index cffbf1c5..97af46a9 100644 --- a/terracotta/drivers/relational_meta_store.py +++ b/terracotta/drivers/relational_meta_store.py @@ -240,7 +240,7 @@ def _initialize_database( sqla.Column(key, self.SQLA_STRING(self.SQL_KEY_SIZE), primary_key=True) for key in keys ], - sqla.Column('handle', self.SQLA_STRING(8000)) + sqla.Column('path', self.SQLA_STRING(8000)) ) _ = sqla.Table( 'metadata', self.sqla_metadata, @@ -321,7 +321,7 @@ def get_datasets( def keytuple(row: Dict[str, Any]) -> Tuple[str, ...]: return tuple(row[key] for key in self.key_names) - datasets = {keytuple(row): row['handle'] for row in result} + datasets = {keytuple(row): row['path'] for row in result} return datasets @trace('get_metadata') @@ -354,7 +354,7 @@ def get_metadata(self, keys: KeysType) -> Optional[Dict[str, Any]]: def insert( self, keys: KeysType, - handle: str, *, + path: str, *, metadata: Mapping[str, Any] = None ) -> None: datasets_table = sqla.Table('datasets', self.sqla_metadata, autoload_with=self.sqla_engine) @@ -366,7 +366,7 @@ def insert( .where(*[datasets_table.c.get(column) == value for column, value in keys.items()]) ) self.connection.execute( - datasets_table.insert().values(**keys, handle=handle) + datasets_table.insert().values(**keys, path=path) ) if metadata is not None: diff --git a/terracotta/drivers/terracotta_driver.py b/terracotta/drivers/terracotta_driver.py index 3d487fdf..11a476df 100644 --- a/terracotta/drivers/terracotta_driver.py +++ b/terracotta/drivers/terracotta_driver.py @@ -73,9 +73,9 @@ def get_metadata(self, keys: ExtendedKeysType) -> Dict[str, Any]: if not dataset: raise exceptions.DatasetNotFoundError('No dataset found') - handle = squeeze(dataset.values()) - metadata = self.compute_metadata(handle, max_shape=self.LAZY_LOADING_MAX_SHAPE) - self.insert(keys, handle, metadata=metadata) + path = squeeze(dataset.values()) + metadata = self.compute_metadata(path, max_shape=self.LAZY_LOADING_MAX_SHAPE) + self.insert(keys, path, metadata=metadata) # ensure standardized/consistent output (types and floating point precision) metadata = self.meta_store.get_metadata(keys) @@ -86,7 +86,7 @@ def get_metadata(self, keys: ExtendedKeysType) -> Dict[str, Any]: @requires_connection def insert( self, keys: ExtendedKeysType, - handle: Any, *, + path: Any, *, override_path: str = None, metadata: Mapping[str, Any] = None, skip_metadata: bool = False, @@ -95,11 +95,11 @@ def insert( keys = self._standardize_keys(keys) if metadata is None and not skip_metadata: - metadata = self.compute_metadata(handle) + metadata = self.compute_metadata(path) self.meta_store.insert( keys=keys, - handle=override_path or handle, + path=override_path or path, metadata=metadata, **kwargs ) @@ -115,22 +115,22 @@ def get_raster_tile(self, keys: ExtendedKeysType, *, tile_size: Sequence[int] = (256, 256), preserve_values: bool = False, asynchronous: bool = False) -> Any: - handle = squeeze(self.get_datasets(keys).values()) + path = squeeze(self.get_datasets(keys).values()) return self.raster_store.get_raster_tile( - handle=handle, + path=path, tile_bounds=tile_bounds, tile_size=tile_size, preserve_values=preserve_values, asynchronous=asynchronous, ) - def compute_metadata(self, handle: str, *, + def compute_metadata(self, path: str, *, extra_metadata: Any = None, use_chunks: bool = None, max_shape: Sequence[int] = None) -> Dict[str, Any]: return self.raster_store.compute_metadata( - handle=handle, + path=path, extra_metadata=extra_metadata, use_chunks=use_chunks, max_shape=max_shape, diff --git a/terracotta/raster.py b/terracotta/raster.py index 6d887db0..c3cc21cb 100644 --- a/terracotta/raster.py +++ b/terracotta/raster.py @@ -176,7 +176,7 @@ def compute_image_stats(dataset: 'DatasetReader', @trace('compute_metadata') -def compute_metadata(handle: str, *, +def compute_metadata(path: str, *, extra_metadata: Any = None, use_chunks: bool = None, max_shape: Sequence[int] = None, @@ -199,18 +199,18 @@ def compute_metadata(handle: str, *, rio_env = rasterio.Env() with rio_env: - if not validate(handle): + if not validate(path): warnings.warn( - f'Raster file {handle} is not a valid cloud-optimized GeoTIFF. ' + f'Raster file {path} is not a valid cloud-optimized GeoTIFF. ' 'Any interaction with it will be significantly slower. Consider optimizing ' 'it through `terracotta optimize-rasters` before ingestion.', exceptions.PerformanceWarning, stacklevel=3 ) - with rasterio.open(handle) as src: + with rasterio.open(path) as src: if src.nodata is None and not has_alpha_band(src): warnings.warn( - f'Raster file {handle} does not have a valid nodata value, ' + f'Raster file {path} does not have a valid nodata value, ' 'and does not contain an alpha band. No data will be masked.' ) @@ -223,7 +223,7 @@ def compute_metadata(handle: str, *, if use_chunks: logger.debug( - f'Computing metadata for file {handle} using more than ' + f'Computing metadata for file {path} using more than ' f'{large_raster_threshold // 10**6}M pixels, iterating ' 'over chunks' ) @@ -241,7 +241,7 @@ def compute_metadata(handle: str, *, raster_stats = compute_image_stats(src, max_shape) if raster_stats is None: - raise ValueError(f'Raster file {handle} does not contain any valid data') + raise ValueError(f'Raster file {path} does not contain any valid data') row_data.update(raster_stats) From 8aad626bf072a89e571c99aeb5b2e052a8d71d77 Mon Sep 17 00:00:00 2001 From: Philip Graae Date: Sun, 20 Feb 2022 19:07:18 +0100 Subject: [PATCH 090/107] update docstrings --- terracotta/drivers/__init__.py | 6 +- terracotta/drivers/base_classes.py | 7 +- terracotta/drivers/terracotta_driver.py | 158 ++++++++++++++++++++++++ 3 files changed, 166 insertions(+), 5 deletions(-) diff --git a/terracotta/drivers/__init__.py b/terracotta/drivers/__init__.py index aa7f7e12..255dc7d9 100644 --- a/terracotta/drivers/__init__.py +++ b/terracotta/drivers/__init__.py @@ -68,12 +68,12 @@ def get_driver(url_or_path: URLOrPathType, provider: str = None) -> TerracottaDr >>> import terracotta as tc >>> tc.get_driver('tc.sqlite') - SQLiteDriver('/home/terracotta/tc.sqlite') + TerracottaDriver(meta_store=SQLiteDriver('/home/terracotta/tc.sqlite')) >>> tc.get_driver('mysql://root@localhost/tc') - MySQLDriver('mysql://root@localhost:3306/tc') + TerracottaDriver(meta_store=MySQLDriver('mysql://root@localhost:3306/tc')) >>> # pass provider if path is given in a non-standard way >>> tc.get_driver('root@localhost/tc', provider='mysql') - MySQLDriver('mysql://root@localhost:3306/tc') + TerracottaDriver(meta_store=MySQLDriver('mysql://root@localhost:3306/tc')) """ if provider is None: # try and auto-detect diff --git a/terracotta/drivers/base_classes.py b/terracotta/drivers/base_classes.py index e5f11c80..b0eb4933 100644 --- a/terracotta/drivers/base_classes.py +++ b/terracotta/drivers/base_classes.py @@ -33,9 +33,9 @@ def inner(self: MetaStore, *args: Any, **kwargs: Any) -> T: class MetaStore(ABC): - """Abstract base class for all Terracotta data backends. + """Abstract base class for all Terracotta metadata backends. - Defines a common interface for all drivers. + Defines a common interface for all metadata backends. """ _RESERVED_KEYS = ('limit', 'page') @@ -170,6 +170,9 @@ def __repr__(self) -> str: class RasterStore(ABC): + """Abstract base class for all Terracotta raster backends. + + Defines a common interface for all raster backends.""" @abstractmethod # TODO: add accurate signature if mypy ever supports conditional return types diff --git a/terracotta/drivers/terracotta_driver.py b/terracotta/drivers/terracotta_driver.py index 11a476df..02eba934 100644 --- a/terracotta/drivers/terracotta_driver.py +++ b/terracotta/drivers/terracotta_driver.py @@ -35,26 +35,93 @@ def __init__(self, meta_store: MetaStore, raster_store: RasterStore) -> None: @property def db_version(self) -> str: + """Terracotta version used to create the meta store. + + Returns + + A str specifying the version of Terracotta that was used to create the meta store. + """ return self.meta_store.db_version @property def key_names(self) -> Tuple[str, ...]: + """Get names of all keys defined by the meta store. + + Returns: + + A tuple defining the key names and order. + """ return self.meta_store.key_names def create(self, keys: Sequence[str], *, key_descriptions: Mapping[str, str] = None) -> None: + """Create a new, empty metastore. + + Arguments: + + keys: A sequence defining the key names and order. + key_descriptions: A mapping from key name to a human-readable + description of what the key encodes. + """ self.meta_store.create(keys=keys, key_descriptions=key_descriptions) def connect(self, verify: bool = True) -> contextlib.AbstractContextManager: + """Context manager to connect to the metastore and clean up on exit. + + This allows you to pool interactions with the metastore to prevent possibly + expensive reconnects, or to roll back several interactions if one of them fails. + + Arguments: + + verify: Whether to verify the metastore (primarily its version) when connecting. + Should be `true` unless absolutely necessary, such as when instantiating the + metastore during creation of it. + + Note: + + Make sure to call :meth:`create` on a fresh metastore before using this method. + + Example: + + >>> import terracotta as tc + >>> driver = tc.get_driver('tc.sqlite') + >>> with driver.connect(): + ... for keys, dataset in datasets.items(): + ... # connection will be kept open between insert operations + ... driver.insert(keys, dataset) + + """ return self.meta_store.connect(verify=verify) @requires_connection def get_keys(self) -> OrderedDict: + """Get all known keys and their fulltext descriptions. + + Returns: + + An :class:`~collections.OrderedDict` in the form + ``{key_name: key_description}`` + + """ return self.meta_store.get_keys() @requires_connection def get_datasets(self, where: MultiValueKeysType = None, page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], Any]: + """Get all known dataset key combinations matching the given constraints, + and a path to retrieve the data (dependent on the raster store). + + Arguments: + + where: A mapping from key name to key value constraint(s) + page: A pagination parameter, skips first page * limit results + limit: A pagination parameter, max number of results to return + + Returns: + + A :class:`dict` mapping from key sequence tuple to dataset path. + + """ return self.meta_store.get_datasets( where=self._standardize_multi_value_keys(where, requires_all_keys=False), page=page, @@ -63,6 +130,27 @@ def get_datasets(self, where: MultiValueKeysType = None, @requires_connection def get_metadata(self, keys: ExtendedKeysType) -> Dict[str, Any]: + """Return all stored metadata for given keys. + + Arguments: + + keys: Keys of the requested dataset. Can either be given as a sequence of key values, + or as a mapping ``{key_name: key_value}``. + + Returns: + + A :class:`dict` with the values + + - ``range``: global minimum and maximum value in dataset + - ``bounds``: physical bounds covered by dataset in latitude-longitude projection + - ``convex_hull``: GeoJSON shape specifying total data coverage in latitude-longitude + projection + - ``percentiles``: array of pre-computed percentiles from 1% through 99% + - ``mean``: global mean + - ``stdev``: global standard deviation + - ``metadata``: any additional client-relevant metadata + + """ keys = self._standardize_keys(keys) metadata = self.meta_store.get_metadata(keys) @@ -92,6 +180,19 @@ def insert( skip_metadata: bool = False, **kwargs: Any ) -> None: + """Register a new dataset. Used to populate meta store. + + Arguments: + + keys: Keys of the dataset. Can either be given as a sequence of key values, or + as a mapping ``{key_name: key_value}``. + path: Path to access dataset (driver dependent). + override_path: If given, this path will be inserted into the meta store + instead of the one used to load the dataset. + metadata: Metadata for the dataset. If not given, metadata will be computed. + skip_metadata: If True, will skip metadata computation, even if metadata is not given. + + """ keys = self._standardize_keys(keys) if metadata is None and not skip_metadata: @@ -106,6 +207,14 @@ def insert( @requires_connection def delete(self, keys: ExtendedKeysType) -> None: + """Remove a dataset from the meta store. + + Arguments: + + keys: Keys of the dataset. Can either be given as a sequence of key values, or + as a mapping ``{key_name: key_value}``. + + """ keys = self._standardize_keys(keys) self.meta_store.delete(keys) @@ -115,6 +224,28 @@ def get_raster_tile(self, keys: ExtendedKeysType, *, tile_size: Sequence[int] = (256, 256), preserve_values: bool = False, asynchronous: bool = False) -> Any: + """Load a raster tile with given keys and bounds. + + Arguments: + + keys: Key sequence identifying the dataset to load tile from. + tile_bounds: Physical bounds of the tile to read, in Web Mercator projection (EPSG3857). + Reads the whole dataset if not given. + tile_size: Shape of the output array to return. Must be two-dimensional. + Defaults to :attr:`~terracotta.config.TerracottaSettings.DEFAULT_TILE_SIZE`. + preserve_values: Whether to preserve exact numerical values (e.g. when reading + categorical data). Sets all interpolation to nearest neighbor. + asynchronous: If given, the tile will be read asynchronously in a separate thread. + This function will return immediately with a :class:`~concurrent.futures.Future` + that can be used to retrieve the result. + + Returns: + + Requested tile as :class:`~numpy.ma.MaskedArray` of shape ``tile_size`` if + ``asynchronous=False``, otherwise a :class:`~concurrent.futures.Future` containing + the result. + + """ path = squeeze(self.get_datasets(keys).values()) return self.raster_store.get_raster_tile( @@ -129,6 +260,33 @@ def compute_metadata(self, path: str, *, extra_metadata: Any = None, use_chunks: bool = None, max_shape: Sequence[int] = None) -> Dict[str, Any]: + """Compute metadata for a dataset. + + Arguments: + + path: Path identifing the dataset. + extra_metadata: Any additional metadata that will be returned as is + in the result, under the `metadata` key. + use_chunks: Whether to load the dataset in chunks, when computing. + Useful if the dataset is too large to fit in memory. + Mutually exclusive with `max_shape`. + max_shape: If dataset is larger than this shape, it will be downsampled + while loading. Useful if the dataset is too large to fit in memory. + Mutually exclusive with `use_chunks`. + + Returns: + + A :class:`dict` with the values + + - ``range``: global minimum and maximum value in dataset + - ``bounds``: physical bounds covered by dataset in latitude-longitude projection + - ``convex_hull``: GeoJSON shape specifying total data coverage in latitude-longitude + projection + - ``percentiles``: array of pre-computed percentiles from 1% through 99% + - ``mean``: global mean + - ``stdev``: global standard deviation + - ``metadata``: any additional client-relevant metadata + """ return self.raster_store.compute_metadata( path=path, extra_metadata=extra_metadata, From 0c1c94ca0db31a77a3df89c593aebe2f73c8be20 Mon Sep 17 00:00:00 2001 From: Philip Graae Date: Sun, 20 Feb 2022 19:43:42 +0100 Subject: [PATCH 091/107] pin pytest<7.0 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c18ddb15..b8c27254 100644 --- a/setup.py +++ b/setup.py @@ -78,7 +78,7 @@ ], extras_require={ 'test': [ - 'pytest', + 'pytest<7.0', 'pytest-cov', 'pytest-mypy', 'pytest-flake8', From 06a6d1a4b26fa66cc529ea68cf91d505a7efa77b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Mon, 21 Feb 2022 11:17:53 +0100 Subject: [PATCH 092/107] do not assemble rio env in driver --- terracotta/drivers/geotiff_raster_store.py | 22 ++++++++++++++-------- terracotta/raster.py | 17 ++++++++--------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/terracotta/drivers/geotiff_raster_store.py b/terracotta/drivers/geotiff_raster_store.py index 122f7fd2..e0af4a5a 100644 --- a/terracotta/drivers/geotiff_raster_store.py +++ b/terracotta/drivers/geotiff_raster_store.py @@ -64,6 +64,16 @@ def submit_to_executor(task: Callable[..., Any]) -> Future: return future +def ensure_hashable(val: Any) -> Any: + if isinstance(val, list): + return tuple(val) + + if isinstance(val, dict): + return tuple((k, ensure_hashable(v)) for k, v in val.items()) + + return val + + class GeoTiffRasterStore(RasterStore): """Mixin that implements methods to load raster data from disk. @@ -84,9 +94,6 @@ def __init__(self) -> None: ) self._cache_lock = threading.RLock() - from rasterio import Env - self._rio_env = Env(**self._RIO_ENV_OPTIONS) - def compute_metadata(self, path: str, *, extra_metadata: Any = None, use_chunks: bool = None, @@ -94,7 +101,7 @@ def compute_metadata(self, path: str, *, return raster.compute_metadata(path, extra_metadata=extra_metadata, use_chunks=use_chunks, max_shape=max_shape, large_raster_threshold=self._LARGE_RASTER_THRESHOLD, - rio_env=self._rio_env) + rio_env_options=self._RIO_ENV_OPTIONS) # return type has to be Any until mypy supports conditional return types def get_raster_tile(self, @@ -111,19 +118,18 @@ def get_raster_tile(self, if tile_size is None: tile_size = settings.DEFAULT_TILE_SIZE - # make sure all arguments are hashable kwargs = dict( path=path, - tile_bounds=tuple(tile_bounds) if tile_bounds else None, + tile_bounds=tile_bounds, tile_size=tuple(tile_size), preserve_values=preserve_values, reprojection_method=settings.REPROJECTION_METHOD, resampling_method=settings.RESAMPLING_METHOD, target_crs=self._TARGET_CRS, - rio_env=self._rio_env, + rio_env_options=self._RIO_ENV_OPTIONS, ) - cache_key = hash(tuple(kwargs.items())) + cache_key = hash(ensure_hashable(kwargs)) try: with self._cache_lock: diff --git a/terracotta/raster.py b/terracotta/raster.py index c3cc21cb..428a65e8 100644 --- a/terracotta/raster.py +++ b/terracotta/raster.py @@ -12,7 +12,6 @@ if TYPE_CHECKING: # pragma: no cover from rasterio.io import DatasetReader # noqa: F401 - from rasterio import Env try: from crick import TDigest, SummaryStats @@ -181,7 +180,7 @@ def compute_metadata(path: str, *, use_chunks: bool = None, max_shape: Sequence[int] = None, large_raster_threshold: int = None, - rio_env: 'Env' = None) -> Dict[str, Any]: + rio_env_options: Dict[str, Any] = None) -> Dict[str, Any]: import rasterio from rasterio import warp from terracotta.cog import validate @@ -195,10 +194,10 @@ def compute_metadata(path: str, *, if use_chunks and max_shape is not None: raise ValueError('Cannot use both use_chunks and max_shape arguments') - if rio_env is None: - rio_env = rasterio.Env() + if rio_env_options is None: + rio_env_options = {} - with rio_env: + with rasterio.Env(**rio_env_options): if not validate(path): warnings.warn( f'Raster file {path} is not a valid cloud-optimized GeoTIFF. ' @@ -284,7 +283,7 @@ def get_raster_tile(path: str, *, tile_size: Tuple[int, int] = (256, 256), preserve_values: bool = False, target_crs: str = 'epsg:3857', - rio_env: 'Env' = None) -> np.ma.MaskedArray: + rio_env_options: Dict[str, Any] = None) -> np.ma.MaskedArray: """Load a raster dataset from a file through rasterio. Heavily inspired by mapbox/rio-tiler @@ -296,8 +295,8 @@ def get_raster_tile(path: str, *, dst_bounds: Tuple[float, float, float, float] - if rio_env is None: - rio_env = rasterio.Env() + if rio_env_options is None: + rio_env_options = {} if preserve_values: reproject_enum = resampling_enum = get_resampling_enum('nearest') @@ -306,7 +305,7 @@ def get_raster_tile(path: str, *, resampling_enum = get_resampling_enum(resampling_method) with contextlib.ExitStack() as es: - es.enter_context(rio_env) + es.enter_context(rasterio.Env(**rio_env_options)) try: with trace('open_dataset'): src = es.enter_context(rasterio.open(path)) From da9f20f797b252186576062ad1da84a63f4f1168 Mon Sep 17 00:00:00 2001 From: Philip Graae Date: Mon, 21 Feb 2022 11:28:06 +0100 Subject: [PATCH 093/107] Update filename in module docstring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Dion Häfner --- terracotta/drivers/geotiff_raster_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/terracotta/drivers/geotiff_raster_store.py b/terracotta/drivers/geotiff_raster_store.py index e0af4a5a..cc8093ed 100644 --- a/terracotta/drivers/geotiff_raster_store.py +++ b/terracotta/drivers/geotiff_raster_store.py @@ -1,4 +1,4 @@ -"""drivers/raster_base.py +"""drivers/geotiff_raster_store.py Base class for drivers operating on physical raster files. """ From 891185a2e942975b5f80d697221bd445584d6999 Mon Sep 17 00:00:00 2001 From: Philip Graae Date: Mon, 21 Feb 2022 11:35:33 +0100 Subject: [PATCH 094/107] docstring polish :memo: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Dion Häfner --- terracotta/drivers/geotiff_raster_store.py | 6 +++--- terracotta/drivers/terracotta_driver.py | 19 ++++++++++++++----- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/terracotta/drivers/geotiff_raster_store.py b/terracotta/drivers/geotiff_raster_store.py index cc8093ed..42afadf0 100644 --- a/terracotta/drivers/geotiff_raster_store.py +++ b/terracotta/drivers/geotiff_raster_store.py @@ -75,9 +75,9 @@ def ensure_hashable(val: Any) -> Any: class GeoTiffRasterStore(RasterStore): - """Mixin that implements methods to load raster data from disk. - - get_datasets has to return path to raster file as sole dict value. + """Raster store that operates on GeoTiff raster files from disk. + + Path arguments are expected to be file paths. """ _TARGET_CRS: str = 'epsg:3857' _LARGE_RASTER_THRESHOLD: int = 10980 * 10980 diff --git a/terracotta/drivers/terracotta_driver.py b/terracotta/drivers/terracotta_driver.py index 02eba934..30361273 100644 --- a/terracotta/drivers/terracotta_driver.py +++ b/terracotta/drivers/terracotta_driver.py @@ -25,7 +25,10 @@ def squeeze(iterable: Collection[T]) -> T: class TerracottaDriver: - + """Terracotta driver object used to retrieve raster tiles and metadata. + + Do not instantiate directly, use :func:`terracotta.get_driver` instead. + """ def __init__(self, meta_store: MetaStore, raster_store: RasterStore) -> None: self.meta_store = meta_store self.raster_store = raster_store @@ -37,9 +40,10 @@ def __init__(self, meta_store: MetaStore, raster_store: RasterStore) -> None: def db_version(self) -> str: """Terracotta version used to create the meta store. - Returns + Returns: A str specifying the version of Terracotta that was used to create the meta store. + """ return self.meta_store.db_version @@ -50,18 +54,20 @@ def key_names(self) -> Tuple[str, ...]: Returns: A tuple defining the key names and order. + """ return self.meta_store.key_names def create(self, keys: Sequence[str], *, key_descriptions: Mapping[str, str] = None) -> None: - """Create a new, empty metastore. + """Create a new, empty metadata store. Arguments: keys: A sequence defining the key names and order. key_descriptions: A mapping from key name to a human-readable description of what the key encodes. + """ self.meta_store.create(keys=keys, key_descriptions=key_descriptions) @@ -189,8 +195,10 @@ def insert( path: Path to access dataset (driver dependent). override_path: If given, this path will be inserted into the meta store instead of the one used to load the dataset. - metadata: Metadata for the dataset. If not given, metadata will be computed. - skip_metadata: If True, will skip metadata computation, even if metadata is not given. + metadata: Metadata dict for the dataset. If not given, metadata will be computed + via :meth:`compute_metadata`. + skip_metadata: If True, will skip metadata computation (will be computed + during first request instead). Has no effect if ``metadata`` argument is given. """ keys = self._standardize_keys(keys) @@ -286,6 +294,7 @@ def compute_metadata(self, path: str, *, - ``mean``: global mean - ``stdev``: global standard deviation - ``metadata``: any additional client-relevant metadata + """ return self.raster_store.compute_metadata( path=path, From dc835a6a2885ac45ac03851cb334bd155eff9282 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Mon, 21 Feb 2022 13:19:28 +0100 Subject: [PATCH 095/107] Improve reprs and satisfy flake8 --- terracotta/drivers/base_classes.py | 3 +++ terracotta/drivers/geotiff_raster_store.py | 2 +- terracotta/drivers/terracotta_driver.py | 4 ++-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/terracotta/drivers/base_classes.py b/terracotta/drivers/base_classes.py index b0eb4933..ec77d657 100644 --- a/terracotta/drivers/base_classes.py +++ b/terracotta/drivers/base_classes.py @@ -212,3 +212,6 @@ def compute_metadata(self, path: str, *, max_shape: Sequence[int] = None) -> Dict[str, Any]: # Compute metadata for a given input file (driver dependent) pass + + def __repr__(self) -> str: + return f'{self.__class__.__name__}()' diff --git a/terracotta/drivers/geotiff_raster_store.py b/terracotta/drivers/geotiff_raster_store.py index 42afadf0..d187383b 100644 --- a/terracotta/drivers/geotiff_raster_store.py +++ b/terracotta/drivers/geotiff_raster_store.py @@ -76,7 +76,7 @@ def ensure_hashable(val: Any) -> Any: class GeoTiffRasterStore(RasterStore): """Raster store that operates on GeoTiff raster files from disk. - + Path arguments are expected to be file paths. """ _TARGET_CRS: str = 'epsg:3857' diff --git a/terracotta/drivers/terracotta_driver.py b/terracotta/drivers/terracotta_driver.py index 30361273..ee21ecd3 100644 --- a/terracotta/drivers/terracotta_driver.py +++ b/terracotta/drivers/terracotta_driver.py @@ -345,7 +345,7 @@ def _ensure_keys_as_dict( def __repr__(self) -> str: return ( f'{self.__class__.__name__}(\n' - f'\tmeta_store={self.meta_store.__class__.__name__}(path="{self.meta_store.path}"),\n' - f'\traster_store={self.raster_store.__class__.__name__}()\n' + f' meta_store={self.meta_store!r},\n' + f' raster_store={self.raster_store!r}\n' ')' ) From 7e75ff48fba551815761021370ca5b943577e893 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Mon, 21 Feb 2022 13:33:32 +0100 Subject: [PATCH 096/107] Improve normalised path from sqlite metastores and update relevant docs --- terracotta/drivers/__init__.py | 15 ++++++++++++--- terracotta/drivers/sqlite_meta_store.py | 3 ++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/terracotta/drivers/__init__.py b/terracotta/drivers/__init__.py index 255dc7d9..c03f113b 100644 --- a/terracotta/drivers/__init__.py +++ b/terracotta/drivers/__init__.py @@ -68,12 +68,21 @@ def get_driver(url_or_path: URLOrPathType, provider: str = None) -> TerracottaDr >>> import terracotta as tc >>> tc.get_driver('tc.sqlite') - TerracottaDriver(meta_store=SQLiteDriver('/home/terracotta/tc.sqlite')) + TerracottaDriver( + meta_store=SQLiteDriver('/home/terracotta/tc.sqlite'), + raster_store=GeoTiffRasterStore() + ) >>> tc.get_driver('mysql://root@localhost/tc') - TerracottaDriver(meta_store=MySQLDriver('mysql://root@localhost:3306/tc')) + TerracottaDriver( + meta_store=MySQLDriver('mysql+pymysql://localhost:3306/tc'), + raster_store=GeoTiffRasterStore() + ) >>> # pass provider if path is given in a non-standard way >>> tc.get_driver('root@localhost/tc', provider='mysql') - TerracottaDriver(meta_store=MySQLDriver('mysql://root@localhost:3306/tc')) + TerracottaDriver( + meta_store=MySQLDriver('mysql+pymysql://localhost:3306/tc'), + raster_store=GeoTiffRasterStore() + ) """ if provider is None: # try and auto-detect diff --git a/terracotta/drivers/sqlite_meta_store.py b/terracotta/drivers/sqlite_meta_store.py index d2d24423..9da51e6d 100644 --- a/terracotta/drivers/sqlite_meta_store.py +++ b/terracotta/drivers/sqlite_meta_store.py @@ -61,7 +61,8 @@ def __init__(self, path: Union[str, Path]) -> None: @classmethod def _normalize_path(cls, path: str) -> str: - return os.path.normpath(os.path.realpath(path)) + url = cls._parse_path(path) + return os.path.normpath(os.path.realpath(url.database)) def _create_database(self) -> None: """The database is automatically created by the sqlite driver on connection, From 69b787660b053afa12ec76eeea12d1e163816a40 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Mon, 21 Feb 2022 13:36:01 +0100 Subject: [PATCH 097/107] Update filenames in first line of files to reflect their actual filenames --- terracotta/drivers/mysql_meta_store.py | 2 +- terracotta/drivers/relational_meta_store.py | 2 +- terracotta/drivers/sqlite_meta_store.py | 2 +- terracotta/drivers/sqlite_remote_meta_store.py | 2 +- terracotta/drivers/terracotta_driver.py | 6 +++--- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/terracotta/drivers/mysql_meta_store.py b/terracotta/drivers/mysql_meta_store.py index 81a24194..0d935668 100644 --- a/terracotta/drivers/mysql_meta_store.py +++ b/terracotta/drivers/mysql_meta_store.py @@ -1,4 +1,4 @@ -"""drivers/mysql.py +"""drivers/mysql_meta_store.py MySQL-backed raster driver. Metadata is stored in a MySQL database, raster data is assumed to be present on disk. diff --git a/terracotta/drivers/relational_meta_store.py b/terracotta/drivers/relational_meta_store.py index 97af46a9..65eecfe6 100644 --- a/terracotta/drivers/relational_meta_store.py +++ b/terracotta/drivers/relational_meta_store.py @@ -1,4 +1,4 @@ -"""drivers/relational_base.py +"""drivers/relational_meta_store.py Base class for relational database drivers, using SQLAlchemy """ diff --git a/terracotta/drivers/sqlite_meta_store.py b/terracotta/drivers/sqlite_meta_store.py index 9da51e6d..3b69e309 100644 --- a/terracotta/drivers/sqlite_meta_store.py +++ b/terracotta/drivers/sqlite_meta_store.py @@ -1,4 +1,4 @@ -"""drivers/sqlite.py +"""drivers/sqlite_meta_store.py SQLite-backed raster driver. Metadata is stored in an SQLite database, raster data is assumed to be present on disk. diff --git a/terracotta/drivers/sqlite_remote_meta_store.py b/terracotta/drivers/sqlite_remote_meta_store.py index dec5e868..456d373d 100644 --- a/terracotta/drivers/sqlite_remote_meta_store.py +++ b/terracotta/drivers/sqlite_remote_meta_store.py @@ -1,4 +1,4 @@ -"""drivers/sqlite_remote.py +"""drivers/sqlite_remote_meta_store.py SQLite-backed raster driver. Metadata is stored in an SQLite database, raster data is assumed to be present on disk. diff --git a/terracotta/drivers/terracotta_driver.py b/terracotta/drivers/terracotta_driver.py index ee21ecd3..fadc56dc 100644 --- a/terracotta/drivers/terracotta_driver.py +++ b/terracotta/drivers/terracotta_driver.py @@ -26,7 +26,7 @@ def squeeze(iterable: Collection[T]) -> T: class TerracottaDriver: """Terracotta driver object used to retrieve raster tiles and metadata. - + Do not instantiate directly, use :func:`terracotta.get_driver` instead. """ def __init__(self, meta_store: MetaStore, raster_store: RasterStore) -> None: @@ -195,9 +195,9 @@ def insert( path: Path to access dataset (driver dependent). override_path: If given, this path will be inserted into the meta store instead of the one used to load the dataset. - metadata: Metadata dict for the dataset. If not given, metadata will be computed + metadata: Metadata dict for the dataset. If not given, metadata will be computed via :meth:`compute_metadata`. - skip_metadata: If True, will skip metadata computation (will be computed + skip_metadata: If True, will skip metadata computation (will be computed during first request instead). Has no effect if ``metadata`` argument is given. """ From 06d6d490a70f86c44733ede60b048c504b1f57df Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Mon, 21 Feb 2022 13:53:16 +0100 Subject: [PATCH 098/107] Always stringify url_or_path Apparently pytest creates some LocalPath objects, which are not caught in isinstance(..., Path) --- terracotta/drivers/__init__.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/terracotta/drivers/__init__.py b/terracotta/drivers/__init__.py index c03f113b..5f788ae6 100644 --- a/terracotta/drivers/__init__.py +++ b/terracotta/drivers/__init__.py @@ -31,8 +31,8 @@ def load_driver(provider: str) -> Type[MetaStore]: raise ValueError(f'Unknown database provider {provider}') -def auto_detect_provider(url_or_path: Union[str, Path]) -> str: - parsed_path = urlparse.urlparse(str(url_or_path)) +def auto_detect_provider(url_or_path: str) -> str: + parsed_path = urlparse.urlparse(url_or_path) scheme = parsed_path.scheme if scheme == 's3': @@ -85,12 +85,11 @@ def get_driver(url_or_path: URLOrPathType, provider: str = None) -> TerracottaDr ) """ + url_or_path = str(url_or_path) + if provider is None: # try and auto-detect provider = auto_detect_provider(url_or_path) - if isinstance(url_or_path, Path): - url_or_path = str(url_or_path) - DriverClass = load_driver(provider) normalized_path = DriverClass._normalize_path(url_or_path) cache_key = (normalized_path, provider, os.getpid()) From 722e7dfb2d6633b9634e3e5cf962c844ebef4361 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Mon, 21 Feb 2022 14:00:17 +0100 Subject: [PATCH 099/107] Rename *Driver classes to *MetaStore --- terracotta/drivers/__init__.py | 12 ++++++------ terracotta/drivers/mysql_meta_store.py | 2 +- terracotta/drivers/sqlite_meta_store.py | 2 +- terracotta/drivers/sqlite_remote_meta_store.py | 4 ++-- tests/drivers/test_drivers.py | 6 +++--- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/terracotta/drivers/__init__.py b/terracotta/drivers/__init__.py index 5f788ae6..26f38988 100644 --- a/terracotta/drivers/__init__.py +++ b/terracotta/drivers/__init__.py @@ -17,16 +17,16 @@ def load_driver(provider: str) -> Type[MetaStore]: if provider == 'sqlite-remote': - from terracotta.drivers.sqlite_remote_meta_store import RemoteSQLiteDriver - return RemoteSQLiteDriver + from terracotta.drivers.sqlite_remote_meta_store import RemoteSQLiteMetaStore + return RemoteSQLiteMetaStore if provider == 'mysql': - from terracotta.drivers.mysql_meta_store import MySQLDriver - return MySQLDriver + from terracotta.drivers.mysql_meta_store import MySQLMetaStore + return MySQLMetaStore if provider == 'sqlite': - from terracotta.drivers.sqlite_meta_store import SQLiteDriver - return SQLiteDriver + from terracotta.drivers.sqlite_meta_store import SQLiteMetaStore + return SQLiteMetaStore raise ValueError(f'Unknown database provider {provider}') diff --git a/terracotta/drivers/mysql_meta_store.py b/terracotta/drivers/mysql_meta_store.py index 0d935668..9d264ac9 100644 --- a/terracotta/drivers/mysql_meta_store.py +++ b/terracotta/drivers/mysql_meta_store.py @@ -12,7 +12,7 @@ from terracotta.drivers.relational_meta_store import RelationalMetaStore -class MySQLDriver(RelationalMetaStore): +class MySQLMetaStore(RelationalMetaStore): """A MySQL-backed raster driver. Assumes raster data to be present in separate GDAL-readable files on disk or remotely. diff --git a/terracotta/drivers/sqlite_meta_store.py b/terracotta/drivers/sqlite_meta_store.py index 3b69e309..660d1b1d 100644 --- a/terracotta/drivers/sqlite_meta_store.py +++ b/terracotta/drivers/sqlite_meta_store.py @@ -11,7 +11,7 @@ from terracotta.drivers.relational_meta_store import RelationalMetaStore -class SQLiteDriver(RelationalMetaStore): +class SQLiteMetaStore(RelationalMetaStore): """An SQLite-backed raster driver. Assumes raster data to be present in separate GDAL-readable files on disk or remotely. diff --git a/terracotta/drivers/sqlite_remote_meta_store.py b/terracotta/drivers/sqlite_remote_meta_store.py index 456d373d..fb77f378 100644 --- a/terracotta/drivers/sqlite_remote_meta_store.py +++ b/terracotta/drivers/sqlite_remote_meta_store.py @@ -15,7 +15,7 @@ from typing import Any, Iterator, Union from terracotta import exceptions, get_settings -from terracotta.drivers.sqlite_meta_store import SQLiteDriver +from terracotta.drivers.sqlite_meta_store import SQLiteMetaStore from terracotta.profile import trace logger = logging.getLogger(__name__) @@ -50,7 +50,7 @@ def _update_from_s3(remote_path: str, local_path: str) -> None: shutil.copyfileobj(obj_bytes, f) -class RemoteSQLiteDriver(SQLiteDriver): +class RemoteSQLiteMetaStore(SQLiteMetaStore): """An SQLite-backed raster driver, where the database file is stored remotely on S3. Assumes raster data to be present in separate GDAL-readable files on disk or remotely. diff --git a/tests/drivers/test_drivers.py b/tests/drivers/test_drivers.py index 43ebe0cf..d3881839 100644 --- a/tests/drivers/test_drivers.py +++ b/tests/drivers/test_drivers.py @@ -2,9 +2,9 @@ TESTABLE_DRIVERS = ['sqlite', 'mysql'] DRIVER_CLASSES = { - 'sqlite': 'SQLiteDriver', - 'sqlite-remote': 'SQLiteRemoteDriver', - 'mysql': 'MySQLDriver' + 'sqlite': 'SQLiteMetaStore', + 'sqlite-remote': 'SQLiteRemoteMetaStore', + 'mysql': 'MySQLMetaStore' } From 41a26e77dcaa0ed16a23a1567edf82fded25f8f4 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Mon, 21 Feb 2022 14:06:21 +0100 Subject: [PATCH 100/107] Remove references to rasters in meta stores's documentation --- terracotta/drivers/mysql_meta_store.py | 8 +++----- terracotta/drivers/relational_meta_store.py | 2 +- terracotta/drivers/sqlite_meta_store.py | 10 ++++------ terracotta/drivers/sqlite_remote_meta_store.py | 8 +++----- 4 files changed, 11 insertions(+), 17 deletions(-) diff --git a/terracotta/drivers/mysql_meta_store.py b/terracotta/drivers/mysql_meta_store.py index 9d264ac9..bfcf8236 100644 --- a/terracotta/drivers/mysql_meta_store.py +++ b/terracotta/drivers/mysql_meta_store.py @@ -1,7 +1,6 @@ """drivers/mysql_meta_store.py -MySQL-backed raster driver. Metadata is stored in a MySQL database, raster data is assumed -to be present on disk. +MySQL-backed metadata driver. Metadata is stored in a MySQL database. """ import functools @@ -13,9 +12,8 @@ class MySQLMetaStore(RelationalMetaStore): - """A MySQL-backed raster driver. + """A MySQL-backed metadata driver. - Assumes raster data to be present in separate GDAL-readable files on disk or remotely. Stores metadata and paths to raster files in MySQL. Requires a running MySQL server. @@ -27,7 +25,7 @@ class MySQLMetaStore(RelationalMetaStore): - ``datasets``: Maps key values to physical raster path. - ``metadata``: Contains actual metadata as separate columns. Indexed via key values. - This driver caches raster data and key names, but not metadata. + This driver caches key names. """ SQL_URL_SCHEME = 'mysql' SQL_DRIVER_TYPE = 'pymysql' diff --git a/terracotta/drivers/relational_meta_store.py b/terracotta/drivers/relational_meta_store.py index 65eecfe6..f0cc0e9a 100644 --- a/terracotta/drivers/relational_meta_store.py +++ b/terracotta/drivers/relational_meta_store.py @@ -1,6 +1,6 @@ """drivers/relational_meta_store.py -Base class for relational database drivers, using SQLAlchemy +Base class for relational database drivers, using SQLAlchemy. """ import contextlib diff --git a/terracotta/drivers/sqlite_meta_store.py b/terracotta/drivers/sqlite_meta_store.py index 660d1b1d..4333dd08 100644 --- a/terracotta/drivers/sqlite_meta_store.py +++ b/terracotta/drivers/sqlite_meta_store.py @@ -1,7 +1,6 @@ """drivers/sqlite_meta_store.py -SQLite-backed raster driver. Metadata is stored in an SQLite database, raster data is assumed -to be present on disk. +SQLite-backed metadata driver. Metadata is stored in an SQLite database. """ import os @@ -12,13 +11,12 @@ class SQLiteMetaStore(RelationalMetaStore): - """An SQLite-backed raster driver. + """An SQLite-backed metadata driver. - Assumes raster data to be present in separate GDAL-readable files on disk or remotely. Stores metadata and paths to raster files in SQLite. This is the simplest Terracotta driver, as it requires no additional infrastructure. - The SQLite database is simply a file that can be stored together with the actual + The SQLite database is simply a file that can e.g. be stored together with the actual raster files. Note: @@ -34,7 +32,7 @@ class SQLiteMetaStore(RelationalMetaStore): - ``datasets``: Maps key values to physical raster path. - ``metadata``: Contains actual metadata as separate columns. Indexed via key values. - This driver caches raster data and key names, but not metadata. + This driver caches key names, but not metadata. Warning: diff --git a/terracotta/drivers/sqlite_remote_meta_store.py b/terracotta/drivers/sqlite_remote_meta_store.py index fb77f378..e40968ad 100644 --- a/terracotta/drivers/sqlite_remote_meta_store.py +++ b/terracotta/drivers/sqlite_remote_meta_store.py @@ -1,7 +1,6 @@ """drivers/sqlite_remote_meta_store.py -SQLite-backed raster driver. Metadata is stored in an SQLite database, raster data is assumed -to be present on disk. +SQLite-backed metadata driver. Metadata is stored in an SQLite database. """ import contextlib @@ -51,9 +50,8 @@ def _update_from_s3(remote_path: str, local_path: str) -> None: class RemoteSQLiteMetaStore(SQLiteMetaStore): - """An SQLite-backed raster driver, where the database file is stored remotely on S3. + """An SQLite-backed metadata driver, where the database file is stored remotely on S3. - Assumes raster data to be present in separate GDAL-readable files on disk or remotely. Stores metadata and paths to raster files in SQLite. See also: @@ -61,7 +59,7 @@ class RemoteSQLiteMetaStore(SQLiteMetaStore): :class:`~terracotta.drivers.sqlite.SQLiteDriver` for the local version of this driver. - The SQLite database is simply a file that can be stored together with the actual + The SQLite database is simply a file that can be stored e.g. together with the actual raster files on S3. Before handling the first request, this driver will download a temporary copy of the remote database file. It is thus not feasible for large databases. From 3b65c8fa12c3661d90c7412dd70903d530a6d8b2 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Mon, 21 Feb 2022 14:19:25 +0100 Subject: [PATCH 101/107] Simplify docstrings in internal base_classes.py --- terracotta/drivers/base_classes.py | 102 +++-------------------------- 1 file changed, 10 insertions(+), 92 deletions(-) diff --git a/terracotta/drivers/base_classes.py b/terracotta/drivers/base_classes.py index ec77d657..736877f5 100644 --- a/terracotta/drivers/base_classes.py +++ b/terracotta/drivers/base_classes.py @@ -63,7 +63,7 @@ def _normalize_path(cls, path: str) -> str: @abstractmethod def create(self, keys: Sequence[str], *, key_descriptions: Mapping[str, str] = None) -> None: - # Create a new, empty database (driver dependent) + """Create a new, empty database""" pass @abstractmethod @@ -72,97 +72,36 @@ def connect(self, verify: bool = True) -> contextlib.AbstractContextManager: This allows you to pool interactions with the database to prevent possibly expensive reconnects, or to roll back several interactions if one of them fails. - - Arguments: - - verify: Whether to verify the database (primarily its version) when connecting. - Should be `true` unless absolutely necessary, such as when instantiating the - database during creation of it. - - Note: - - Make sure to call :meth:`create` on a fresh database before using this method. - - Example: - - >>> import terracotta as tc - >>> driver = tc.get_driver('tc.sqlite') - >>> with driver.connect(): - ... for keys, dataset in datasets.items(): - ... # connection will be kept open between insert operations - ... driver.insert(keys, dataset) - """ pass @abstractmethod def get_keys(self) -> OrderedDict: - """Get all known keys and their fulltext descriptions. - - Returns: - - An :class:`~collections.OrderedDict` in the form - ``{key_name: key_description}`` - - """ + """Get all known keys and their fulltext descriptions.""" pass @abstractmethod def get_datasets(self, where: MultiValueKeysType = None, page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], Any]: - # Get all known dataset key combinations matching the given constraints, - # and a path to retrieve the data (driver dependent) + """Get all known dataset key combinations matching the given constraints, + and a path to retrieve the data + """ pass @abstractmethod def get_metadata(self, keys: KeysType) -> Optional[Dict[str, Any]]: - """Return all stored metadata for given keys. - - Arguments: - - keys: Keys of the requested dataset. Can either be given as a sequence of key values, - or as a mapping ``{key_name: key_value}``. - - Returns: - - A :class:`dict` with the values - - - ``range``: global minimum and maximum value in dataset - - ``bounds``: physical bounds covered by dataset in latitude-longitude projection - - ``convex_hull``: GeoJSON shape specifying total data coverage in latitude-longitude - projection - - ``percentiles``: array of pre-computed percentiles from 1% through 99% - - ``mean``: global mean - - ``stdev``: global standard deviation - - ``metadata``: any additional client-relevant metadata - - """ + """Return all stored metadata for given keys.""" pass @abstractmethod def insert(self, keys: KeysType, path: Any, **kwargs: Any) -> None: - """Register a new dataset. Used to populate metadata database. - - Arguments: - - keys: Keys of the dataset. Can either be given as a sequence of key values, or - as a mapping ``{key_name: key_value}``. - path: Path to access dataset (driver dependent). - - """ + """Register a new dataset. This also populates the metadata database.""" pass @abstractmethod def delete(self, keys: KeysType) -> None: - """Remove a dataset from the metadata database. - - Arguments: - - keys: Keys of the dataset. Can either be given as a sequence of key values, or - as a mapping ``{key_name: key_value}``. - - """ + """Remove a dataset, including information from the metadata database.""" pass def __repr__(self) -> str: @@ -181,28 +120,7 @@ def get_raster_tile(self, path: str, *, tile_size: Sequence[int] = (256, 256), preserve_values: bool = False, asynchronous: bool = False) -> Any: - """Load a raster tile with given path and bounds. - - Arguments: - - path: Path of the requested dataset. - tile_bounds: Physical bounds of the tile to read, in Web Mercator projection (EPSG3857). - Reads the whole dataset if not given. - tile_size: Shape of the output array to return. Must be two-dimensional. - Defaults to :attr:`~terracotta.config.TerracottaSettings.DEFAULT_TILE_SIZE`. - preserve_values: Whether to preserve exact numerical values (e.g. when reading - categorical data). Sets all interpolation to nearest neighbor. - asynchronous: If given, the tile will be read asynchronously in a separate thread. - This function will return immediately with a :class:`~concurrent.futures.Future` - that can be used to retrieve the result. - - Returns: - - Requested tile as :class:`~numpy.ma.MaskedArray` of shape ``tile_size`` if - ``asynchronous=False``, otherwise a :class:`~concurrent.futures.Future` containing - the result. - - """ + """Load a raster tile with given path and bounds.""" pass @abstractmethod @@ -210,7 +128,7 @@ def compute_metadata(self, path: str, *, extra_metadata: Any = None, use_chunks: bool = None, max_shape: Sequence[int] = None) -> Dict[str, Any]: - # Compute metadata for a given input file (driver dependent) + """Compute metadata for a given input file""" pass def __repr__(self) -> str: From 5a9969eb8316e382acf89d70b3760bf866771179 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Mon, 21 Feb 2022 14:38:56 +0100 Subject: [PATCH 102/107] Fix bug (on Windows paths) in sqlite metastore _normalize_path --- terracotta/drivers/sqlite_meta_store.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/terracotta/drivers/sqlite_meta_store.py b/terracotta/drivers/sqlite_meta_store.py index 4333dd08..910c3211 100644 --- a/terracotta/drivers/sqlite_meta_store.py +++ b/terracotta/drivers/sqlite_meta_store.py @@ -59,8 +59,10 @@ def __init__(self, path: Union[str, Path]) -> None: @classmethod def _normalize_path(cls, path: str) -> str: - url = cls._parse_path(path) - return os.path.normpath(os.path.realpath(url.database)) + if path.startswith(f'{cls.SQL_URL_SCHEME}:///'): + path = path.replace(f'{cls.SQL_URL_SCHEME}:///', '') + + return os.path.normpath(os.path.realpath(path)) def _create_database(self) -> None: """The database is automatically created by the sqlite driver on connection, From d8b1ea2d349392d903e8d8410783650a093b2c35 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Mon, 21 Feb 2022 15:17:06 +0100 Subject: [PATCH 103/107] Specify arguments to MetaStore.insert --- terracotta/drivers/base_classes.py | 6 +++--- terracotta/drivers/terracotta_driver.py | 6 ++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/terracotta/drivers/base_classes.py b/terracotta/drivers/base_classes.py index 736877f5..75153b6c 100644 --- a/terracotta/drivers/base_classes.py +++ b/terracotta/drivers/base_classes.py @@ -94,9 +94,9 @@ def get_metadata(self, keys: KeysType) -> Optional[Dict[str, Any]]: pass @abstractmethod - def insert(self, keys: KeysType, - path: Any, **kwargs: Any) -> None: - """Register a new dataset. This also populates the metadata database.""" + def insert(self, keys: KeysType, path: Any, *, metadata: Mapping[str, Any] = None) -> None: + """Register a new dataset. This also populates the metadata database, + if metadata is specified and not `None`.""" pass @abstractmethod diff --git a/terracotta/drivers/terracotta_driver.py b/terracotta/drivers/terracotta_driver.py index fadc56dc..4cc6b8eb 100644 --- a/terracotta/drivers/terracotta_driver.py +++ b/terracotta/drivers/terracotta_driver.py @@ -183,8 +183,7 @@ def insert( path: Any, *, override_path: str = None, metadata: Mapping[str, Any] = None, - skip_metadata: bool = False, - **kwargs: Any + skip_metadata: bool = False ) -> None: """Register a new dataset. Used to populate meta store. @@ -209,8 +208,7 @@ def insert( self.meta_store.insert( keys=keys, path=override_path or path, - metadata=metadata, - **kwargs + metadata=metadata ) @requires_connection From cd0fe1d4b87dac2a0adfe1903fb3408ba947a232 Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Mon, 21 Feb 2022 15:20:01 +0100 Subject: [PATCH 104/107] Specify path in meta stores to be of type str --- terracotta/drivers/base_classes.py | 2 +- terracotta/drivers/terracotta_driver.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/terracotta/drivers/base_classes.py b/terracotta/drivers/base_classes.py index 75153b6c..bae3f206 100644 --- a/terracotta/drivers/base_classes.py +++ b/terracotta/drivers/base_classes.py @@ -94,7 +94,7 @@ def get_metadata(self, keys: KeysType) -> Optional[Dict[str, Any]]: pass @abstractmethod - def insert(self, keys: KeysType, path: Any, *, metadata: Mapping[str, Any] = None) -> None: + def insert(self, keys: KeysType, path: str, *, metadata: Mapping[str, Any] = None) -> None: """Register a new dataset. This also populates the metadata database, if metadata is specified and not `None`.""" pass diff --git a/terracotta/drivers/terracotta_driver.py b/terracotta/drivers/terracotta_driver.py index 4cc6b8eb..e0ac77d1 100644 --- a/terracotta/drivers/terracotta_driver.py +++ b/terracotta/drivers/terracotta_driver.py @@ -180,7 +180,7 @@ def get_metadata(self, keys: ExtendedKeysType) -> Dict[str, Any]: @requires_connection def insert( self, keys: ExtendedKeysType, - path: Any, *, + path: str, *, override_path: str = None, metadata: Mapping[str, Any] = None, skip_metadata: bool = False From 1556dcef84f279f3dcfb24bdf21485afebe1015b Mon Sep 17 00:00:00 2001 From: Nicklas Boserup Date: Mon, 21 Feb 2022 15:31:32 +0100 Subject: [PATCH 105/107] Use SQLAlchemy dialect+driver terminology --- terracotta/drivers/mysql_meta_store.py | 4 ++-- terracotta/drivers/relational_meta_store.py | 10 +++++----- terracotta/drivers/sqlite_meta_store.py | 10 +++++----- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/terracotta/drivers/mysql_meta_store.py b/terracotta/drivers/mysql_meta_store.py index bfcf8236..7d6fa5e9 100644 --- a/terracotta/drivers/mysql_meta_store.py +++ b/terracotta/drivers/mysql_meta_store.py @@ -27,8 +27,8 @@ class MySQLMetaStore(RelationalMetaStore): This driver caches key names. """ - SQL_URL_SCHEME = 'mysql' - SQL_DRIVER_TYPE = 'pymysql' + SQL_DIALECT = 'mysql' + SQL_DRIVER = 'pymysql' SQL_TIMEOUT_KEY = 'connect_timeout' _CHARSET = 'utf8mb4' diff --git a/terracotta/drivers/relational_meta_store.py b/terracotta/drivers/relational_meta_store.py index f0cc0e9a..2ccdf45f 100644 --- a/terracotta/drivers/relational_meta_store.py +++ b/terracotta/drivers/relational_meta_store.py @@ -51,8 +51,8 @@ def convert_exceptions( class RelationalMetaStore(MetaStore, ABC): - SQL_URL_SCHEME: str # The database flavour, eg mysql, sqlite, etc - SQL_DRIVER_TYPE: str # The actual database driver, eg pymysql, sqlite3, etc + SQL_DIALECT: str # The database flavour, eg mysql, sqlite, etc + SQL_DRIVER: str # The actual database driver, eg pymysql, sqlite3, etc SQL_KEY_SIZE: int SQL_TIMEOUT_KEY: str @@ -108,13 +108,13 @@ def _parse_path(cls, connection_string: str) -> URL: con_params = urlparse.urlparse(connection_string) if not con_params.scheme: - con_params = urlparse.urlparse(f'{cls.SQL_URL_SCHEME}:{connection_string}') + con_params = urlparse.urlparse(f'{cls.SQL_DIALECT}:{connection_string}') - if con_params.scheme != cls.SQL_URL_SCHEME: + if con_params.scheme != cls.SQL_DIALECT: raise ValueError(f'unsupported URL scheme "{con_params.scheme}"') url = URL.create( - drivername=f'{cls.SQL_URL_SCHEME}+{cls.SQL_DRIVER_TYPE}', + drivername=f'{cls.SQL_DIALECT}+{cls.SQL_DRIVER}', username=con_params.username, password=con_params.password, host=con_params.hostname, diff --git a/terracotta/drivers/sqlite_meta_store.py b/terracotta/drivers/sqlite_meta_store.py index 910c3211..1d504d2e 100644 --- a/terracotta/drivers/sqlite_meta_store.py +++ b/terracotta/drivers/sqlite_meta_store.py @@ -40,8 +40,8 @@ class SQLiteMetaStore(RelationalMetaStore): outside the main thread. """ - SQL_URL_SCHEME = 'sqlite' - SQL_DRIVER_TYPE = 'pysqlite' + SQL_DIALECT = 'sqlite' + SQL_DRIVER = 'pysqlite' SQL_KEY_SIZE = 256 SQL_TIMEOUT_KEY = 'timeout' @@ -55,12 +55,12 @@ def __init__(self, path: Union[str, Path]) -> None: path: File path to target SQLite database (may or may not exist yet) """ - super().__init__(f'{self.SQL_URL_SCHEME}:///{path}') + super().__init__(f'{self.SQL_DIALECT}:///{path}') @classmethod def _normalize_path(cls, path: str) -> str: - if path.startswith(f'{cls.SQL_URL_SCHEME}:///'): - path = path.replace(f'{cls.SQL_URL_SCHEME}:///', '') + if path.startswith(f'{cls.SQL_DIALECT}:///'): + path = path.replace(f'{cls.SQL_DIALECT}:///', '') return os.path.normpath(os.path.realpath(path)) From b2ebcbab585f325ed1d19ec10c6de29f99ce6127 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Mon, 21 Feb 2022 15:57:20 +0100 Subject: [PATCH 106/107] fix API docs --- docs/api.rst | 39 ++++++++++++++++++--------------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index 1845dd43..58e70650 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -15,30 +15,27 @@ Get a driver instance .. autofunction:: terracotta.get_driver -SQLite driver -------------- +TerracottaDriver +---------------- -.. autoclass:: terracotta.drivers.sqlite.SQLiteDriver +.. autoclass:: terracotta.drivers.TerracottaDriver :members: - :undoc-members: - :special-members: __init__ - :inherited-members: -Remote SQLite driver --------------------- -.. autoclass:: terracotta.drivers.sqlite_remote.RemoteSQLiteDriver - :members: - :undoc-members: - :special-members: __init__ - :inherited-members: - :exclude-members: delete, insert, create +Supported metadata stores +------------------------- -MySQL driver ------------- +SQLite metadata store ++++++++++++++++++++++ -.. autoclass:: terracotta.drivers.mysql.MySQLDriver - :members: - :undoc-members: - :special-members: __init__ - :inherited-members: +.. autoclass:: terracotta.drivers.sqlite_meta_store.SQLiteMetaStore + +Remote SQLite metadata store +++++++++++++++++++++++++++++ + +.. autoclass:: terracotta.drivers.sqlite_remote_meta_store.RemoteSQLiteMetaStore + +MySQL metadata store +++++++++++++++++++++ + +.. autoclass:: terracotta.drivers.mysql_meta_store.MySQLMetaStore From 2f7f9bb128f3647976e85151ae84d31ad8fb2972 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Mon, 21 Feb 2022 16:09:02 +0100 Subject: [PATCH 107/107] unpin pytest --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index cd2a9e4f..bf34bf66 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', 'Framework :: Flask', 'Operating System :: Microsoft :: Windows :: Windows 10', 'Operating System :: MacOS :: MacOS X', @@ -78,7 +79,7 @@ ], extras_require={ 'test': [ - 'pytest<7.0', + 'pytest', 'pytest-cov', 'pytest-mypy', 'pytest-benchmark',