Skip to content

Commit

Permalink
fix: don't pass process stack via context (#4699)
Browse files Browse the repository at this point in the history
This PR fixes a memory leak: when running `CalcJob`s over an SSH connection,
the first CalcJob that was run remained in memory indefinitely.

`plumpy` uses the `contextvars` module to provide a reference to the
`current_process` anywhere in a task launched by a process.  When using any of
`asyncio`'s `call_soon`, `call_later` or `call_at` methods, each individual
function execution gets their own copy of this context.  This means that as
long as a handle to these scheduled executions remains in memory, the copy of
the `'process stack'` context var (and thus the process itself) remain in
memory,

In this particular case, a handle to such a task (`do_open` a `transport`)
remained in memory and caused the whole process to remain in memory as well via
the 'process stack' context variable.  This is fixed by explicitly passing an
empty context to the execution of `do_open` (which anyhow does not need access
to the `current_process`).  An explicit test is added to make sure that no
references to processes are leaked after running process via the interpreter
as well as in the daemon tests.

This PR adds the empty context in two other invocations of `call_later`, but
there are more places in the code where these methods are used. As such it is a
bit of a workaround.  Eventually, this problem should likely be addressed by
converting any functions that use `call_soon`, `call_later` or `call_at` and
all their parents in the call stack to coroutines.

Co-authored-by: Chris Sewell <[email protected]>
  • Loading branch information
ltalirz and chrisjsewell authored Feb 9, 2021
1 parent e7223ae commit b07841a
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 36 deletions.
39 changes: 39 additions & 0 deletions .github/system_tests/pytest/test_memory_leaks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Utilities for testing memory leakage."""
from tests.utils import processes as test_processes # pylint: disable=no-name-in-module,import-error
from tests.utils.memory import get_instances # pylint: disable=no-name-in-module,import-error
from aiida.engine import processes, run
from aiida.plugins import CalculationFactory
from aiida import orm

ArithmeticAddCalculation = CalculationFactory('arithmetic.add')


def test_leak_run_process():
"""Test whether running a dummy process leaks memory."""
inputs = {'a': orm.Int(2), 'b': orm.Str('test')}
run(test_processes.DummyProcess, **inputs)

# check that no reference to the process is left in memory
# some delay is necessary in order to allow for all callbacks to finish
process_instances = get_instances(processes.Process, delay=0.2)
assert not process_instances, f'Memory leak: process instances remain in memory: {process_instances}'


def test_leak_local_calcjob(aiida_local_code_factory):
"""Test whether running a local CalcJob leaks memory."""
inputs = {'x': orm.Int(1), 'y': orm.Int(2), 'code': aiida_local_code_factory('arithmetic.add', '/usr/bin/diff')}
run(ArithmeticAddCalculation, **inputs)

# check that no reference to the process is left in memory
# some delay is necessary in order to allow for all callbacks to finish
process_instances = get_instances(processes.Process, delay=0.2)
assert not process_instances, f'Memory leak: process instances remain in memory: {process_instances}'
95 changes: 63 additions & 32 deletions .github/system_tests/test_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from aiida.engine.daemon.client import get_daemon_client
from aiida.engine.persistence import ObjectLoader
from aiida.manage.caching import enable_caching
from aiida.engine.processes import Process
from aiida.orm import CalcJobNode, load_node, Int, Str, List, Dict, load_code
from aiida.plugins import CalculationFactory, WorkflowFactory
from aiida.workflows.arithmetic.add_multiply import add_multiply, add
Expand All @@ -26,6 +27,8 @@
WorkFunctionRunnerWorkChain, NestedInputNamespace, SerializeWorkChain, ArithmeticAddBaseWorkChain
)

from tests.utils.memory import get_instances # pylint: disable=import-error

CODENAME_ADD = 'add@localhost'
CODENAME_DOUBLER = 'doubler@localhost'
TIMEOUTSECS = 4 * 60 # 4 minutes
Expand Down Expand Up @@ -389,9 +392,12 @@ def run_multiply_add_workchain():
assert results['result'].value == 5


def main():
"""Launch a bunch of calculation jobs and workchains."""
# pylint: disable=too-many-locals,too-many-statements,too-many-branches
def launch_all():
"""Launch a bunch of calculation jobs and workchains.
:returns: dictionary with expected results and pks of all launched calculations and workchains
"""
# pylint: disable=too-many-locals,too-many-statements
expected_results_process_functions = {}
expected_results_calculations = {}
expected_results_workchains = {}
Expand Down Expand Up @@ -437,8 +443,8 @@ def main():
builder = NestedWorkChain.get_builder()
input_val = 4
builder.inp = Int(input_val)
proc = submit(builder)
expected_results_workchains[proc.pk] = input_val
pk = submit(builder).pk
expected_results_workchains[pk] = input_val

print('Submitting a workchain with a nested input namespace.')
value = Int(-12)
Expand Down Expand Up @@ -483,9 +489,46 @@ def main():
calculation_pks = sorted(expected_results_calculations.keys())
workchains_pks = sorted(expected_results_workchains.keys())
process_functions_pks = sorted(expected_results_process_functions.keys())
pks = calculation_pks + workchains_pks + process_functions_pks

print('Wating for end of execution...')
return {
'pks': calculation_pks + workchains_pks + process_functions_pks,
'calculations': expected_results_calculations,
'process_functions': expected_results_process_functions,
'workchains': expected_results_workchains,
}


def relaunch_cached(results):
"""Launch the same calculations but with caching enabled -- these should be FINISHED immediately."""
code_doubler = load_code(CODENAME_DOUBLER)
cached_calcs = []
with enable_caching(identifier='aiida.calculations:templatereplacer'):
for counter in range(1, NUMBER_CALCULATIONS + 1):
inputval = counter
calc, expected_result = run_calculation(code=code_doubler, counter=counter, inputval=inputval)
cached_calcs.append(calc)
results['calculations'][calc.pk] = expected_result

if not (
validate_calculations(results['calculations']) and validate_workchains(results['workchains']) and
validate_cached(cached_calcs) and validate_process_functions(results['process_functions'])
):
print_daemon_log()
print('')
print('ERROR! Some return values are different from the expected value')
sys.exit(3)

print_daemon_log()
print('')
print('OK, all calculations have the expected parsed result')


def main():
"""Launch a bunch of calculation jobs and workchains."""

results = launch_all()

print('Waiting for end of execution...')
start_time = time.time()
exited_with_timeout = True
while time.time() - start_time < TIMEOUTSECS:
Expand Down Expand Up @@ -515,7 +558,7 @@ def main():
except subprocess.CalledProcessError as exception:
print(f'Note: the command failed, message: {exception}')

if jobs_have_finished(pks):
if jobs_have_finished(results['pks']):
print('Calculation terminated its execution')
exited_with_timeout = False
break
Expand All @@ -525,30 +568,18 @@ def main():
print('')
print(f'Timeout!! Calculation did not complete after {TIMEOUTSECS} seconds')
sys.exit(2)
else:
# Launch the same calculations but with caching enabled -- these should be FINISHED immediately
cached_calcs = []
with enable_caching(identifier='aiida.calculations:templatereplacer'):
for counter in range(1, NUMBER_CALCULATIONS + 1):
inputval = counter
calc, expected_result = run_calculation(code=code_doubler, counter=counter, inputval=inputval)
cached_calcs.append(calc)
expected_results_calculations[calc.pk] = expected_result

if (
validate_calculations(expected_results_calculations) and
validate_workchains(expected_results_workchains) and validate_cached(cached_calcs) and
validate_process_functions(expected_results_process_functions)
):
print_daemon_log()
print('')
print('OK, all calculations have the expected parsed result')
sys.exit(0)
else:
print_daemon_log()
print('')
print('ERROR! Some return values are different from the expected value')
sys.exit(3)

relaunch_cached(results)

# Check that no references to processes remain in memory
# Note: This tests only processes that were `run` in the same interpreter, not those that were `submitted`
del results
processes = get_instances(Process, delay=1.0)
if processes:
print(f'Memory leak! Process instances remained in memory: {processes}')
sys.exit(4)

sys.exit(0)


if __name__ == '__main__':
Expand Down
11 changes: 9 additions & 2 deletions aiida/engine/processes/calcjobs/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""Module containing utilities and classes relating to job calculations running on systems that require transport."""
import asyncio
import contextlib
import contextvars
import logging
import time
from typing import Any, Dict, Hashable, Iterator, List, Optional, TYPE_CHECKING
Expand Down Expand Up @@ -180,15 +181,21 @@ async def updating():
# Any outstanding requests?
if self._update_requests_outstanding():
self._update_handle = self._loop.call_later(
self._get_next_update_delay(), asyncio.ensure_future, updating()
self._get_next_update_delay(),
asyncio.ensure_future,
updating(),
context=contextvars.Context(), # type: ignore[call-arg]
)
else:
self._update_handle = None

# Check if we're already updating
if self._update_handle is None:
self._update_handle = self._loop.call_later(
self._get_next_update_delay(), asyncio.ensure_future, updating()
self._get_next_update_delay(),
asyncio.ensure_future,
updating(),
context=contextvars.Context(), # type: ignore[call-arg]
)

@staticmethod
Expand Down
9 changes: 8 additions & 1 deletion aiida/engine/transports.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import traceback
from typing import Awaitable, Dict, Hashable, Iterator, Optional
import asyncio
import contextvars

from aiida.orm import AuthInfo
from aiida.transports import Transport
Expand Down Expand Up @@ -96,7 +97,13 @@ def do_open():
transport_request.future.set_result(transport)

# Save the handle so that we can cancel the callback if the user no longer wants it
open_callback_handle = self._loop.call_later(safe_open_interval, do_open)
# Note: Don't pass the Process context, since (a) it is not needed by `do_open` and (b) the transport is
# passed around to many places, including outside aiida-core (e.g. paramiko). Anyone keeping a reference
# to this handle would otherwise keep the Process context (and thus the process itself) in memory.
# See https://github.com/aiidateam/aiida-core/issues/4698
open_callback_handle = self._loop.call_later(
safe_open_interval, do_open, context=contextvars.Context()
) # type: ignore[call-arg]

try:
transport_request.count += 1
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-py-3.7.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ pycparser==2.20
pydata-sphinx-theme==0.4.3
Pygments==2.7.4
pymatgen==2020.12.31
pympler==0.9
PyMySQL==0.9.3
PyNaCl==1.4.0
pyparsing==2.4.7
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-py-3.8.txt
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ pycparser==2.20
pydata-sphinx-theme==0.4.3
Pygments==2.7.4
pymatgen==2020.12.31
pympler==0.9
PyMySQL==0.9.3
PyNaCl==1.4.0
pyparsing==2.4.7
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-py-3.9.txt
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ pycparser==2.20
pydata-sphinx-theme==0.4.3
Pygments==2.7.4
pymatgen==2020.12.31
pympler==0.9
PyMySQL==0.9.3
PyNaCl==1.4.0
pyparsing==2.4.7
Expand Down
1 change: 1 addition & 0 deletions setup.json
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
"pytest-cov~=2.7",
"pytest-rerunfailures~=9.1,>=9.1.1",
"pytest-benchmark~=3.2",
"pympler~=0.9",
"coverage<5.0",
"sqlalchemy-diff~=0.1.3"
],
Expand Down
1 change: 0 additions & 1 deletion tests/engine/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Tests for the `run` functions."""

from aiida.backends.testbase import AiidaTestCase
from aiida.engine import run, run_get_node
from aiida.orm import Int, Str, ProcessNode
Expand Down
31 changes: 31 additions & 0 deletions tests/utils/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Utilities for testing memory leakage."""
import asyncio
from pympler import muppy


def get_instances(classes, delay=0.0):
"""Return all instances of provided classes that are in memory.
Useful for investigating memory leaks.
:param classes: A class or tuple of classes to check (passed to `isinstance`).
:param delay: How long to sleep (seconds) before collecting the memory dump.
This is a convenience function for tests involving Processes. For example, :py:func:`~aiida.engine.run` returns
before all futures are resolved/cleaned up. Dumping memory too early would catch those and the references they
carry, although they may not actually be leaking memory.
"""
if delay > 0:
loop = asyncio.get_event_loop()
loop.run_until_complete(asyncio.sleep(delay))

all_objects = muppy.get_objects() # this also calls gc.collect()
return [o for o in all_objects if hasattr(o, '__class__') and isinstance(o, classes)]

0 comments on commit b07841a

Please sign in to comment.