Skip to content

Commit

Permalink
Add More Typing to the dbt.task Module (#10622)
Browse files Browse the repository at this point in the history
* Add typing to task module.

* More typing in the task module

* Still more types for task module
  • Loading branch information
peterallenwebb authored Aug 28, 2024
1 parent e1fa461 commit f7d21e0
Show file tree
Hide file tree
Showing 16 changed files with 107 additions and 92 deletions.
2 changes: 2 additions & 0 deletions core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,8 @@ def compile_node(
the node's raw_code into compiled_code, and then calls the
recursive method to "prepend" the ctes.
"""
# REVIEW: UnitTestDefinition shouldn't be possible here because of the
# type of node, and it is likewise an invalid return type.
if isinstance(node, UnitTestDefinition):
return node

Expand Down
2 changes: 1 addition & 1 deletion core/dbt/plugins/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ class PluginNodes:
def add_model(self, model_args: ModelNodeArgs) -> None:
self.models[model_args.unique_id] = model_args

def update(self, other: "PluginNodes"):
def update(self, other: "PluginNodes") -> None:
self.models.update(other.models)
42 changes: 18 additions & 24 deletions core/dbt/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,10 @@
from dbt.task.printer import print_run_result_error
from dbt_common.events.contextvars import get_node_info
from dbt_common.events.functions import fire_event
from dbt_common.exceptions import (
CompilationError,
DbtInternalError,
DbtRuntimeError,
NotImplementedError,
)
from dbt_common.exceptions import DbtInternalError, DbtRuntimeError, NotImplementedError


def read_profiles(profiles_dir=None):
def read_profiles(profiles_dir: Optional[str] = None) -> Dict[str, Any]:
"""This is only used for some error handling"""
if profiles_dir is None:
profiles_dir = get_flags().PROFILES_DIR
Expand Down Expand Up @@ -123,7 +118,7 @@ def __init__(
self.manifest = manifest
self.compiler = Compiler(self.config)

def compile_manifest(self):
def compile_manifest(self) -> None:
if self.manifest is None:
raise DbtInternalError("compile_manifest called before manifest was loaded")

Expand Down Expand Up @@ -165,7 +160,7 @@ def __init__(self, node) -> None:


class BaseRunner(metaclass=ABCMeta):
def __init__(self, config, adapter, node, node_index, num_nodes) -> None:
def __init__(self, config, adapter, node, node_index: int, num_nodes: int) -> None:
self.config = config
self.compiler = Compiler(config)
self.adapter = adapter
Expand Down Expand Up @@ -272,7 +267,7 @@ def from_run_result(self, result, start_time, timing_info):
failures=result.failures,
)

def compile_and_execute(self, manifest, ctx):
def compile_and_execute(self, manifest: Manifest, ctx: ExecutionContext):
result = None
with (
self.adapter.connection_named(self.node.unique_id, self.node)
Expand Down Expand Up @@ -305,7 +300,7 @@ def compile_and_execute(self, manifest, ctx):

return result

def _handle_catchable_exception(self, e, ctx):
def _handle_catchable_exception(self, e: DbtRuntimeError, ctx: ExecutionContext) -> str:
if e.node is None:
e.add_node(ctx.node)

Expand All @@ -316,15 +311,15 @@ def _handle_catchable_exception(self, e, ctx):
)
return str(e)

def _handle_internal_exception(self, e, ctx):
def _handle_internal_exception(self, e: DbtInternalError, ctx: ExecutionContext) -> str:
fire_event(
InternalErrorOnRun(
build_path=self._node_build_path(), exc=str(e), node_info=get_node_info()
)
)
return str(e)

def _handle_generic_exception(self, e, ctx):
def _handle_generic_exception(self, e: Exception, ctx: ExecutionContext) -> str:
fire_event(
GenericExceptionOnRun(
build_path=self._node_build_path(),
Expand All @@ -337,17 +332,16 @@ def _handle_generic_exception(self, e, ctx):

return str(e)

def handle_exception(self, e, ctx):
catchable_errors = (CompilationError, DbtRuntimeError)
if isinstance(e, catchable_errors):
def handle_exception(self, e: Exception, ctx: ExecutionContext) -> str:
if isinstance(e, DbtRuntimeError):
error = self._handle_catchable_exception(e, ctx)
elif isinstance(e, DbtInternalError):
error = self._handle_internal_exception(e, ctx)
else:
error = self._handle_generic_exception(e, ctx)
return error

def safe_run(self, manifest):
def safe_run(self, manifest: Manifest):
started = time.time()
ctx = ExecutionContext(self.node)
error = None
Expand Down Expand Up @@ -394,19 +388,19 @@ def _safe_release_connection(self):

return None

def before_execute(self):
raise NotImplementedError()
def before_execute(self) -> None:
raise NotImplementedError("before_execute is not implemented")

def execute(self, compiled_node, manifest):
raise NotImplementedError()
raise NotImplementedError("execute is not implemented")

def run(self, compiled_node, manifest):
return self.execute(compiled_node, manifest)

def after_execute(self, result):
raise NotImplementedError()
def after_execute(self, result) -> None:
raise NotImplementedError("after_execute is not implemented")

def _skip_caused_by_ephemeral_failure(self):
def _skip_caused_by_ephemeral_failure(self) -> bool:
if self.skip_cause is None or self.skip_cause.node is None:
return False
return self.skip_cause.node.is_ephemeral_model
Expand Down Expand Up @@ -461,7 +455,7 @@ def on_skip(self):
node_result = RunResult.from_node(self.node, RunStatus.Skipped, error_message)
return node_result

def do_skip(self, cause=None):
def do_skip(self, cause=None) -> None:
self.skip = True
self.skip_cause = cause

Expand Down
14 changes: 7 additions & 7 deletions core/dbt/task/build.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import threading
from typing import Dict, List, Set
from typing import Dict, List, Optional, Set, Type

from dbt.artifacts.schemas.results import NodeStatus, RunStatus
from dbt.artifacts.schemas.run import RunResult
Expand All @@ -24,16 +24,16 @@
class SavedQueryRunner(BaseRunner):
# Stub. No-op Runner for Saved Queries, which require MetricFlow for execution.
@property
def description(self):
def description(self) -> str:
return f"saved query {self.node.name}"

def before_execute(self):
def before_execute(self) -> None:
pass

def compile(self, manifest):
def compile(self, manifest: Manifest):
return self.node

def after_execute(self, result):
def after_execute(self, result) -> None:
fire_event(
LogNodeNoOpResult(
description=self.description,
Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(self, args: Flags, config: RuntimeConfig, manifest: Manifest) -> No
self.selected_unit_tests: Set = set()
self.model_to_unit_test_map: Dict[str, List] = {}

def resource_types(self, no_unit_tests=False):
def resource_types(self, no_unit_tests: bool = False) -> List[NodeType]:
resource_types = resource_types_from_args(
self.args, set(self.ALL_RESOURCE_VALUES), set(self.ALL_RESOURCE_VALUES)
)
Expand Down Expand Up @@ -210,7 +210,7 @@ def get_node_selector(self, no_unit_tests=False) -> ResourceTypeSelector:
resource_types=resource_types,
)

def get_runner_type(self, node):
def get_runner_type(self, node) -> Optional[Type[BaseRunner]]:
return self.RUNNER_MAP.get(node.resource_type)

# Special build compile_manifest method to pass add_test_edges to the compiler
Expand Down
17 changes: 9 additions & 8 deletions core/dbt/task/clone.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import threading
from typing import AbstractSet, Any, Iterable, List, Optional, Set
from typing import AbstractSet, Any, Collection, Iterable, List, Optional, Set, Type

from dbt.adapters.base import BaseRelation
from dbt.artifacts.resources.types import NodeType
from dbt.artifacts.schemas.run import RunResult, RunStatus
from dbt.clients.jinja import MacroGenerator
from dbt.context.providers import generate_runtime_model_context
Expand All @@ -16,10 +17,10 @@


class CloneRunner(BaseRunner):
def before_execute(self):
def before_execute(self) -> None:
pass

def after_execute(self, result):
def after_execute(self, result) -> None:
pass

def _build_run_model_result(self, model, context):
Expand All @@ -44,7 +45,7 @@ def _build_run_model_result(self, model, context):
failures=None,
)

def compile(self, manifest):
def compile(self, manifest: Manifest):
# no-op
return self.node

Expand Down Expand Up @@ -91,7 +92,7 @@ def execute(self, model, manifest):


class CloneTask(GraphRunnableTask):
def raise_on_first_error(self):
def raise_on_first_error(self) -> bool:
return False

def get_run_mode(self) -> GraphRunnableMode:
Expand Down Expand Up @@ -133,8 +134,8 @@ def before_run(self, adapter, selected_uids: AbstractSet[str]):
self.populate_adapter_cache(adapter, schemas_to_cache)

@property
def resource_types(self):
resource_types = resource_types_from_args(
def resource_types(self) -> List[NodeType]:
resource_types: Collection[NodeType] = resource_types_from_args(
self.args, set(REFABLE_NODE_TYPES), set(REFABLE_NODE_TYPES)
)

Expand All @@ -154,5 +155,5 @@ def get_node_selector(self) -> ResourceTypeSelector:
resource_types=resource_types,
)

def get_runner_type(self, _):
def get_runner_type(self, _) -> Optional[Type[BaseRunner]]:
return CloneRunner
18 changes: 10 additions & 8 deletions core/dbt/task/compile.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import threading
from typing import Optional, Type

from dbt.artifacts.schemas.run import RunResult, RunStatus
from dbt.contracts.graph.manifest import Manifest
from dbt.events.types import CompiledNode, ParseInlineNodeError
from dbt.graph import ResourceTypeSelector
from dbt.node_types import EXECUTABLE_NODE_TYPES, NodeType
Expand All @@ -17,10 +19,10 @@


class CompileRunner(BaseRunner):
def before_execute(self):
def before_execute(self) -> None:
pass

def after_execute(self, result):
def after_execute(self, result) -> None:
pass

def execute(self, compiled_node, manifest):
Expand All @@ -35,7 +37,7 @@ def execute(self, compiled_node, manifest):
failures=None,
)

def compile(self, manifest):
def compile(self, manifest: Manifest):
return self.compiler.compile_node(self.node, manifest, {})


Expand All @@ -44,7 +46,7 @@ class CompileTask(GraphRunnableTask):
# it should be removed before the task is complete
_inline_node_id = None

def raise_on_first_error(self):
def raise_on_first_error(self) -> bool:
return True

def get_node_selector(self) -> ResourceTypeSelector:
Expand All @@ -62,10 +64,10 @@ def get_node_selector(self) -> ResourceTypeSelector:
resource_types=resource_types,
)

def get_runner_type(self, _):
def get_runner_type(self, _) -> Optional[Type[BaseRunner]]:
return CompileRunner

def task_end_messages(self, results):
def task_end_messages(self, results) -> None:
is_inline = bool(getattr(self.args, "inline", None))
output_format = getattr(self.args, "output", "text")

Expand Down Expand Up @@ -127,14 +129,14 @@ def _runtime_initialize(self):
raise DbtException("Error parsing inline query")
super()._runtime_initialize()

def after_run(self, adapter, results):
def after_run(self, adapter, results) -> None:
# remove inline node from manifest
if self._inline_node_id:
self.manifest.nodes.pop(self._inline_node_id)
self._inline_node_id = None
super().after_run(adapter, results)

def _handle_result(self, result):
def _handle_result(self, result) -> None:
super()._handle_result(result)

if (
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/task/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def test_connection(self) -> SubtaskStatus:
return status

@classmethod
def validate_connection(cls, target_dict):
def validate_connection(cls, target_dict) -> None:
"""Validate a connection dictionary. On error, raises a DbtConfigError."""
target_name = "test"
# make a fake profile that we can parse
Expand Down
Loading

0 comments on commit f7d21e0

Please sign in to comment.