From 335a8f2090a9f0e106c799e5ca992ab1c007c4c8 Mon Sep 17 00:00:00 2001 From: Tim Paine <3105306+timkpaine@users.noreply.github.com> Date: Tue, 16 Jun 2026 22:26:07 -0400 Subject: [PATCH] Add SSH host-filter fan-out via dynamic task mapping An SSHTask whose ssh_hook is a HostQuery(kind="filter") now fans out across every matching host using Airflow dynamic task mapping, rendered and instantiated as Operator.partial(...).expand(remote_host=[...]). - One mapped task with per-host, independently-retryable instances - A single base hook supplies shared credentials; only remote_host is mapped, so matching hosts must share username/password/key_file (validated, otherwise a clear error is raised) - Opt-in: the default kind="select" single-host behaviour is unchanged - Add render/instantiate tests and multi-host balancer fixtures --- airflow_pydantic/operators/ssh.py | 119 ++++++++++++++++++--- airflow_pydantic/tests/conftest.py | 36 +++++++ airflow_pydantic/tests/core/test_render.py | 18 ++++ airflow_pydantic/tests/test_operators.py | 40 +++++++ 4 files changed, 201 insertions(+), 12 deletions(-) diff --git a/airflow_pydantic/operators/ssh.py b/airflow_pydantic/operators/ssh.py index d6ed99ee..3bb903af 100644 --- a/airflow_pydantic/operators/ssh.py +++ b/airflow_pydantic/operators/ssh.py @@ -1,3 +1,4 @@ +import ast from logging import getLogger from types import FunctionType, MethodType from typing import Any @@ -7,7 +8,7 @@ from ..airflow import SSHHook as BaseSSHHook from ..core import Task, TaskArgs from ..extras import BalancerHostQueryConfiguration, Host -from ..utils import BashCommands, CallablePath, ImportPath, SSHHook, get_import_path +from ..utils import BashCommands, CallablePath, ImportPath, SSHHook, Variable, get_import_path __all__ = ( "SSHOperator", @@ -27,6 +28,10 @@ class SSHTaskArgs(TaskArgs): ) ssh_hook_host: Host | None = Field(default=None, exclude=True) + # Hosts to fan out over when ssh_hook is a host filter (BalancerHostQueryConfiguration kind="filter"). + # Rendered / instantiated as `Operator.partial(...).expand(remote_host=[...])`. + ssh_hook_hosts: list[Host] | None = Field(default=None, exclude=True) + # Track source of hook in order to defer ssh_hook_foo: CallablePath | None = Field(default=None, exclude=True) ssh_hook_external: bool | None = Field( @@ -75,6 +80,22 @@ def validate_command(cls, v: Any) -> Any: else: raise ValueError("command must be a string, list of strings, or a BashCommands model") + @staticmethod + def _assert_uniform_credentials(hosts: list[Host]) -> None: + # All hosts in a fan-out are reached through a single base SSHHook (only `remote_host` is + # mapped per task instance), so they must agree on username / password / key_file. + def _creds(host: Host): + password = host.password + password_key = password.key if isinstance(password, Variable) else password + return (host.username, password_key, host.key_file) + + if len({_creds(host) for host in hosts}) > 1: + names = ", ".join(sorted(host.name for host in hosts)) + raise ValueError( + f"Host filter for SSH expansion matched hosts with differing credentials ({names}); " + "all matching hosts must share username/password/key_file." + ) + @model_validator(mode="before") @classmethod def _extract_host_from_ssh_hook(cls, data: Any) -> Any: @@ -82,22 +103,37 @@ def _extract_host_from_ssh_hook(cls, data: Any) -> Any: ssh_hook = data["ssh_hook"] if isinstance(ssh_hook, (BalancerHostQueryConfiguration, Host)): if isinstance(ssh_hook, BalancerHostQueryConfiguration): - # Ensure that the BalancerHostQueryConfiguration is of kind 'select' - if not ssh_hook.kind == "select": - raise ValueError("BalancerHostQueryConfiguration must be of kind 'select'") + if ssh_hook.kind == "select": + # Execute the query to get the Host, just set as it will + # be handled by the field validator + data["ssh_hook"] = ssh_hook.execute() - # Execute the query to get the Host, just set as it will - # be handled by the field validator - data["ssh_hook"] = ssh_hook.execute() - - # Save the host for later use - data["ssh_hook_host"] = data["ssh_hook"] + # Save the host for later use + data["ssh_hook_host"] = data["ssh_hook"] + elif ssh_hook.kind == "filter": + # Fan out: run this task on EVERY matching host via Airflow dynamic task + # mapping (rendered as `.partial(...).expand(remote_host=[...])`). + hosts = ssh_hook.execute() + if isinstance(hosts, Host): + hosts = [hosts] + # A single base hook supplies the shared credentials while only + # `remote_host` is mapped, so all matching hosts must share credentials. + cls._assert_uniform_credentials(hosts) + if data.get("remote_host"): + raise ValueError("Cannot set 'remote_host' together with a host filter; remote_host is mapped per-host during expansion.") + data["ssh_hook_hosts"] = hosts + # Base hook carries the shared credentials; remote_host is overridden per mapped instance. + data["ssh_hook"] = hosts[0] + data["ssh_hook_host"] = hosts[0] + else: + raise ValueError("BalancerHostQueryConfiguration must be of kind 'select' or 'filter'") else: # If it's a Host instance, set it for later use data["ssh_hook_host"] = ssh_hook - # Override pool from host if not otherwise set - if data["ssh_hook_host"].pool and not data.get("pool"): + # Override pool from host if not otherwise set. + # Skipped when fanning out: per-host pools cannot be mapped over a single task. + if not data.get("ssh_hook_hosts") and data["ssh_hook_host"].pool and not data.get("pool"): data["pool"] = data["ssh_hook"].pool if isinstance(ssh_hook, str): @@ -183,6 +219,65 @@ def validate_operator(cls, v: type) -> type: raise TypeError(f"operator must be 'airflow.providers.ssh.operators.ssh.SSHOperator', got: {v}") return v + @staticmethod + def _remote_host_name(host: Host) -> str: + # Mirror Host.hook(use_local=True): append ".local" to an unqualified hostname. + return f"{host.name}.local" if host.name.count(".") == 0 else host.name + + def instantiate(self, **kwargs): + # When ssh_hook resolves to a host filter, fan out across every matching host using Airflow + # dynamic task mapping. A single base hook supplies the shared credentials and only + # `remote_host` is mapped, producing one (independently retryable) task instance per host. + if not self.ssh_hook_hosts: + return super().instantiate(**kwargs) + + from ..airflow import Pool as AirflowPool + from ..utils import Pool + + args = {**self.model_dump(exclude_unset=True, exclude=["type_", "operator", "dependencies"]), **kwargs} + if "pool" in args: + if isinstance(args["pool"], dict): + args["pool"] = self.pool + if isinstance(args["pool"], (AirflowPool, Pool)): + args["pool"] = args["pool"].pool + # `remote_host` is supplied per mapped task instance, not in the partial. + args.pop("remote_host", None) + remote_hosts = [self._remote_host_name(host) for host in self.ssh_hook_hosts] + return self.operator.partial(**args).expand(remote_host=remote_hosts) + + def render(self, raw: bool = False, dag_from_context: bool = False, airflow_major_version: int = 2, **kwargs): + # Default (single-host) rendering is unchanged. + if not self.ssh_hook_hosts: + return super().render(raw=raw, dag_from_context=dag_from_context, airflow_major_version=airflow_major_version, **kwargs) + + # Build the regular `Operator(**args)` call, then rewrite it as + # `Operator.partial(**args).expand(remote_host=[...])`. + imports, globals_, call = super().render(raw=True, dag_from_context=dag_from_context, airflow_major_version=airflow_major_version, **kwargs) + + call.keywords = [keyword for keyword in call.keywords if keyword.arg != "remote_host"] + partial_call = ast.Call( + func=ast.Attribute(value=call.func, attr="partial", ctx=ast.Load()), + args=call.args, + keywords=call.keywords, + ) + remote_hosts = [self._remote_host_name(host) for host in self.ssh_hook_hosts] + call = ast.Call( + func=ast.Attribute(value=partial_call, attr="expand", ctx=ast.Load()), + args=[], + keywords=[ + ast.keyword( + arg="remote_host", + value=ast.List(elts=[ast.Constant(value=name) for name in remote_hosts], ctx=ast.Load()), + ) + ], + ) + + if not raw: + imports = [ast.unparse(i) for i in imports] + globals_ = [ast.unparse(i) for i in globals_] + call = ast.unparse(call) + return imports, globals_, call + # Alias SSHOperator = SSHTask diff --git a/airflow_pydantic/tests/conftest.py b/airflow_pydantic/tests/conftest.py index 88628eee..2cd2270b 100644 --- a/airflow_pydantic/tests/conftest.py +++ b/airflow_pydantic/tests/conftest.py @@ -236,6 +236,42 @@ def ssh_operator_balancer_template(ssh_operator_balancer): ) +@fixture +def balancer_multi(): + with pools(): + return BalancerConfiguration( + hosts=[ + Host( + name="test_host_1", + username="test_user", + password=Variable(key="VAR", deserialize_json=True), + tags=["worker"], + ), + Host( + name="test_host_2", + username="test_user", + password=Variable(key="VAR", deserialize_json=True), + tags=["worker"], + ), + ] + ) + + +@fixture +def ssh_operator_balancer_filter(ssh_operator_args, balancer_multi): + ssh_operator_args.command = BashCommands(commands=["test1", "test2"], login=True, cwd="/tmp", env={"var": "{{ ti.blerg }}"}) + with pools(), variables({"user": "test", "password": "password"}): + return SSHTask( + task_id="test-ssh-operator", + **ssh_operator_args.model_dump(exclude_unset=True, exclude=["ssh_hook", "pool"]), + ssh_hook=BalancerHostQueryConfiguration( + kind="filter", + balancer=balancer_multi, + tag="worker", + ), + ) + + @fixture def time_sensor_args(): return TimeSensorArgs( diff --git a/airflow_pydantic/tests/core/test_render.py b/airflow_pydantic/tests/core/test_render.py index fcf62404..5e2d4d33 100644 --- a/airflow_pydantic/tests/core/test_render.py +++ b/airflow_pydantic/tests/core/test_render.py @@ -87,6 +87,24 @@ def test_render_operator_ssh_host_variable_from_template(self, ssh_operator_bala == "SSHOperator(pool=Pool.create_or_update_pool(name='test_host', slots=8, description='Balancer pool for host(test_host)', include_deferred=False).pool, do_xcom_push=True, ssh_hook=SSHHook(remote_host='test_host.local', username='test_user', password=AirflowVariable.get('VAR', deserialize_json=True)['password']), ssh_conn_id='test', command='bash -lc \\'export var=\"{{ ti.blerg }}\"\\ncd /tmp\\nset -ex\\ntest1\\ntest2\\'', cmd_timeout=10, environment={'test': 'test'}, get_pty=True, task_id='test-ssh-operator')" ) + def test_render_operator_ssh_host_filter(self, ssh_operator_balancer_filter): + # A host filter (kind="filter") fans the task out across every matching host via + # Airflow dynamic task mapping: `Operator.partial(...).expand(remote_host=[...])`. + imports, globals_, task = ssh_operator_balancer_filter.render() + assert "from airflow.providers.ssh.operators.ssh import SSHOperator" in imports + assert "from airflow.providers.ssh.hooks.ssh import SSHHook" in imports + assert "from airflow.models.variable import Variable as AirflowVariable" in imports + assert globals_ == [] + assert task.startswith("SSHOperator.partial(") + assert task.endswith(".expand(remote_host=['test_host_1.local', 'test_host_2.local'])") + # A single base hook (the first matching host) carries the shared credentials. + assert ( + "ssh_hook=SSHHook(remote_host='test_host_1.local', username='test_user', " + "password=AirflowVariable.get('VAR', deserialize_json=True)['password'])" + ) in task + # Per-host pools cannot be mapped, so no pool is emitted for the fan-out. + assert "pool=" not in task + def test_render_dag(self, dag): assert isinstance(dag, Dag) assert ( diff --git a/airflow_pydantic/tests/test_operators.py b/airflow_pydantic/tests/test_operators.py index 51a50671..4599f85c 100644 --- a/airflow_pydantic/tests/test_operators.py +++ b/airflow_pydantic/tests/test_operators.py @@ -51,6 +51,46 @@ def test_ssh_operator(self, ssh_operator): return pytest.skip("Airflow not installed") ssh_operator.instantiate() + def test_ssh_operator_host_filter_instantiate(self, ssh_operator_balancer_filter, dag_args): + if _airflow_3() is None: + return pytest.skip("Airflow not installed") + from airflow.models.mappedoperator import MappedOperator + + from airflow_pydantic import Dag + from airflow_pydantic.testing import pools, variables + + with pools(), variables({"user": "test", "password": "password"}): + d = Dag( + dag_id="filter-dag", + schedule=None, + **dag_args.model_dump(exclude_unset=True, exclude={"schedule"}), + tasks={"check": ssh_operator_balancer_filter}, + ) + dag_instance = d.instantiate() + # A host filter produces a single mapped task that fans out across every matching host. + task = dag_instance.get_task("test-ssh-operator") + assert isinstance(task, MappedOperator) + + def test_ssh_operator_host_filter_requires_uniform_credentials(self): + from pydantic import ValidationError + + from airflow_pydantic import BalancerConfiguration, BalancerHostQueryConfiguration, Host, SSHTask, Variable + from airflow_pydantic.testing import pools, variables + + with pools(), variables({"user": "test", "password": "password"}): + balancer = BalancerConfiguration( + hosts=[ + Host(name="host_a", username="user_a", password=Variable(key="VAR", deserialize_json=True), tags=["worker"]), + Host(name="host_b", username="user_b", password=Variable(key="VAR", deserialize_json=True), tags=["worker"]), + ] + ) + with pytest.raises((ValidationError, ValueError), match="differing credentials"): + SSHTask( + task_id="test-ssh-operator", + command="echo hi", + ssh_hook=BalancerHostQueryConfiguration(kind="filter", balancer=balancer, tag="worker"), + ) + def test_bash_sensor_args(self, bash_sensor_args): o = bash_sensor_args