diff --git a/damnit/ctxsupport/ctxrunner.py b/damnit/ctxsupport/ctxrunner.py index 9ede9a05..2de36ddf 100644 --- a/damnit/ctxsupport/ctxrunner.py +++ b/damnit/ctxsupport/ctxrunner.py @@ -7,6 +7,7 @@ import argparse import functools +import inspect import logging import os import time @@ -212,16 +213,28 @@ def __init__(self, data, ctx): def create(cls, ctx_file: ContextFile, inputs, run_number, proposal): res = {'start_time': np.asarray(get_start_time(inputs['run_data']))} + def get_dep_or_default(var, arg_name, dep_name): + """ + Helper function to get either the value returned from the dependency + `dep_name` of `var`, if any, or the default value of the argument in + the function signature. + """ + value = res.get(dep_name) + if value is None: + value = inspect.signature(var.func).parameters[arg_name].default + + return value + for name in ctx_file.ordered_vars(): var = ctx_file.vars[name] try: # Add all variable dependencies - kwargs = { arg_name: res.get(dep_name) + kwargs = { arg_name: get_dep_or_default(var, arg_name, dep_name) for arg_name, dep_name in var.arg_dependencies().items() } - # If any are None, skip this variable since we're missing a dependency - missing_deps = [key for key, value in kwargs.items() if value is None] + # Check for missing dependencies with no default value + missing_deps = [key for key, value in kwargs.items() if value is inspect.Parameter.empty] if len(missing_deps) > 0: log.warning(f"Skipping {name} because of missing dependencies: {', '.join(missing_deps)}") continue diff --git a/docs/backend.md b/docs/backend.md index 3d4be5d1..30fe4c2d 100644 --- a/docs/backend.md +++ b/docs/backend.md @@ -99,6 +99,16 @@ def bar(run, value: "var#foo"): return value * 2 ``` +Dependencies with default values are also allowed, the default value will be +passed to the function if the dependency did not complete execution for some +reason: +```python +@Variable(title="baz") +def baz(run, bar: "var#bar"=42): + # This will return the result of foo() if foo() succeeded, otherwise 42 + return value +``` + ## Reprocessing The context file is loaded each time a run is received, so if you edit the context file the changes will only take effect for the runs coming later. But, diff --git a/tests/test_backend.py b/tests/test_backend.py index 6ae755e9..280125a1 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -225,6 +225,19 @@ def bar(run, foo: "var#foo"): # There should be no computed variables since we treat None as a missing dependency assert tuple(results.data.keys()) == ("start_time",) + default_value_code = """ + from damnit_ctx import Variable + + @Variable(title="foo") + def foo(run): return None + + @Variable(title="bar") + def bar(run, foo: "var#foo"=1): return 41 + foo + """ + default_value_ctx = mkcontext(default_value_code) + results = results_create(default_value_ctx) + assert results.reduced["bar"].item() == 42 + # Test that the backend completely updates all datasets belonging to a # variable during reprocessing. e.g. if it had a trainId dataset but now # doesn't, the trainId dataset should be deleted from the HDF5 file.