Skip to content

Commit

Permalink
Merge pull request #206 from DHI-GRAS/respawn-broken-pool
Browse files Browse the repository at this point in the history
Respawn broken process pool (#205)
  • Loading branch information
dionhaefner authored Apr 27, 2021
2 parents 6323df7 + fd255f2 commit 253e02c
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 23 deletions.
5 changes: 5 additions & 0 deletions terracotta/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ class TerracottaSettings(NamedTuple):
#: MySQL database password (if not given in driver path)
MYSQL_PASSWORD: Optional[str] = None

#: Use a process pool for band retrieval in parallel
USE_MULTIPROCESSING: bool = True


AVAILABLE_SETTINGS: Tuple[str, ...] = tuple(TerracottaSettings._fields)

Expand Down Expand Up @@ -123,6 +126,8 @@ class SettingSchema(Schema):
MYSQL_USER = fields.String()
MYSQL_PASSWORD = fields.String()

USE_MULTIPROCESSING = fields.Boolean()

@pre_load
def decode_lists(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
for var in ('DEFAULT_TILE_SIZE', 'LAZY_LOADING_MAX_SHAPE',
Expand Down
50 changes: 41 additions & 9 deletions terracotta/drivers/raster_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
Base class for drivers operating on physical raster files.
"""

from typing import (Any, Union, Mapping, Sequence, Dict, List, Tuple,
from typing import (Any, Callable, Union, Mapping, Sequence, Dict, List, Tuple,
TypeVar, Optional, cast, TYPE_CHECKING)
from abc import abstractmethod
from concurrent.futures import Future, Executor, ProcessPoolExecutor, ThreadPoolExecutor
from concurrent.futures.process import BrokenProcessPool

import contextlib
import functools
Expand Down Expand Up @@ -36,14 +37,45 @@

logger = logging.getLogger(__name__)

executor: Executor
context = threading.local()
context.executor = None

try:
# this fails on architectures without /dev/shm
executor = ProcessPoolExecutor(max_workers=3)
except OSError:
# fall back to serial evaluation
executor = ThreadPoolExecutor(max_workers=1)

def create_executor() -> Executor:
settings = get_settings()

if not settings.USE_MULTIPROCESSING:
return ThreadPoolExecutor(max_workers=1)

executor: Executor

try:
# this fails on architectures without /dev/shm
executor = ProcessPoolExecutor(max_workers=3)
except OSError:
# fall back to serial evaluation
warnings.warn(
'Multiprocessing is not available on this system. '
'Falling back to serial execution.'
)
executor = ThreadPoolExecutor(max_workers=1)

return executor


def submit_to_executor(task: Callable[..., Any]) -> Future:
if context.executor is None:
context.executor = create_executor()

try:
future = context.executor.submit(task)
except BrokenProcessPool:
# re-create executor and try again
logger.warn('Re-creating broken process pool')
context.executor = create_executor()
future = context.executor.submit(task)

return future


class RasterDriver(Driver):
Expand Down Expand Up @@ -561,7 +593,7 @@ def get_raster_tile(self,

retrieve_tile = functools.partial(self._get_raster_tile, **kwargs)

future = executor.submit(retrieve_tile)
future = submit_to_executor(retrieve_tile)

def cache_callback(future: Future) -> None:
# insert result into global cache if execution was successful
Expand Down
70 changes: 56 additions & 14 deletions tests/drivers/test_raster_drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,29 +349,32 @@ def test_multiprocessing_fallback(driver_path, provider, raster_file, monkeypatc
import concurrent.futures
from importlib import reload
from terracotta import drivers
import terracotta.drivers.raster_base

def dummy(*args, **kwargs):
raise OSError('monkeypatched')

with monkeypatch.context() as m:
m.setattr(concurrent.futures, 'ProcessPoolExecutor', dummy)
try:
with monkeypatch.context() as m, pytest.warns(UserWarning):
m.setattr(concurrent.futures, 'ProcessPoolExecutor', dummy)

import terracotta.drivers.raster_base
reload(terracotta.drivers.raster_base)
db = drivers.get_driver(driver_path, provider=provider)
keys = ('some', 'keynames')
reload(terracotta.drivers.raster_base)
db = drivers.get_driver(driver_path, provider=provider)
keys = ('some', 'keynames')

db.create(keys)
db.insert(['some', 'value'], str(raster_file))
db.insert(['some', 'other_value'], str(raster_file))
db.create(keys)
db.insert(['some', 'value'], str(raster_file))
db.insert(['some', 'other_value'], str(raster_file))

data1 = db.get_raster_tile(['some', 'value'], tile_size=(256, 256))
assert data1.shape == (256, 256)
data1 = db.get_raster_tile(['some', 'value'], tile_size=(256, 256))
assert data1.shape == (256, 256)

data2 = db.get_raster_tile(['some', 'other_value'], tile_size=(256, 256))
assert data2.shape == (256, 256)
data2 = db.get_raster_tile(['some', 'other_value'], tile_size=(256, 256))
assert data2.shape == (256, 256)

np.testing.assert_array_equal(data1, data2)
np.testing.assert_array_equal(data1, data2)
finally:
reload(terracotta.drivers.raster_base)


@pytest.mark.parametrize('provider', DRIVERS)
Expand Down Expand Up @@ -635,3 +638,42 @@ def test_compute_metadata_unoptimized(unoptimized_raster_file):
)

assert geometry_mismatch(shape(mtd['convex_hull']), convex_hull) < 1e-6


@pytest.mark.parametrize('provider', DRIVERS)
def test_broken_process_pool(driver_path, provider, raster_file):
import concurrent.futures
from terracotta import drivers
from terracotta.drivers.raster_base import context

class BrokenPool:
def submit(self, *args, **kwargs):
raise concurrent.futures.process.BrokenProcessPool('monkeypatched')

context.executor = BrokenPool()

db = drivers.get_driver(driver_path, provider=provider)
keys = ('some', 'keynames')

db.create(keys)
db.insert(['some', 'value'], str(raster_file))
db.insert(['some', 'other_value'], str(raster_file))

data1 = db.get_raster_tile(['some', 'value'], tile_size=(256, 256))
assert data1.shape == (256, 256)

data2 = db.get_raster_tile(['some', 'other_value'], tile_size=(256, 256))
assert data2.shape == (256, 256)

np.testing.assert_array_equal(data1, data2)


def test_no_multiprocessing():
import concurrent.futures
from terracotta import update_settings
from terracotta.drivers.raster_base import create_executor

update_settings(USE_MULTIPROCESSING=False)

executor = create_executor()
assert isinstance(executor, concurrent.futures.ThreadPoolExecutor)

0 comments on commit 253e02c

Please sign in to comment.