diff --git a/janis_core/ingestion/fromwdl.py b/janis_core/ingestion/fromwdl.py new file mode 100755 index 000000000..55ab50054 --- /dev/null +++ b/janis_core/ingestion/fromwdl.py @@ -0,0 +1,447 @@ +#!/usr/bin/env python3 +import functools +import os +import re +from types import LambdaType +from typing import List, Union, Optional, Callable +import WDL + +import janis_core as j + + +def error_boundary(return_value=None): + def try_catch_translate_inner(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not WdlParser.allow_errors: + return func(*args, **kwargs) + else: + try: + return func(*args, **kwargs) + except Exception as e: + j.Logger.log_ex(e) + return return_value + + return wrapper + + return try_catch_translate_inner + +class WdlParser: + + allow_errors = False + + @staticmethod + def from_doc(doc: str, base_uri=None): + abs_path = os.path.relpath(doc) + d = WDL.load(abs_path) + + parser = WdlParser() + + if d.workflow: + return parser.from_loaded_object(d.workflow) + + tasks = [] + for t in d.tasks: + tasks.append(parser.from_loaded_object(t)) + + return tasks[0] + + def from_loaded_object(self, obj: WDL.SourceNode): + if isinstance(obj, WDL.Task): + return self.from_loaded_task(obj) + elif isinstance(obj, WDL.Workflow): + return self.from_loaded_workflow(obj) + + def from_loaded_workflow(self, obj: WDL.Workflow): + wf = j.WorkflowBuilder(identifier=obj.name) + + for inp in obj.inputs: + self.add_decl_to_wf_input(wf, inp) + + for call in obj.body: + self.add_call_to_wf(wf, call) + + return wf + + def workflow_selector_getter(self, wf, exp: str): + if "." in exp: + node, *tag = exp.split(".") + if len(tag) > 1: + raise Exception(f"Couldn't parse source ID: {exp} - too many '.'") + return wf[node][tag[0]] + + return wf[exp] + + @error_boundary() + def add_call_to_wf( + self, + wf: j.WorkflowBase, + call: WDL.WorkflowNode, + condition=None, + foreach=None, + expr_alias: str = None, + ): + def selector_getter(exp): + if exp == expr_alias: + return j.ForEachSelector() + + return self.workflow_selector_getter(wf, exp) + + if isinstance(call, WDL.Call): + task = self.from_loaded_object(call.callee) + inp_map = {} + for k, v in call.inputs.items(): + new_expr = self.translate_expr(v, input_selector_getter=selector_getter) + + inp_map[k] = new_expr + + return wf.step(call.name, task(**inp_map), when=condition, _foreach=foreach) + + elif isinstance(call, WDL.Conditional): + # if len(call.body) > 1: + # raise NotImplementedError( + # f"Janis can't currently support more than one call inside the conditional: {', '.join(str(c) for c in call.body)}") + for inner_call in call.body: + # inner_call = call.body[0] + self.add_call_to_wf( + wf, + inner_call, + condition=self.translate_expr( + call.expr, input_selector_getter=selector_getter + ), + expr_alias=expr_alias, + foreach=foreach + ) + elif isinstance(call, WDL.Scatter): + # for scatter, we want to take the call.expr, and pass it to a step.foreach + + foreach = self.translate_expr(call.expr) + + scar_var_type = self.parse_wdl_type(call.expr.type) + if isinstance(scar_var_type, WDL.Type.Array): + scar_var_type = scar_var_type.item_type + + # when we unwrap each step-input to the workflow, we want to replace 'call.variable' with + # lambda el: + # if call.variable not in wf.input_nodes: + # wf.input(call.variable, scar_var_type) + for inner_call in call.body: + self.add_call_to_wf( + wf, inner_call, foreach=foreach, expr_alias=call.variable + ) + + elif isinstance(call, WDL.Decl): + self.add_decl_to_wf_input(wf, call) + else: + raise NotImplementedError(f"body type: {type(call)}") + + def add_decl_to_wf_input(self, wf: j.WorkflowBase, inp: WDL.Decl): + default = None + if inp.expr: + + def selector_getter(exp): + return self.workflow_selector_getter(wf, exp) + + default = self.translate_expr( + inp.expr, input_selector_getter=selector_getter + ) + + return wf.input(inp.name, self.parse_wdl_type(inp.type), default=default) + + @classmethod + def container_from_runtime(cls, runtime, inputs: List[WDL.Decl]): + container = runtime.get("container", runtime.get("docker")) + if isinstance(container, WDL.Expr.Get): + # relevant input + inp = [i.expr for i in inputs if i.name == str(container.expr)] + if len(inp) > 0: + container = inp[0] + else: + j.Logger.warn( + f"Expression for determining containers was '{container}' " + f"but couldn't find input called {str(container.expr)}" + ) + if isinstance(container, WDL.Expr.String): + container = container.literal + if isinstance(container, WDL.Value.String): + container = container.value + if container is None: + container = "ubuntu:latest" + if not isinstance(container, str): + j.Logger.warn( + f"Expression for determining containers ({container}) are not supported in Janis, using ubuntu:latest" + ) + container = "ubuntu:latest" + return container + + def parse_memory_requirement(self, value): + s = self.translate_expr(value) + if isinstance(s, str): + if s.lower().endswith("g"): + return float(s[:-1].strip()) + if s.lower().endswith("gb"): + return float(s[:-2].strip()) + elif s.lower().endswith("gib"): + return float(s[:-3].strip()) * 0.931323 + elif s.lower().endswith("mb"): + return float(s[:-2].strip()) / 1000 + elif s.lower().endswith("mib"): + return float(s[:-3].strip()) / 1024 + raise Exception(f"Memory type {s}") + elif isinstance(s, (float, int)): + # in bytes? + return s / (1024 ** 3) + elif isinstance(s, j.Selector): + return s + raise Exception(f"Couldn't recognise memory requirement '{value}'") + + def parse_disk_requirement(self, value): + s = self.translate_expr(value) + if isinstance(s, str): + try: + return int(s) + except ValueError: + pass + pattern_matcher = re.match(r"local-disk (\d+) .*", s) + if not pattern_matcher: + raise Exception(f"Couldn't recognise disk type '{value}'") + s = pattern_matcher.groups()[0] + try: + return int(s) + except ValueError: + pass + if s.lower().endswith("gb"): + return float(s[:-2].strip()) + elif s.lower().endswith("gib"): + return float(s[:-3].strip()) * 0.931323 + elif s.lower().endswith("mb"): + return float(s[:-2].strip()) / 1000 + elif s.lower().endswith("mib"): + return float(s[:-3].strip()) / 1024 + raise Exception(f"Disk type type {s}") + elif isinstance(s, (float, int)): + # in bytes? + return s / (1024 ** 3) + elif isinstance(s, j.Selector): + return s + raise Exception(f"Couldn't recognise memory requirement '{value}'") + + def from_loaded_task(self, obj: WDL.Task): + rt = obj.runtime + translated_script = self.translate_expr(obj.command) + inputs = obj.inputs + + cpus = self.translate_expr(rt.get("cpu")) + if cpus is not None and not isinstance(cpus, (int, float)): + cpus = int(cpus) + + c = j.CommandToolBuilder( + tool=obj.name, + base_command=["sh", "script.sh"], + container=self.container_from_runtime(rt, inputs=inputs), + version="DEV", + inputs=[ + self.parse_command_tool_input(i) + for i in obj.inputs + if not i.name.startswith("runtime_") + ], + outputs=[self.parse_command_tool_output(o) for o in obj.outputs], + files_to_create={"script.sh": translated_script}, + memory=self.parse_memory_requirement(rt.get("memory")), + cpus=cpus, + disk=self.parse_disk_requirement(rt.get("disks")), + ) + + return c + + def translate_expr( + self, expr: WDL.Expr.Base, input_selector_getter: Callable[[str], any] = None + ) -> Optional[Union[j.Selector, List[j.Selector], int, str, float, bool]]: + if expr is None: + return None + + tp = lambda exp: self.translate_expr( + exp, input_selector_getter=input_selector_getter + ) + + if isinstance(expr, WDL.Expr.Array): + # a literal array + return [self.translate_expr(e) for e in expr.items] + if isinstance(expr, WDL.Expr.String): + return self.translate_wdl_string(expr) + elif isinstance(expr, (WDL.Expr.Int, WDL.Expr.Boolean, WDL.Expr.Float)): + return expr.literal.value + if isinstance(expr, WDL.Expr.Placeholder): + return self.translate_expr(expr.expr) + if isinstance(expr, WDL.Expr.IfThenElse): + return j.If(tp(expr.condition), tp(expr.consequent), tp(expr.alternative)) + elif isinstance(expr, WDL.Expr.Get): + n = str(expr.expr) + if input_selector_getter: + return input_selector_getter(n) + return j.InputSelector(n) + elif isinstance(expr, WDL.Expr.Apply): + return self.translate_apply( + expr, input_selector_getter=input_selector_getter + ) + + raise Exception(f"Unsupported WDL expression type: {expr} ({type(expr)})") + + def translate_wdl_string(self, s: WDL.Expr.String): + if s.literal is not None: + return str(s.literal).lstrip('"').rstrip('"') + + elements = {} + counter = 1 + _format = str(s).lstrip('"').rstrip('"') + + for placeholder in s.children: + if isinstance(placeholder, (str, bool, int, float)): + continue + + token = f"JANIS_WDL_TOKEN_{counter}" + if str(placeholder) not in _format: + # if the placeholder came up again + continue + + _format = _format.replace(str(placeholder), f"{{{token}}}") + elements[token] = self.translate_expr(placeholder) + counter += 1 + + if len(elements) == 0: + return str(s) + + _format.replace("\\n", "\n") + + return j.StringFormatter(_format, **elements) + + def file_size_operator(self, src, *args): + multiplier = None + if len(args) > 1: + f = args[1].lower() + multiplier_heirarchy = [ + ("ki" in f, 1024), + ("k" in f, 1000), + ("mi" in f, 1.024), + ("gi" in f, 0.001024), + ("g" in f, 0.001), + ] + if not any(m[0] for m in multiplier_heirarchy): + j.Logger.warn( + f"Couldn't determine prefix {f} for FileSizeOperator, defaulting to MB" + ) + else: + multiplier = [m[1] for m in multiplier_heirarchy if m[0] is True][0] + + if isinstance(src, list): + return multiplier * sum(j.FileSizeOperator(s) for s in src) + + base = j.FileSizeOperator(src, *args) + if multiplier is not None and multiplier != 1: + return multiplier * base + return base + + def basename_operator(self, src, *args): + retval = j.BasenameOperator(src) + if len(args) > 0: + retval = retval.replace(args[0], "") + + return retval + + def translate_apply( + self, expr: WDL.Expr.Apply, **expr_kwargs + ) -> Union[j.Selector, List[j.Selector]]: + + # special case for select_first of array with one element + if expr.function_name == "select_first" and len(expr.arguments) > 0: + inner = expr.arguments[0] + if isinstance(inner, WDL.Expr.Array) and len(inner.items) == 1: + return self.translate_expr(inner.items[0]).assert_not_null() + + args = [self.translate_expr(e, **expr_kwargs) for e in expr.arguments] + + fn_map = { + "_land": j.AndOperator, + "defined": j.IsDefined, + "select_first": j.FilterNullOperator, + "basename": self.basename_operator, + "length": j.LengthOperator, + "_gt": j.GtOperator, + "_gte": j.GteOperator, + "_lt": j.LtOperator, + "_lte": j.LteOperator, + "sep": j.JoinOperator, + "_add": j.AddOperator, + "_interpolation_add": j.AddOperator, + "stdout": j.Stdout, + "_mul": j.MultiplyOperator, + "_div": j.DivideOperator, + "glob": j.WildcardSelector, + "range": j.RangeOperator, + "_at": j.IndexOperator, + "_negate": j.NotOperator, + "_sub": j.SubtractOperator, + "size": self.file_size_operator, + "ceil": j.CeilOperator, + "select_all": j.FilterNullOperator, + "sub": j.ReplaceOperator, + "round": j.RoundOperator, + "write_lines": lambda exp: f"JANIS: write_lines({exp})", + "read_tsv": lambda exp: f'JANIS: j.read_tsv({exp})', + "read_boolean": lambda exp: f'JANIS: j.read_boolean({exp})', + 'read_lines': lambda exp: f'JANIS: j.read_lines({exp})', + + } + fn = fn_map.get(expr.function_name) + if fn is None: + raise Exception(f"Unhandled WDL apply function_name: {expr.function_name}") + if isinstance(fn, LambdaType): + return fn(args) + return fn(*args) + + def parse_wdl_type(self, t: WDL.Type.Base): + optional = t.optional + if isinstance(t, WDL.Type.Int): + return j.Int(optional=optional) + elif isinstance(t, WDL.Type.String): + return j.String(optional=optional) + elif isinstance(t, WDL.Type.Float): + return j.Float(optional=optional) + elif isinstance(t, WDL.Type.Boolean): + return j.Boolean(optional=optional) + elif isinstance(t, WDL.Type.File): + return j.File(optional=optional) + elif isinstance(t, WDL.Type.Directory): + return j.Directory(optional=optional) + elif isinstance(t, WDL.Type.Array): + return j.Array(self.parse_wdl_type(t.item_type), optional=optional) + + raise Exception(f"Didn't handle WDL type conversion for '{t}' ({type(t)})") + + def parse_command_tool_input(self, inp: WDL.Decl): + default = None + if inp.expr: + default = self.translate_expr(inp.expr) + + # explicitly skip "runtime_*" inputs because they're from janis + if inp.name.startswith("runtime_"): + return None + + return j.ToolInput(inp.name, self.parse_wdl_type(inp.type), default=default) + + def parse_command_tool_output(self, outp: WDL.Decl): + sel = self.translate_expr(outp.expr) + + return j.ToolOutput(outp.name, self.parse_wdl_type(outp.type), selector=sel) + + +if __name__ == "__main__": + import sys + if len(sys.argv) != 2: + raise Exception("Expected 1 argument, the name of a CWL tool.") + + toolname = sys.argv[1] + + tool = WdlParser.from_doc(toolname) + + tool.translate("janis") diff --git a/janis_core/operators/logical.py b/janis_core/operators/logical.py index 9e01aec9c..c6223d680 100644 --- a/janis_core/operators/logical.py +++ b/janis_core/operators/logical.py @@ -53,6 +53,10 @@ def __repr__(self): def evaluate(self, inputs): return self.evaluate_arg(self.args[0], inputs) is not None + def to_python(self, unwrap_operator, *args): + arg = unwrap_operator(self.args[0]) + return f"{arg} is not None" + def to_cwl(self, unwrap_operator, *args): arg = unwrap_operator(self.args[0]) # 2 equals (!=) in javascript will coerce undefined to equal null @@ -103,6 +107,10 @@ def to_cwl(self, unwrap_operator, *args): cond, v1, v2 = [unwrap_operator(a) for a in self.args] return f"{cond} ? {v1} : {v2}" + def to_python(self, unwrap_operator, *args): + condition, iftrue, iffalse = [unwrap_operator(a) for a in self.args] + return f"({iftrue} if {condition} else {iffalse})" + class AssertNotNull(Operator): @staticmethod @@ -117,6 +125,9 @@ def evaluate(self, inputs): assert result is not None return result + def to_python(self, unwrap_operator, *args): + return unwrap_operator(unwrap_operator(args[0])) + def to_wdl(self, unwrap_operator, *args): arg = unwrap_operator(self.args[0]) return f"select_first([{arg}])" @@ -168,6 +179,10 @@ def returntype(self): def apply_to(value): return not value + def to_python(self, unwrap_operator, *args): + arg = unwrap_operator(self.args[0]) + return f"not {arg}" + # Two value operators @@ -178,7 +193,7 @@ def friendly_signature(): @staticmethod def symbol(): - return "&&" + return "and" @staticmethod def wdl_symbol(): @@ -206,7 +221,7 @@ def friendly_signature(): @staticmethod def symbol(): - return "||" + return "or" @staticmethod def wdl_symbol(): @@ -559,6 +574,10 @@ def __str__(self): def __repr__(self): return str(self) + def to_python(self, unwrap_operator, *args): + arg = unwrap_operator(self.args[0]) + return f"math.floor({arg})" + def to_wdl(self, unwrap_operator, *args): arg = unwrap_operator(self.args[0]) return f"floor({arg})" @@ -577,7 +596,7 @@ def evaluate(self, inputs): class CeilOperator(Operator): @staticmethod def friendly_signature(): - return "Numeric, NumericType -> Int" + return "Numeric -> Int" def argtypes(self) -> List[DataType]: return [NumericType] @@ -592,6 +611,10 @@ def __str__(self): def __repr__(self): return str(self) + def to_python(self, unwrap_operator, *args): + arg = unwrap_operator(self.args[0]) + return f"math.ceil({arg})" + def to_wdl(self, unwrap_operator, *args): arg = unwrap_operator(self.args[0]) return f"ceil({arg})" @@ -625,6 +648,10 @@ def __str__(self): def __repr__(self): return str(self) + def to_python(self, unwrap_operator, *args): + arg = unwrap_operator(self.args[0]) + return f"math.round({arg})" + def to_wdl(self, unwrap_operator, *args): arg = unwrap_operator(self.args[0]) return f"round({arg})" diff --git a/janis_core/operators/operator.py b/janis_core/operators/operator.py index 7c019abc1..4d0ac67a1 100644 --- a/janis_core/operators/operator.py +++ b/janis_core/operators/operator.py @@ -134,6 +134,10 @@ def to_wdl(self, unwrap_operator, *args): def to_cwl(self, unwrap_operator, *args): pass + @abstractmethod + def to_python(self, unwrap_operator, *args): + pass + def to_string_formatter(self): import re from janis_core.operators.stringformatter import StringFormatter @@ -158,7 +162,10 @@ def argtypes(self): return [Array(AnyType), Int] def returntype(self): - return self.args[0].returntype().subtype() + inner = get_instantiated_type(self.args[0].returntype()) + if isinstance(inner, Array): + return inner.subtype() + return inner def __str__(self): base, index = self.args @@ -172,6 +179,10 @@ def evaluate(self, inputs): return iterable[idx] + def to_python(self, unwrap_operator, *args): + base, index = [unwrap_operator(a) for a in self.args] + return f"{base}[{index}]" + def to_wdl(self, unwrap_operator, *args): base, index = [unwrap_operator(a) for a in self.args] return f"{base}[{index}]" @@ -219,6 +230,9 @@ def to_wdl(self, unwrap_operator, *args): def to_cwl(self, unwrap_operator, *args): return f"{self.cwl_symbol()}({unwrap_operator(*args)})" + def to_python(self, unwrap_operator, *args): + return f"{self.symbol()}({unwrap_operator(*args)})" + class TwoValueOperator(Operator, ABC): @staticmethod @@ -253,6 +267,10 @@ def to_cwl(self, unwrap_operator, *args): arg1, arg2 = [unwrap_operator(a) for a in self.args] return f"({arg1} {self.cwl_symbol()} {arg2})" + def to_python(self, unwrap_operator, *args): + arg1, arg2 = [unwrap_operator(a) for a in self.args] + return f"({arg1} {self.symbol()} {arg2})" + def __str__(self): args = self.args return f"({args[0]} {self.symbol()} {args[1]})" diff --git a/janis_core/operators/selectors.py b/janis_core/operators/selectors.py index b8d59b96f..f9186686b 100644 --- a/janis_core/operators/selectors.py +++ b/janis_core/operators/selectors.py @@ -245,6 +245,10 @@ def basename(self): return BasenameOperator(self) + def replace(self, pattern, replacement): + from .standard import ReplaceOperator + return ReplaceOperator(self, pattern, replacement) + def file_size(self): from .standard import FileSizeOperator @@ -308,6 +312,14 @@ def to_string_formatter(self): return StringFormatter(f"{{{self.input_to_select}}}", **kwarg) + def init_dictionary(self): + d = {"input_to_select": self.input_to_select} + if self.remove_file_extension is not None: + d["remove_file_extension"] = self.remove_file_extension + if not isinstance(self.type_hint, File): + d["type_hint"] = self.type_hint + return d + def __str__(self): return "inputs." + self.input_to_select @@ -332,7 +344,7 @@ def id(self): def returntype(self): out = first_value(self.input_node.outputs()).outtype - if self.input_node is not None: + if self.input_node is not None and self.input_node.default is not None: import copy out = copy.copy(out) @@ -367,8 +379,14 @@ def __init__(self, node, tag): def returntype(self): retval = self.node.outputs()[self.tag].outtype - if hasattr(self.node, "scatter") and self.node.scatter: + if self.node.node_type != NodeType.STEP: + return retval + + if hasattr(self.node, "scatter") and self.node.scatter is not None: retval = Array(retval) + elif hasattr(self.node, "foreach") and self.node.foreach is not None: + retval = Array(retval) + return retval @staticmethod @@ -412,9 +430,24 @@ def __init__(self, inner: Selector, dt: ParseableType): def returntype(self) -> DataType: return self.data_type - def to_string_formatter(self): + def __repr__(self): return f"({self.inner_selector} as {self.data_type})" + def to_string_formatter(self): + from janis_core.operators.stringformatter import StringFormatter + + return StringFormatter("{value}", value=self.inner_selector) + + +class ForEachSelector(Selector): + def returntype(self) -> DataType: + return File() + + def to_string_formatter(self): + from janis_core.operators.stringformatter import StringFormatter + + return StringFormatter("{inp}", inp=self) + class ResourceSelector(InputSelector): def __init__( diff --git a/janis_core/operators/standard.py b/janis_core/operators/standard.py index 941284778..9ee84c31c 100644 --- a/janis_core/operators/standard.py +++ b/janis_core/operators/standard.py @@ -1,5 +1,7 @@ from copy import copy from typing import List + +from janis_core.utils.logger import Logger from janis_core.types import ( DataType, UnionType, @@ -11,7 +13,7 @@ ) from janis_core.types.common_data_types import String, Array, AnyType -from janis_core.operators.operator import Operator +from janis_core.operators.operator import Operator, InputSelector class ReadContents(Operator): @@ -22,6 +24,9 @@ def friendly_signature(): def argtypes(self) -> List[DataType]: return [File()] + def to_python(self, unwrap_operator, *args): + raise NotImplementedError("Determine _safe_ one line solution for ReadContents") + def to_wdl(self, unwrap_operator, *args): arg = unwrap_operator(args[0]) return f"read_string({arg})" @@ -56,6 +61,9 @@ def evaluate(self, inputs): with open(file) as f: return load(f) + def to_python(self, unwrap_operator, *args): + raise NotImplementedError("Determine _safe_ one line solution for ReadContents") + def to_wdl(self, unwrap_operator, *args): f = unwrap_operator(self.args[0]) return f"read_json({f})" @@ -91,6 +99,10 @@ def argtypes(self): def returntype(self): return String() + def to_python(self, unwrap_operator, *args): + iterable, separator = [unwrap_operator(a) for a in self.args] + return f"{separator}.join({iterable})" + def to_wdl(self, unwrap_operator, *args): iterable, separator = [unwrap_operator(a) for a in self.args] iterable_arg = self.args[0] @@ -124,9 +136,13 @@ class BasenameOperator(Operator): def friendly_signature(): return "Union[File, Directory] -> String" + def to_python(self, unwrap_operator, *args): + arg = unwrap_operator(args[0]) + return f"os.path.basename({arg})" + def to_wdl(self, unwrap_operator, *args): - arg = args[0] - return f"basename({unwrap_operator(arg)})" + arg = unwrap_operator(args[0]) + return f"basename({arg})" def to_cwl(self, unwrap_operator, *args): arg = unwrap_operator( @@ -169,6 +185,10 @@ def __str__(self): def __repr__(self): return str(self) + def to_python(self, unwrap_operator, *args): + iterable = unwrap_operator(self.args[0]) + return f"[[{iterable}[j][i] for j in range(len({iterable}))] for i in range(len({iterable}[0]))]" + def to_wdl(self, unwrap_operator, *args): return f"transform({unwrap_operator(args[0])})" @@ -200,6 +220,10 @@ def __str__(self): def __repr__(self): return str(self) + def to_python(self, unwrap_operator, *args): + arg = unwrap_operator(self.args[0]) + return f"len({arg})" + def to_wdl(self, unwrap_operator, *args): arg = unwrap_operator(self.args[0]) return f"length({arg})" @@ -213,6 +237,41 @@ def evaluate(self, inputs): return len(ar) +class RangeOperator(Operator): + @staticmethod + def friendly_signature(): + return "Int -> Array[Int]" + + def argtypes(self): + return [Int] + + def returntype(self): + return Array(Int()) + + def __str__(self): + return f"0...{self.args[0]}" + + def __repr__(self): + return str(self) + + def to_python(self, unwrap_operator, *args): + arg = unwrap_operator(self.args[0]) + return f"range({arg})" + + def to_wdl(self, unwrap_operator, *args): + arg = unwrap_operator(self.args[0]) + return f"range({arg})" + + def to_cwl(self, unwrap_operator, *args): + arg = unwrap_operator(self.args[0]) + return f"Array.from({{ length: {arg} + 1 }}, (_, i) => i)" + # return f"{arg}.length" + + def evaluate(self, inputs): + ar = self.evaluate_arg(self.args[0], inputs) + return list(range(ar)) + + class FlattenOperator(Operator): @staticmethod def friendly_signature(): @@ -230,6 +289,10 @@ def __str__(self): def __repr__(self): return str(self) + def to_python(self, unwrap_operator, *args): + arg = unwrap_operator(self.args[0]) + return f"[el for sublist in {arg} for el in sublist]" + def to_wdl(self, unwrap_operator, *args): arg = unwrap_operator(self.args[0]) return f"flatten({arg})" @@ -261,6 +324,10 @@ def __str__(self): def __repr__(self): return str(self) + def to_python(self, unwrap_operator, *args): + prefix, iterable = [unwrap_operator(a) for a in self.args] + return f"[{prefix} + i for i in {iterable}]" + def to_wdl(self, unwrap_operator, *args): prefix, iterable = [unwrap_operator(a) for a in self.args] return f"prefix({prefix}, {iterable})" @@ -279,6 +346,33 @@ class FileSizeOperator(Operator): Returned in MB: Note that this does NOT include the reference files (yet) """ + def __new__(cls, *args, **kwargs): + multiplier = None + src, *otherargs = args + + if len(otherargs) == 1: + f = otherargs[0].lower() + multiplier_heirarchy = [ + ("ki" in f, 1024), + ("k" in f, 1000), + ("mi" in f, 1.024), + ("gi" in f, 0.001024), + ("g" in f, 0.001), + ] + if not any(m[0] for m in multiplier_heirarchy): + Logger.warn( + f"Couldn't determine prefix {f} for FileSizeOperator, defaulting to MB" + ) + else: + multiplier = [m[1] for m in multiplier_heirarchy if m[0] is True][0] + + instance = super(FileSizeOperator, cls).__new__(cls) + instance.__init__(args[0]) + + if multiplier is not None and multiplier != 1: + return instance * multiplier + return instance + @staticmethod def friendly_signature(): return "File -> Float" @@ -296,6 +390,10 @@ def __str__(self): def __repr__(self): return str(self) + def to_python(self, unwrap_operator, *args): + f = unwrap_operator(self.args[0]) + return f"os.stat({f}).st_size / 1000" + def to_wdl(self, unwrap_operator, *args): f = unwrap_operator(self.args[0]) return f'size({f}, "MB")' @@ -338,6 +436,10 @@ def __str__(self): def __repr__(self): return str(self) + def to_python(self, unwrap_operator, *args): + iterable = unwrap_operator(self.args[0]) + return f"[a for a in {iterable} if a is not None][0]" + def to_wdl(self, unwrap_operator, *args): iterable = unwrap_operator(self.args[0]) return f"select_first({iterable})" @@ -363,7 +465,17 @@ def returntype(self): if isinstance(self.args[0], list): rettype = self.args[0][0].returntype() else: - rettype = self.args[0].returntype().subtype() + outer_rettype = get_instantiated_type(self.args[0].returntype()) + if not isinstance(outer_rettype, Array): + # hmmm, this could be a bad input selector + rettype = outer_rettype + if not isinstance(self.args[0], InputSelector): + Logger.warn( + f'Expected return type of "{self.args[0]}" to be an array, ' + f'but found {outer_rettype}, will return this as a returntype.' + ) + else: + rettype = outer_rettype.subtype() rettype = copy(get_instantiated_type(rettype)) rettype.optional = False @@ -376,6 +488,10 @@ def __str__(self): def __repr__(self): return str(self) + def to_python(self, unwrap_operator, *args): + iterable = unwrap_operator(self.args[0]) + return f"[a for a in {iterable} if a is not None]" + def to_wdl(self, unwrap_operator, *args): iterable = unwrap_operator(self.args[0]) return f"select_all({iterable})" @@ -389,6 +505,36 @@ def evaluate(self, inputs): return [i for i in iterable if i is not None] +class ReplaceOperator(Operator): + + @staticmethod + def friendly_signature(): + return "Base: String, Pattern: String, Replacement: String -> String" + + def argtypes(self) -> List[DataType]: + return [String(), String(), String()] + + def evaluate(self, inputs): + base, pattern, replacement = [self.evaluate_arg(a, inputs) for a in self.args] + import re + return re.sub(pattern, replacement, base) + + def to_wdl(self, unwrap_operator, *args): + base, pattern, replacement = [unwrap_operator(a) for a in self.args] + return f"sub({base}, {pattern}, {replacement})" + + def to_cwl(self, unwrap_operator, *args): + base, pattern, replacement = [unwrap_operator(a) for a in self.args] + return f"{base}.replace(new RegExp({pattern}), {replacement})" + + def to_python(self, unwrap_operator, *args): + base, pattern, replacement = [unwrap_operator(a) for a in self.args] + return f"re.sub({pattern}, {replacement}, {base})" + + def returntype(self) -> DataType: + return String() + + # class Stdout(Operator): # @staticmethod # def friendly_signature(): diff --git a/janis_core/operators/stringformatter.py b/janis_core/operators/stringformatter.py index d4d5893e9..d703244bd 100644 --- a/janis_core/operators/stringformatter.py +++ b/janis_core/operators/stringformatter.py @@ -65,6 +65,12 @@ def to_cwl(self, unwrap_operator, *args): def to_wdl(self, unwrap_operator, *args): raise Exception("Don't use this method") + def to_python(self, unwrap_operator, *args): + f = self._format + for k, v in self.kwargs.items(): + f = f.replace(f"{{{str(k)}}}", unwrap_operator(v)) + return f + def evaluate(self, inputs): resolvedvalues = { k: self.evaluate_arg(v, inputs) for k, v in self.kwargs.items() diff --git a/janis_core/tests/test_operators.py b/janis_core/tests/test_operators.py index cc1c67111..cebcbf9a9 100644 --- a/janis_core/tests/test_operators.py +++ b/janis_core/tests/test_operators.py @@ -22,15 +22,15 @@ def test_not_operator(self): class TestAndOperator(unittest.TestCase): def test_add_operator(self): op = AndOperator("cond1", "cond2") - self.assertEqual("(cond1 && cond2)", str(op)) + self.assertEqual("(cond1 and cond2)", str(op)) def test_nested_add_operator(self): op = AndOperator("cond1", AndOperator("cond2", "cond3")) - self.assertEqual("(cond1 && (cond2 && cond3))", str(op)) + self.assertEqual("(cond1 and (cond2 and cond3))", str(op)) - def test_and_to_operator(self): + def test_and_two_operator(self): op = AndOperator("cond1", "cond2").op_and("cond3") - self.assertEqual("((cond1 && cond2) && cond3)", str(op)) + self.assertEqual("((cond1 and cond2) and cond3)", str(op)) class TestAddOperator(unittest.TestCase): diff --git a/janis_core/tests/test_translation_cwl.py b/janis_core/tests/test_translation_cwl.py index 36f3d2359..7885cf470 100644 --- a/janis_core/tests/test_translation_cwl.py +++ b/janis_core/tests/test_translation_cwl.py @@ -16,6 +16,7 @@ EchoTestTool, FilenameGeneratedTool, OperatorResourcesTestTool, + TestForEach, ) from janis_core.deps import cwlgen @@ -1021,6 +1022,17 @@ def test_basic(self): self.assertEqual("__when_inp", extra_input.id) +class TestForEachSelectors(unittest.TestCase): + def test_minimal(self): + tool = TestForEach() + # tool.translate("cwl", export_path="~/Desktop/tmp", to_disk=True) + w, _ = CwlTranslator.translate_workflow(tool) + + stp = w.steps[0] + self.assertEqual("inp", stp.in_[0].source) + self.assertEqual('$((inputs._idx + "-hello"))', stp.in_[1].valueFrom) + + cwl_testtool = """\ #!/usr/bin/env cwl-runner class: CommandLineTool diff --git a/janis_core/tests/test_translation_wdl.py b/janis_core/tests/test_translation_wdl.py index c16a6a848..0a8aa7e5c 100644 --- a/janis_core/tests/test_translation_wdl.py +++ b/janis_core/tests/test_translation_wdl.py @@ -39,6 +39,7 @@ ArrayTestTool, OperatorResourcesTestTool, TestWorkflowThatOutputsArraysOfSecondaryFiles, + TestForEach, ) from janis_core.translations import WdlTranslator from janis_core.utils.scatter import ScatterDescription, ScatterMethod @@ -1597,6 +1598,32 @@ def test_file_int_fail(self): self.assertRaises(Exception, uniontype.wdl) +class TestForEachSelectors(unittest.TestCase): + def test_minimal(self): + TestForEach().translate("wdl", to_disk=True, export_path="~/Desktop/tmp") + w, _ = WdlTranslator.translate_workflow(TestForEach()) + expected = """\ +version development + +import "tools/EchoTestTool_TEST.wdl" as E + +workflow TestForEach { + input { + Array[String] inp + } + scatter (idx in inp) { + call E.EchoTestTool as print { + input: + inp=(idx + "-hello") + } + } + output { + Array[File] out = print.out + } +}""" + self.assertEqual(expected.strip(), w.get_string().strip()) + + t = CommandToolBuilder( tool="test_readcontents", base_command=["echo", "1"], diff --git a/janis_core/tests/testtools.py b/janis_core/tests/testtools.py index 933efaf51..b3ead5972 100644 --- a/janis_core/tests/testtools.py +++ b/janis_core/tests/testtools.py @@ -18,6 +18,7 @@ InputDocumentation, InputQualityType, Workflow, + ForEachSelector, ) @@ -312,3 +313,18 @@ def constructor(self): ) self.output("out", source=self.stp.out) + + +class TestForEach(Workflow): + def constructor(self): + self.input("inp", Array(str)) + self.step( + "print", EchoTestTool(inp=ForEachSelector() + "-hello"), _foreach=self.inp + ) + self.output("out", source=self.print.out) + + def friendly_name(self): + return self.id() + + def id(self) -> str: + return "TestForEach" diff --git a/janis_core/translations/cwl.py b/janis_core/translations/cwl.py index 459ec848e..55c14c178 100644 --- a/janis_core/translations/cwl.py +++ b/janis_core/translations/cwl.py @@ -45,6 +45,7 @@ DiskSelector, ResourceSelector, AliasSelector, + ForEachSelector, ) from janis_core.operators.logical import IsDefined, If, RoundOperator from janis_core.operators.standard import FirstOperator @@ -829,6 +830,8 @@ def unwrap_selector_for_reference(cls, value): elif isinstance(value, InputSelector): return value.input_to_select + elif isinstance(value, ForEachSelector): + return "inputs._idx" elif isinstance(value, AliasSelector): return cls.unwrap_selector_for_reference(value.inner_selector) else: @@ -889,6 +892,8 @@ def unwrap_expression( return CwlTranslator.quote_values_if_code_environment( value.generated_filename(), code_environment ) + elif isinstance(value, ForEachSelector): + return "inputs._idx" if code_environment else "$(inputs._idx)" elif isinstance(value, AliasSelector): return cls.unwrap_expression( value.inner_selector, @@ -1591,14 +1596,41 @@ def translate_step_node( in_=[], out=[], ) + extra_steps: List[cwlgen.WorkflowStep] = [] ## SCATTER + scatter_fields = set() if step.scatter: if len(step.scatter.fields) > 1: cwlstep.scatterMethod = step.scatter.method.cwl() cwlstep.scatter = step.scatter.fields - scatter_fields = set(cwlstep.scatter or []) + scatter_fields = set(cwlstep.scatter or []) + + elif step.foreach is not None: + new_source = CwlTranslator.unwrap_selector_for_reference( + step.foreach, + ) + if isinstance(step.foreach, Operator): + additional_step_id = f"_evaluate_preforeach-{step.id()}" + + tool = CwlTranslator.convert_operator_to_commandtool( + step_id=additional_step_id, + operators=[step.foreach], + tool=tool, + select_first_element=True, + ) + extra_steps.append(tool) + new_source = f"{additional_step_id}/out" + + d = cwlgen.WorkflowStepInput( + id="_idx", + source=new_source, + ) + + cwlstep.in_.append(d) + cwlstep.scatter = "_idx" + scatter_fields = {"_idx"} ## OUTPUTS @@ -1608,8 +1640,6 @@ def translate_step_node( ## INPUTS - extra_steps: List[cwlgen.WorkflowStep] = [] - for k, inp in step.inputs().items(): if k not in step.sources: if inp.intype.optional or inp.default is not None: @@ -1654,7 +1684,10 @@ def translate_step_node( link_merge = None default = None - if not has_operator: + if hasattr(src, "source") and isinstance(src.source, ForEachSelector): + valuefrom = "$(_idx)" + + elif not has_operator: unwrapped_sources: List[str] = [] for stepinput in ar_source: src = stepinput.source @@ -1717,6 +1750,9 @@ def translate_step_node( if not isinstance(leaf, Selector): # probably a python literal continue + if isinstance(leaf, ForEachSelector): + continue + sel = CwlTranslator.unwrap_selector_for_reference(leaf) alias = prepare_alias(sel) param_aliasing[sel] = "inputs." + alias @@ -1946,7 +1982,10 @@ def translate_to_cwl_glob(glob, inputsdict, tool, **debugkwargs): ) elif isinstance(glob, StringFormatter): - return translate_string_formatter(glob, None, tool=tool), None + return ( + translate_string_formatter(glob, None, tool=tool, inputs_dict=inputsdict), + None, + ) elif isinstance(glob, WildcardSelector): return ( diff --git a/janis_core/translations/translationbase.py b/janis_core/translations/translationbase.py index 7eb87fa0b..ee9646975 100644 --- a/janis_core/translations/translationbase.py +++ b/janis_core/translations/translationbase.py @@ -433,13 +433,14 @@ def build_inputs_file( values_provided_from_tool = { i.id(): i.value or i.default for i in tool.input_nodes.values() - if i.value or (i.default and not isinstance(i.default, Selector)) + if i.value is not None + or (i.default is not None and not isinstance(i.default, Selector)) } inp = { i.id(): ad.get(i.id(), values_provided_from_tool.get(i.id())) for i in tool.tool_inputs() - if i.default is not None + if (i.default is not None and not isinstance(i.default, Selector)) or not i.intype.optional or i.id() in ad or i.id() in values_provided_from_tool @@ -479,20 +480,20 @@ def build_resources_input( disk = inputs.get(f"{prefix}runtime_disks", 20) seconds = inputs.get(f"{prefix}runtime_seconds", 86400) - if max_cores and cpus > max_cores: + if max_cores is not None and cpus > max_cores: Logger.info( f"Tool '{tool.id()}' exceeded ({cpus}) max number of cores ({max_cores}), " "this was dropped to the new maximum" ) cpus = max_cores - if mem and max_mem and mem > max_mem: + if mem is not None and max_mem and mem > max_mem: Logger.info( f"Tool '{tool.id()}' exceeded ({mem} GB) max amount of memory ({max_mem} GB), " "this was dropped to the new maximum" ) mem = max_mem - if seconds and max_duration and seconds > max_duration: + if seconds is not None and max_duration and seconds > max_duration: Logger.info( f"Tool '{tool.id()}' exceeded ({seconds} secs) max duration in seconds ({max_duration} secs), " "this was dropped to the new maximum" @@ -500,10 +501,18 @@ def build_resources_input( seconds = max_duration return { - prefix + "runtime_memory": mem, - prefix + "runtime_cpu": cpus, - prefix + "runtime_disks": disk, - prefix + "runtime_seconds": seconds, + prefix + "runtime_memory": mem + if not isinstance(mem, Selector) + else None, + prefix + "runtime_cpu": cpus + if not isinstance(cpus, Selector) + else None, + prefix + "runtime_disks": disk + if not isinstance(disk, Selector) + else None, + prefix + "runtime_seconds": seconds + if not isinstance(seconds, Selector) + else None, } new_inputs = {} diff --git a/janis_core/translations/wdl.py b/janis_core/translations/wdl.py index 299afb18a..663590e6f 100644 --- a/janis_core/translations/wdl.py +++ b/janis_core/translations/wdl.py @@ -20,24 +20,9 @@ from inspect import isclass from typing import List, Dict, Any, Set, Tuple, Optional -from janis_core.deps import wdlgen as wdl - -from janis_core.translationdeps.supportedtranslations import SupportedTranslation -from janis_core.operators.logical import If, IsDefined -from janis_core.operators.standard import FirstOperator -from janis_core.types import get_instantiated_type, DataType - -from janis_core.types.data_types import is_python_primitive - from janis_core.code.codetool import CodeTool +from janis_core.deps import wdlgen as wdl from janis_core.graph.steptaginput import Edge, StepTagInput -from janis_core.tool.commandtool import CommandTool, ToolInput, ToolArgument, ToolOutput -from janis_core.tool.tool import Tool, TOutput, ToolType -from janis_core.translations.translationbase import ( - TranslatorBase, - TranslatorMeta, - try_catch_translate, -) from janis_core.operators import ( InputSelector, WildcardSelector, @@ -52,7 +37,17 @@ DiskSelector, ResourceSelector, AliasSelector, + ForEachSelector, +) +from janis_core.tool.commandtool import CommandTool, ToolInput, ToolArgument, ToolOutput +from janis_core.tool.tool import Tool, ToolType +from janis_core.translationdeps.supportedtranslations import SupportedTranslation +from janis_core.translations.translationbase import ( + TranslatorBase, + TranslatorMeta, + try_catch_translate, ) +from janis_core.types import get_instantiated_type, DataType from janis_core.types.common_data_types import ( Stdout, Stderr, @@ -61,11 +56,9 @@ Filename, File, Directory, - Int, - Float, - Double, String, ) +from janis_core.types.data_types import is_python_primitive from janis_core.utils import ( first_value, recursive_2param_wrap, @@ -75,17 +68,16 @@ from janis_core.utils.generators import generate_new_id_from from janis_core.utils.logger import Logger from janis_core.utils.scatter import ScatterDescription, ScatterMethod -from janis_core.utils.validators import Validators from janis_core.utils.secondary import ( split_secondary_file_carats, apply_secondary_file_format_to_filename, ) - -# from janis_core.tool.step import StepNode - +from janis_core.utils.validators import Validators ## PRIMARY TRANSLATION METHODS -from janis_core.workflow.workflow import InputNode, StepNode +from janis_core.workflow.workflow import StepNode + +# from janis_core.tool.step import StepNode SED_REMOVE_EXTENSION = "| sed 's/\\.[^.]*$//'" REMOVE_EXTENSION = ( @@ -568,7 +560,8 @@ def unwrap_expression( } ) return cls.wrap_if_string_environment(gen_filename, string_environment) - + elif isinstance(expression, ForEachSelector): + return wrap_in_code_block("idx") elif isinstance(expression, AliasSelector): return cls.unwrap_expression( expression.inner_selector, @@ -905,8 +898,12 @@ def prepare_secondary_tool_outputs( for ex in ar_exp: inner_exp = ex for ext in potential_extensions: - inner_exp = 'sub({inp}, "\\\\{old_ext}$", "{new_ext}")'.format( - inp=inner_exp, old_ext=ext, new_ext=s.replace("^", "") + inner_exp = ( + 'sub({inp}, "\\\\{old_ext}$", "{new_ext}")'.format( + inp=inner_exp, + old_ext=ext, + new_ext=s.replace("^", ""), + ) ) exp.append(inner_exp) @@ -1051,7 +1048,6 @@ def build_inputs_file( :param tool: :return: """ - from janis_core.workflow.workflow import Workflow inp = {} values_provided_from_tool = {} @@ -1121,7 +1117,6 @@ def build_resources_input( prefix=None, is_root=False, ): - from janis_core.workflow.workflow import Workflow is_workflow = tool.type() == ToolType.Workflow d = super().build_resources_input( @@ -1449,7 +1444,7 @@ def translate_step_node( :param resource_overrides: :return: """ - from janis_core.workflow.workflow import StepNode, InputNode + from janis_core.workflow.workflow import StepNode node: StepNode = node2 step_alias: str = node.id() @@ -1636,6 +1631,14 @@ def translate_step_node( call = wrap_scatter_call( call, node.scatter, scatterable, scattered_old_to_new_identifier ) + if node2.foreach is not None: + expr = WdlTranslator.unwrap_expression( + node2.foreach, + inputsdict=inputsdict, + string_environment=False, + stepid=step_identifier, + ) + call = wdl.WorkflowScatter("idx", expr, [call]) if node.when is not None: condition = WdlTranslator.unwrap_expression( diff --git a/janis_core/utils/bracketmatching.py b/janis_core/utils/bracketmatching.py index 8083ef757..0d6fb2093 100644 --- a/janis_core/utils/bracketmatching.py +++ b/janis_core/utils/bracketmatching.py @@ -22,7 +22,7 @@ def get_keywords_between_braces( for i in range(len(text)): char = text[i] - if char == "{": + if char == "{" and (i < 0 or text[i-1] != "$"): counter += 1 highest_level = max(highest_level, counter) if start_idx is None: diff --git a/janis_core/workflow/workflow.py b/janis_core/workflow/workflow.py index 9c876792f..834235a34 100644 --- a/janis_core/workflow/workflow.py +++ b/janis_core/workflow/workflow.py @@ -14,6 +14,7 @@ InputNodeSelector, Selector, AliasSelector, + ForEachSelector, ) from janis_core.operators.logical import AndOperator, NotOperator, or_prev_conds from janis_core.operators.standard import FirstOperator @@ -54,6 +55,8 @@ def verify_or_try_get_source( return source elif isinstance(source, AliasSelector): return source + elif isinstance(source, ForEachSelector): + return source elif isinstance(source, list): return [verify_or_try_get_source(s) for s in source] @@ -122,12 +125,14 @@ def __init__( doc: DocumentationMeta = None, scatter: ScatterDescription = None, when: Operator = None, + _foreach=None, ): super().__init__(wf, NodeType.STEP, identifier) self.tool = tool self.doc = doc self.scatter = scatter self.when = when + self.foreach = _foreach self.parent_has_conditionals = False self.has_conditionals = when is not None @@ -417,17 +422,9 @@ def output( while isinstance(sourceoperator, list): sourceoperator: Selector = sourceoperator[0] - datatype: DataType = copy.copy( - get_instantiated_type(sourceoperator.returntype()).received_type() - ) - if ( - isinstance(sourceoperator, InputNodeSelector) - and sourceoperator.input_node.default is not None - ): - datatype.optional = False - - elif isinstance(sourceoperator, StepNode) and sourceoperator.scatter: - datatype = Array(datatype) + datatype: DataType = get_instantiated_type( + sourceoperator.returntype() + ).received_type() skip_typecheck = True @@ -679,6 +676,7 @@ def step( identifier: str, tool: Tool, scatter: Union[str, List[str], ScatterDescription] = None, + _foreach: Union[Selector, List[Selector]] = None, when: Optional[Operator] = None, ignore_missing=False, doc: str = None, @@ -693,6 +691,10 @@ def step( :param when: An operator / condition that determines whether the step should run :type when: Optional[Operator] :param ignore_missing: Don't throw an error if required params are missing from this function + :param _foreach: NB: this is unimplemented. Iterate for each value of this resolves list, where + you should use the "ForEachSelector" to select each value in this iterable. + + :return: """ @@ -712,6 +714,11 @@ def step( scatter = ScatterDescription(fields, method=ScatterMethod.dot) + if scatter is not None and _foreach is not None: + raise Exception( + f"Can't supply 'scatter' and 'foreach' value to step with id: {identifier} for tool: {tool.id()}" + ) + # verify scatter if scatter: ins = set(tool.inputs_map().keys()) @@ -756,7 +763,13 @@ def step( d = doc if isinstance(doc, DocumentationMeta) else DocumentationMeta(doc=doc) stp = StepNode( - self, identifier=identifier, tool=tool, scatter=scatter, when=when, doc=d + self, + identifier=identifier, + tool=tool, + scatter=scatter, + when=when, + doc=d, + _foreach=_foreach, ) added_edges = [] @@ -803,7 +816,9 @@ def step( si = e.finish.sources[e.ftag] if e.ftag else first_value(e.finish.sources) self.has_multiple_inputs = self.has_multiple_inputs or si.multiple_inputs - self.has_scatter = self.has_scatter or scatter is not None + self.has_scatter = ( + self.has_scatter or scatter is not None or _foreach is not None + ) self.has_subworkflow = self.has_subworkflow or isinstance(tool, WorkflowBase) self.nodes[identifier] = stp self.step_nodes[identifier] = stp