Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 107 additions & 12 deletions airflow_pydantic/operators/ssh.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
from logging import getLogger
from types import FunctionType, MethodType
from typing import Any
Expand All @@ -7,7 +8,7 @@
from ..airflow import SSHHook as BaseSSHHook
from ..core import Task, TaskArgs
from ..extras import BalancerHostQueryConfiguration, Host
from ..utils import BashCommands, CallablePath, ImportPath, SSHHook, get_import_path
from ..utils import BashCommands, CallablePath, ImportPath, SSHHook, Variable, get_import_path

__all__ = (
"SSHOperator",
Expand All @@ -27,6 +28,10 @@ class SSHTaskArgs(TaskArgs):
)
ssh_hook_host: Host | None = Field(default=None, exclude=True)

# Hosts to fan out over when ssh_hook is a host filter (BalancerHostQueryConfiguration kind="filter").
# Rendered / instantiated as `Operator.partial(...).expand(remote_host=[...])`.
ssh_hook_hosts: list[Host] | None = Field(default=None, exclude=True)

# Track source of hook in order to defer
ssh_hook_foo: CallablePath | None = Field(default=None, exclude=True)
ssh_hook_external: bool | None = Field(
Expand Down Expand Up @@ -75,29 +80,60 @@ def validate_command(cls, v: Any) -> Any:
else:
raise ValueError("command must be a string, list of strings, or a BashCommands model")

@staticmethod
def _assert_uniform_credentials(hosts: list[Host]) -> None:
# All hosts in a fan-out are reached through a single base SSHHook (only `remote_host` is
# mapped per task instance), so they must agree on username / password / key_file.
def _creds(host: Host):
password = host.password
password_key = password.key if isinstance(password, Variable) else password
return (host.username, password_key, host.key_file)

if len({_creds(host) for host in hosts}) > 1:
names = ", ".join(sorted(host.name for host in hosts))
raise ValueError(
f"Host filter for SSH expansion matched hosts with differing credentials ({names}); "
"all matching hosts must share username/password/key_file."
)

@model_validator(mode="before")
@classmethod
def _extract_host_from_ssh_hook(cls, data: Any) -> Any:
if isinstance(data, dict) and "ssh_hook" in data:
ssh_hook = data["ssh_hook"]
if isinstance(ssh_hook, (BalancerHostQueryConfiguration, Host)):
if isinstance(ssh_hook, BalancerHostQueryConfiguration):
# Ensure that the BalancerHostQueryConfiguration is of kind 'select'
if not ssh_hook.kind == "select":
raise ValueError("BalancerHostQueryConfiguration must be of kind 'select'")
if ssh_hook.kind == "select":
# Execute the query to get the Host, just set as it will
# be handled by the field validator
data["ssh_hook"] = ssh_hook.execute()

# Execute the query to get the Host, just set as it will
# be handled by the field validator
data["ssh_hook"] = ssh_hook.execute()

# Save the host for later use
data["ssh_hook_host"] = data["ssh_hook"]
# Save the host for later use
data["ssh_hook_host"] = data["ssh_hook"]
elif ssh_hook.kind == "filter":
# Fan out: run this task on EVERY matching host via Airflow dynamic task
# mapping (rendered as `.partial(...).expand(remote_host=[...])`).
hosts = ssh_hook.execute()
if isinstance(hosts, Host):
hosts = [hosts]
# A single base hook supplies the shared credentials while only
# `remote_host` is mapped, so all matching hosts must share credentials.
cls._assert_uniform_credentials(hosts)
if data.get("remote_host"):
raise ValueError("Cannot set 'remote_host' together with a host filter; remote_host is mapped per-host during expansion.")
data["ssh_hook_hosts"] = hosts
# Base hook carries the shared credentials; remote_host is overridden per mapped instance.
data["ssh_hook"] = hosts[0]
data["ssh_hook_host"] = hosts[0]
else:
raise ValueError("BalancerHostQueryConfiguration must be of kind 'select' or 'filter'")
else:
# If it's a Host instance, set it for later use
data["ssh_hook_host"] = ssh_hook

# Override pool from host if not otherwise set
if data["ssh_hook_host"].pool and not data.get("pool"):
# Override pool from host if not otherwise set.
# Skipped when fanning out: per-host pools cannot be mapped over a single task.
if not data.get("ssh_hook_hosts") and data["ssh_hook_host"].pool and not data.get("pool"):
data["pool"] = data["ssh_hook"].pool

if isinstance(ssh_hook, str):
Expand Down Expand Up @@ -183,6 +219,65 @@ def validate_operator(cls, v: type) -> type:
raise TypeError(f"operator must be 'airflow.providers.ssh.operators.ssh.SSHOperator', got: {v}")
return v

@staticmethod
def _remote_host_name(host: Host) -> str:
# Mirror Host.hook(use_local=True): append ".local" to an unqualified hostname.
return f"{host.name}.local" if host.name.count(".") == 0 else host.name

def instantiate(self, **kwargs):
# When ssh_hook resolves to a host filter, fan out across every matching host using Airflow
# dynamic task mapping. A single base hook supplies the shared credentials and only
# `remote_host` is mapped, producing one (independently retryable) task instance per host.
if not self.ssh_hook_hosts:
return super().instantiate(**kwargs)

from ..airflow import Pool as AirflowPool
from ..utils import Pool

args = {**self.model_dump(exclude_unset=True, exclude=["type_", "operator", "dependencies"]), **kwargs}
if "pool" in args:
if isinstance(args["pool"], dict):
args["pool"] = self.pool
if isinstance(args["pool"], (AirflowPool, Pool)):
args["pool"] = args["pool"].pool
# `remote_host` is supplied per mapped task instance, not in the partial.
args.pop("remote_host", None)
remote_hosts = [self._remote_host_name(host) for host in self.ssh_hook_hosts]
return self.operator.partial(**args).expand(remote_host=remote_hosts)

def render(self, raw: bool = False, dag_from_context: bool = False, airflow_major_version: int = 2, **kwargs):
# Default (single-host) rendering is unchanged.
if not self.ssh_hook_hosts:
return super().render(raw=raw, dag_from_context=dag_from_context, airflow_major_version=airflow_major_version, **kwargs)

# Build the regular `Operator(**args)` call, then rewrite it as
# `Operator.partial(**args).expand(remote_host=[...])`.
imports, globals_, call = super().render(raw=True, dag_from_context=dag_from_context, airflow_major_version=airflow_major_version, **kwargs)

call.keywords = [keyword for keyword in call.keywords if keyword.arg != "remote_host"]
partial_call = ast.Call(
func=ast.Attribute(value=call.func, attr="partial", ctx=ast.Load()),
args=call.args,
keywords=call.keywords,
)
remote_hosts = [self._remote_host_name(host) for host in self.ssh_hook_hosts]
call = ast.Call(
func=ast.Attribute(value=partial_call, attr="expand", ctx=ast.Load()),
args=[],
keywords=[
ast.keyword(
arg="remote_host",
value=ast.List(elts=[ast.Constant(value=name) for name in remote_hosts], ctx=ast.Load()),
)
],
)

if not raw:
imports = [ast.unparse(i) for i in imports]
globals_ = [ast.unparse(i) for i in globals_]
call = ast.unparse(call)
return imports, globals_, call


# Alias
SSHOperator = SSHTask
36 changes: 36 additions & 0 deletions airflow_pydantic/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,42 @@ def ssh_operator_balancer_template(ssh_operator_balancer):
)


@fixture
def balancer_multi():
with pools():
return BalancerConfiguration(
hosts=[
Host(
name="test_host_1",
username="test_user",
password=Variable(key="VAR", deserialize_json=True),
tags=["worker"],
),
Host(
name="test_host_2",
username="test_user",
password=Variable(key="VAR", deserialize_json=True),
tags=["worker"],
),
]
)


@fixture
def ssh_operator_balancer_filter(ssh_operator_args, balancer_multi):
ssh_operator_args.command = BashCommands(commands=["test1", "test2"], login=True, cwd="/tmp", env={"var": "{{ ti.blerg }}"})
with pools(), variables({"user": "test", "password": "password"}):
return SSHTask(
task_id="test-ssh-operator",
**ssh_operator_args.model_dump(exclude_unset=True, exclude=["ssh_hook", "pool"]),
ssh_hook=BalancerHostQueryConfiguration(
kind="filter",
balancer=balancer_multi,
tag="worker",
),
)


@fixture
def time_sensor_args():
return TimeSensorArgs(
Expand Down
18 changes: 18 additions & 0 deletions airflow_pydantic/tests/core/test_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,24 @@ def test_render_operator_ssh_host_variable_from_template(self, ssh_operator_bala
== "SSHOperator(pool=Pool.create_or_update_pool(name='test_host', slots=8, description='Balancer pool for host(test_host)', include_deferred=False).pool, do_xcom_push=True, ssh_hook=SSHHook(remote_host='test_host.local', username='test_user', password=AirflowVariable.get('VAR', deserialize_json=True)['password']), ssh_conn_id='test', command='bash -lc \\'export var=\"{{ ti.blerg }}\"\\ncd /tmp\\nset -ex\\ntest1\\ntest2\\'', cmd_timeout=10, environment={'test': 'test'}, get_pty=True, task_id='test-ssh-operator')"
)

def test_render_operator_ssh_host_filter(self, ssh_operator_balancer_filter):
# A host filter (kind="filter") fans the task out across every matching host via
# Airflow dynamic task mapping: `Operator.partial(...).expand(remote_host=[...])`.
imports, globals_, task = ssh_operator_balancer_filter.render()
assert "from airflow.providers.ssh.operators.ssh import SSHOperator" in imports
assert "from airflow.providers.ssh.hooks.ssh import SSHHook" in imports
assert "from airflow.models.variable import Variable as AirflowVariable" in imports
assert globals_ == []
assert task.startswith("SSHOperator.partial(")
assert task.endswith(".expand(remote_host=['test_host_1.local', 'test_host_2.local'])")
# A single base hook (the first matching host) carries the shared credentials.
assert (
"ssh_hook=SSHHook(remote_host='test_host_1.local', username='test_user', "
"password=AirflowVariable.get('VAR', deserialize_json=True)['password'])"
) in task
# Per-host pools cannot be mapped, so no pool is emitted for the fan-out.
assert "pool=" not in task

def test_render_dag(self, dag):
assert isinstance(dag, Dag)
assert (
Expand Down
40 changes: 40 additions & 0 deletions airflow_pydantic/tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,46 @@ def test_ssh_operator(self, ssh_operator):
return pytest.skip("Airflow not installed")
ssh_operator.instantiate()

def test_ssh_operator_host_filter_instantiate(self, ssh_operator_balancer_filter, dag_args):
if _airflow_3() is None:
return pytest.skip("Airflow not installed")
from airflow.models.mappedoperator import MappedOperator

from airflow_pydantic import Dag
from airflow_pydantic.testing import pools, variables

with pools(), variables({"user": "test", "password": "password"}):
d = Dag(
dag_id="filter-dag",
schedule=None,
**dag_args.model_dump(exclude_unset=True, exclude={"schedule"}),
tasks={"check": ssh_operator_balancer_filter},
)
dag_instance = d.instantiate()
# A host filter produces a single mapped task that fans out across every matching host.
task = dag_instance.get_task("test-ssh-operator")
assert isinstance(task, MappedOperator)

def test_ssh_operator_host_filter_requires_uniform_credentials(self):
from pydantic import ValidationError

from airflow_pydantic import BalancerConfiguration, BalancerHostQueryConfiguration, Host, SSHTask, Variable
from airflow_pydantic.testing import pools, variables

with pools(), variables({"user": "test", "password": "password"}):
balancer = BalancerConfiguration(
hosts=[
Host(name="host_a", username="user_a", password=Variable(key="VAR", deserialize_json=True), tags=["worker"]),
Host(name="host_b", username="user_b", password=Variable(key="VAR", deserialize_json=True), tags=["worker"]),
]
)
with pytest.raises((ValidationError, ValueError), match="differing credentials"):
SSHTask(
task_id="test-ssh-operator",
command="echo hi",
ssh_hook=BalancerHostQueryConfiguration(kind="filter", balancer=balancer, tag="worker"),
)

def test_bash_sensor_args(self, bash_sensor_args):
o = bash_sensor_args

Expand Down
Loading