Skip to content

Commit

Permalink
✨ NEW: Add ProcessLauncher.process_cache
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisjsewell committed Mar 2, 2021
1 parent db0bf60 commit 72505e9
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 3 deletions.
4 changes: 4 additions & 0 deletions plumpy/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,7 @@ class PersistenceError(Exception):

class ClosedError(Exception):
"""Raised when an mutable operation is attempted on a closed process"""


class DuplicateProcess(Exception):
"""Raised when an ProcessLauncher is asked to launch a process it is already running."""
32 changes: 30 additions & 2 deletions plumpy/process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import copy
import logging
from typing import Any, cast, Dict, Optional, Sequence, TYPE_CHECKING, Union
from weakref import WeakValueDictionary

import kiwipy

from . import loaders
from . import exceptions
from . import communications
from . import futures
from . import loaders
from . import persistence
from .utils import PID_TYPE

Expand All @@ -27,6 +29,7 @@

if TYPE_CHECKING:
from .processes import Process # pylint: disable=cyclic-import
ProcessCacheType = WeakValueDictionary[PID_TYPE, Process] # pylint: disable=unsubscriptable-object

ProcessResult = Any
ProcessStatus = Any
Expand Down Expand Up @@ -527,6 +530,20 @@ def __init__(
else:
self._loader = loaders.get_object_loader()

# using a weak reference ensures the processes can be garbage cleaned on completion
self._process_cache: 'ProcessCacheType' = WeakValueDictionary()

@property
def process_cache(self) -> 'ProcessCacheType':
"""Return a dictionary mapping PIDs to launched processes that are still in memory.
The mapping uses a `WeakValueDictionary`, meaning that processes can be removed,
once they are no longer referenced anywhere else.
This means the dictionary will always contain all processes still running,
but potentially also processes that have terminated but have not yet been garbage collected.
"""
return copy.copy(self._process_cache)

async def __call__(self, communicator: kiwipy.Communicator, task: Dict[str, Any]) -> Union[PID_TYPE, ProcessResult]:
"""
Receive a task.
Expand Down Expand Up @@ -571,10 +588,16 @@ async def _launch(
init_kwargs = {}

proc_class = self._loader.load_object(process_class)
proc = proc_class(*init_args, **init_kwargs)
proc: Process = proc_class(*init_args, **init_kwargs)

if proc.pid in self._process_cache and not self._process_cache[proc.pid].has_terminated():
raise exceptions.DuplicateProcess(f'Process<{proc.pid}> is already running')

if persist and self._persister is not None:
self._persister.save_checkpoint(proc)

self._process_cache[proc.pid] = proc

if nowait:
asyncio.ensure_future(proc.step_until_terminated())
return proc.pid
Expand Down Expand Up @@ -602,10 +625,15 @@ async def _continue(
LOGGER.warning('rejecting task: cannot continue process<%d> because no persister is available', pid)
raise communications.TaskRejected('Cannot continue process, no persister')

if pid in self._process_cache and not self._process_cache[pid].has_terminated():
raise exceptions.DuplicateProcess(f'Process<{pid}> is already running')

# Do not catch exceptions here, because if these operations fail, the continue task should except and bubble up
saved_state = self._persister.load_checkpoint(pid, tag)
proc = cast('Process', saved_state.unbundle(self._load_context))

self._process_cache[proc.pid] = proc

if nowait:
asyncio.ensure_future(proc.step_until_terminated())
return proc.pid
Expand Down
14 changes: 13 additions & 1 deletion test/rmq/test_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import shortuuid

import pytest
from kiwipy import rmq
from kiwipy import RemoteException, rmq

import plumpy
from plumpy import communications, process_comms
Expand Down Expand Up @@ -177,3 +177,15 @@ async def test_continue(self, loop_communicator, async_controller, persister):
# Let the process run to the end
result = await async_controller.continue_process(pid)
assert result, utils.DummyProcessWithOutput.EXPECTED_OUTPUTS

@pytest.mark.asyncio
async def test_duplicate_process(self, loop_communicator, async_controller, persister):
loop = asyncio.get_event_loop()
launcher = plumpy.ProcessLauncher(loop, persister=persister)
loop_communicator.add_task_subscriber(launcher)
process = utils.DummyProcessWithOutput()
persister.save_checkpoint(process)
launcher._process_cache[process.pid] = process
assert process.pid in launcher.process_cache
with pytest.raises(RemoteException, match='already running'):
await async_controller.continue_process(process.pid)

0 comments on commit 72505e9

Please sign in to comment.