Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EXPERIMENTAL] Add transfer nodes #21

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@ jobs:
matrix:
python: ['3.9', '3.10']
experimental: [false]
experimental-features: [false, true]
include:
- python: '3.11.0-alpha - 3.11.0'
experimental: true
experimental-features: false
- python: '3.11.0-alpha - 3.11.0'
experimental: true
experimental-features: true
continue-on-error: ${{ matrix.experimental }}
steps:
- name: Check out repository
Expand Down Expand Up @@ -44,6 +49,8 @@ jobs:
- name: Install library
run: poetry install --no-interaction
- name: Run tests
env:
AIRFLOW_DIAGRAMS__EXPERIMENTAL: ${{ matrix.experimental-features }}
run: poetry run pytest --cov-report=xml --cov=airflow_diagrams tests/
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v2
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ Then just call it like this:

_Examples of generated diagrams can be found in the [examples](examples) directory._

## 🧪 Experimental Features

* **Transfer Nodes**: Convert Airflow transfer operators into two tasks i.e. source & destination grouped in a cluster

## 🤔 How it Works

1. ℹ️ It connects, by using the official [Apache Airflow Python Client](https://github.com/apache/airflow-client-python), to your Airflow installation to retrieve all DAGs (in case you don't specify any `dag_id`) and all Tasks for the DAG(s).
Expand Down
2 changes: 2 additions & 0 deletions Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ tasks:
# 4. Render diagram
- cd examples && python3 dbt_diagrams.py
fake-dag:
env:
AIRFLOW_DIAGRAMS__EXPERIMENTAL: true
cmds:
# 1. Create fake dag
- python3 dev/airflow/airflow_dags_creator.py
Expand Down
7 changes: 6 additions & 1 deletion airflow_diagrams/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
"""Top-level package for airflow-diagrams."""
from importlib.metadata import version
from os import getcwd
from os import getcwd, getenv
from os.path import dirname, join, realpath

__app_name__ = "airflow-diagrams"
__version__ = version(__name__)
__location__ = realpath(join(getcwd(), dirname(__file__)))
__experimental__ = getenv("AIRFLOW_DIAGRAMS__EXPERIMENTAL", "False").lower() in (
"true",
"1",
"t",
)
63 changes: 63 additions & 0 deletions airflow_diagrams/airflow.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import re
from dataclasses import dataclass
from typing import Generator, Optional

from airflow_client.client.api.dag_api import DAGApi
from airflow_client.client.api_client import ApiClient, Configuration

from airflow_diagrams.class_ref import ClassRef
from airflow_diagrams.utils import experimental


@dataclass
Expand All @@ -16,6 +18,21 @@ class AirflowTask:
downstream_task_ids: list[str]
group_name: Optional[str]

def __hash__(self) -> int:
"""
Build a hash based on all attributes.

:returns: a hash of all attributes.
"""
return (
hash(self.class_ref)
^ hash(self.task_id)
^ hash(
downstream_task_id for downstream_task_id in self.downstream_task_ids
)
^ hash(self.group_name)
)

def __str__(self) -> str:
"""
Define pretty string reprenstation.
Expand Down Expand Up @@ -64,6 +81,52 @@ def get_tasks(self) -> list[AirflowTask]:
]


@experimental
def transfer_nodes(tasks: list[AirflowTask]) -> None:
"""
Transfer Nodes replaces an Airflow transfer task by two tasks i.e. source & destination clustered.

:param tasks: The tasks to modify.
"""
transfer_tasks = [
(task, match.groups())
for task in tasks
if task.class_ref.module_path
and ".transfers." in task.class_ref.module_path
and (match := re.search(r"(\w+)To(\w+)", task.class_ref.class_name))
]

for task, (source_class_name, destination_class_name) in transfer_tasks:
source_task_id = f"[SOURCE] {task.task_id}"
destination_task_id = f"[DESTINATION] {task.task_id}"
source = AirflowTask(
class_ref=ClassRef(
module_path=None, # We don't know if the original module_path belongs to source or destination
class_name=source_class_name,
),
task_id=source_task_id,
downstream_task_ids=[destination_task_id],
group_name=task.task_id,
)
destination = AirflowTask(
class_ref=ClassRef(
module_path=None, # We don't know if the original module_path belongs to source or destination
class_name=destination_class_name,
),
task_id=destination_task_id,
downstream_task_ids=task.downstream_task_ids,
group_name=task.task_id,
)
tasks.extend([source, destination])
tasks.remove(task)

transfer_task_ids = list(map(lambda task: task[0].task_id, transfer_tasks))
for t_idx, t in enumerate(tasks):
for dt_idx, dt_id in enumerate(t.downstream_task_ids):
if dt_id in transfer_task_ids:
tasks[t_idx].downstream_task_ids[dt_idx] = f"[SOURCE] {dt_id}"


class AirflowApiTree:
"""Retrieve Airflow Api information as a Tree."""

Expand Down
11 changes: 7 additions & 4 deletions airflow_diagrams/class_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class ClassRef:
"""A unique reference to a python class."""

module_path: str
module_path: Optional[str]
class_name: str

def __hash__(self) -> int:
Expand All @@ -29,7 +29,9 @@ def __str__(self) -> str:

:returns: the string representation of the class ref.
"""
return f"{self.module_path}.{self.class_name}"
if self.module_path:
return f"{self.module_path}.{self.class_name}"
return self.class_name

@staticmethod
def from_string(string: str) -> "ClassRef":
Expand All @@ -40,8 +42,9 @@ def from_string(string: str) -> "ClassRef":

:returns: the ClassRef object.
"""
module_path, class_name = string.rsplit(".", 1)
return ClassRef(module_path, class_name)
if "." in string:
return ClassRef(*string.rsplit(".", 1))
return ClassRef(module_path=None, class_name=string)


@dataclass
Expand Down
12 changes: 11 additions & 1 deletion airflow_diagrams/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typer import Argument, Exit, Option

from airflow_diagrams import __app_name__, __version__
from airflow_diagrams.airflow import retrieve_airflow_info
from airflow_diagrams.airflow import retrieve_airflow_info, transfer_nodes
from airflow_diagrams.class_ref import (
ClassRef,
ClassRefMatcher,
Expand Down Expand Up @@ -119,6 +119,11 @@ def generate( # dead: disable
exists=True,
dir_okay=False,
),
experimental: bool = Argument(
False,
envvar="AIRFLOW_DIAGRAMS__EXPERIMENTAL",
help="Enable experimental features by setting the variable to 'true'.",
),
) -> None:
if verbose:
rprint("💬Running with verbose output...")
Expand All @@ -130,6 +135,9 @@ def generate( # dead: disable
)
install(max_frames=0)

if experimental:
rprint("🧪Running with experimental features...")

mappings: dict = load_mappings(mapping_file) if mapping_file else {}

diagrams_class_refs: list[ClassRef] = retrieve_class_refs(
Expand Down Expand Up @@ -189,6 +197,8 @@ def generate( # dead: disable
rprint(f"[blue]🪄 Processing Airflow DAG {airflow_dag_id}...")
diagram_context = DiagramContext(airflow_dag_id)

transfer_nodes(airflow_tasks)

for airflow_task in airflow_tasks:
rprint(f"[blue dim] 🪄 Processing {airflow_task}...")
class_ref_matcher = ClassRefMatcher(
Expand Down
2 changes: 1 addition & 1 deletion airflow_diagrams/diagram.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ with Diagram("{{ name }}", show=False):
{% for node in nodes -%}
{% if node.cluster -%}
with {{ node.cluster.get_variable() }}:
{{ node.get_variable() }} = {{ node.class_name }}("{{ node.get_label(label_wrap) }}")
{{ node.get_variable() }} = {{ node.class_name }}()
{% else -%}
{{ node.get_variable() }} = {{ node.class_name }}("{{ node.get_label(label_wrap) }}")
{% endif -%}
Expand Down
14 changes: 13 additions & 1 deletion airflow_diagrams/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
import os
from pathlib import Path

import yaml

from airflow_diagrams import __location__
from airflow_diagrams import __experimental__, __location__


def load_abbreviations() -> dict:
Expand Down Expand Up @@ -31,3 +32,14 @@ def load_mappings(file: Path) -> dict:
"r",
) as mapping_yaml:
return yaml.safe_load(mapping_yaml)


def experimental(func):
"""Decorate experimental features."""

def wrapper(*args, **kwargs):
if __experimental__:
logging.debug("Calling experimental feature: %s", func.__name__)
func(*args, **kwargs)

return wrapper
Binary file modified assets/images/usage.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
78 changes: 77 additions & 1 deletion tests/test_airflow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from airflow_diagrams.airflow import AirflowDag, AirflowTask
import pytest

from airflow_diagrams import __experimental__
from airflow_diagrams.airflow import AirflowDag, AirflowTask, transfer_nodes
from airflow_diagrams.class_ref import ClassRef


Expand Down Expand Up @@ -38,6 +41,79 @@ def test_airflow_dag_get_tasks(airflow_api_tree):
]


@pytest.mark.skipif(not __experimental__, reason="Transfer nodes are experimental.")
def test_transfer_nodes():
"""Test getting tasks from Airflow DAG"""
tasks = [
AirflowTask(
class_ref=ClassRef(
module_path=None,
class_name="Fizz",
),
task_id="test_task_0",
downstream_task_ids=["test_task_1"],
group_name=None,
),
AirflowTask(
class_ref=ClassRef(
module_path="foo.transfers.bar",
class_name="FooToBar",
),
task_id="test_task_1",
downstream_task_ids=["test_task_2"],
group_name=None,
),
AirflowTask(
class_ref=ClassRef(
module_path=None,
class_name="Fizz",
),
task_id="test_task_2",
downstream_task_ids=[],
group_name=None,
),
]
transfer_nodes(tasks)
assert set(tasks) == {
AirflowTask(
class_ref=ClassRef(
module_path=None,
class_name="Fizz",
),
task_id="test_task_0",
downstream_task_ids=["[SOURCE] test_task_1"],
group_name=None,
),
AirflowTask(
class_ref=ClassRef(
module_path=None,
class_name="Foo",
),
task_id="[SOURCE] test_task_1",
downstream_task_ids=["[DESTINATION] test_task_1"],
group_name="test_task_1",
),
AirflowTask(
class_ref=ClassRef(
module_path=None,
class_name="Bar",
),
task_id="[DESTINATION] test_task_1",
downstream_task_ids=["test_task_2"],
group_name="test_task_1",
),
AirflowTask(
class_ref=ClassRef(
module_path=None,
class_name="Fizz",
),
task_id="test_task_2",
downstream_task_ids=[],
group_name=None,
),
}


def test_airflow_api_tree_get_dags(airflow_api_tree):
"""Test getting dags from Airflow API Tree"""
airflow_api_tree.dag_api.get_dags.return_value = dict(
Expand Down
18 changes: 18 additions & 0 deletions tests/test_class_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ def class_ref():
)


@pytest.fixture()
def class_ref_without_module_path():
return ClassRef(
module_path=None,
class_name="ClassNameOperator",
)


@pytest.fixture()
def class_ref_matcher(class_ref):
return ClassRefMatcher(
Expand All @@ -39,6 +47,16 @@ def test_class_ref_str_and_from_string(class_ref):
assert ClassRef.from_string(str(class_ref)) == class_ref


def test_class_ref_str_and_from_string_without_module_path(
class_ref_without_module_path,
):
"""Test converting a ClassRef to str & creating a ClassRef from a string"""
assert (
ClassRef.from_string(str(class_ref_without_module_path))
== class_ref_without_module_path
)


def test_class_ref_matcher_match(class_ref_matcher):
"""Test matching"""
assert class_ref_matcher.match() == class_ref_matcher.choices[0]
Expand Down
Loading