diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index b9cf7e22a08..f5bcc578541 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -602,15 +602,15 @@ def _has_relation(self, model) -> bool: ) return relation is not None - def _should_run_in_parallel( - self, - relation_exists: bool, - ) -> bool: + def should_run_in_parallel(self) -> bool: if not self.adapter.supports(Capability.MicrobatchConcurrency): run_in_parallel = False - elif not relation_exists: + elif not self.relation_exists: # If the relation doesn't exist, we can't run in parallel run_in_parallel = False + elif self.batch_idx == 0 or self.batch_idx == len(self.batches) - 1: + # First and last batch don't run in parallel + run_in_parallel = False elif self.node.config.concurrent_batches is not None: # If the relation exists and the `concurrent_batches` config isn't None, use the config value run_in_parallel = self.node.config.concurrent_batches @@ -703,52 +703,79 @@ def handle_microbatch_model( runner: MicrobatchModelRunner, pool: ThreadPool, ) -> RunResult: - # Initial run computes batch metadata, unless model is skipped + # Initial run computes batch metadata result = self.call_runner(runner) + batches, node, relation_exists = runner.batches, runner.node, runner.relation_exists + + # Return early if model should be skipped, or there are no batches to execute if result.status == RunStatus.Skipped: return result + elif len(runner.batches) == 0: + return result batch_results: List[RunResult] = [] - - # Execute batches serially until a relation exists, at which point future batches are run in parallel - relation_exists = runner.relation_exists batch_idx = 0 - while batch_idx < len(runner.batches): - batch_runner = MicrobatchModelRunner( - self.config, runner.adapter, deepcopy(runner.node), self.run_count, self.num_nodes + # Run all batches except last batch, in parallel if possible + while batch_idx < len(runner.batches) - 1: + relation_exists = self._submit_batch( + node, relation_exists, batches, batch_idx, batch_results, pool ) - batch_runner.set_batch_idx(batch_idx) - batch_runner.set_relation_exists(relation_exists) - batch_runner.set_batches(runner.batches) - - if runner._should_run_in_parallel(relation_exists): - fire_event( - MicrobatchExecutionDebug( - msg=f"{batch_runner.describe_batch} is being run concurrently" - ) - ) - self._submit(pool, [batch_runner], batch_results.append) - else: - fire_event( - MicrobatchExecutionDebug( - msg=f"{batch_runner.describe_batch} is being run sequentially" - ) - ) - batch_results.append(self.call_runner(batch_runner)) - relation_exists = batch_runner.relation_exists - batch_idx += 1 - # Wait until all batches have completed - while len(batch_results) != len(runner.batches): + # Wait until all submitted batches have completed + while len(batch_results) != batch_idx: pass + # Final batch runs once all others complete to ensure post_hook runs at the end + self._submit_batch(node, relation_exists, batches, batch_idx, batch_results, pool) + # Finalize run: merge results, track model run, and print final result line runner.merge_batch_results(result, batch_results) track_model_run(runner.node_index, runner.num_nodes, result, adapter=runner.adapter) runner.print_result_line(result) return result + def _submit_batch( + self, + node: ModelNode, + relation_exists: bool, + batches: Dict[int, BatchType], + batch_idx: int, + batch_results: List[RunResult], + pool: ThreadPool, + ): + node_copy = deepcopy(node) + # Only run pre_hook(s) for first batch + if batch_idx != 0: + node_copy.config.pre_hook = [] + # Only run post_hook(s) for last batch + elif batch_idx != len(batches) - 1: + node_copy.config.post_hook = [] + + batch_runner = self.get_runner(node_copy) + assert isinstance(batch_runner, MicrobatchModelRunner) + batch_runner.set_batch_idx(batch_idx) + batch_runner.set_relation_exists(relation_exists) + batch_runner.set_batches(batches) + + if batch_runner.should_run_in_parallel(): + fire_event( + MicrobatchExecutionDebug( + msg=f"{batch_runner.describe_batch} is being run concurrently" + ) + ) + self._submit(pool, [batch_runner], batch_results.append) + else: + fire_event( + MicrobatchExecutionDebug( + msg=f"{batch_runner.describe_batch} is being run sequentially" + ) + ) + batch_results.append(self.call_runner(batch_runner)) + relation_exists = batch_runner.relation_exists + + return relation_exists + def _hook_keyfunc(self, hook: HookNode) -> Tuple[str, Optional[int]]: package_name = hook.package_name if package_name == self.config.project_name: diff --git a/tests/functional/microbatch/test_microbatch.py b/tests/functional/microbatch/test_microbatch.py index e3acc415273..56537ad48cf 100644 --- a/tests/functional/microbatch/test_microbatch.py +++ b/tests/functional/microbatch/test_microbatch.py @@ -875,7 +875,7 @@ def batch_exc_catcher(self) -> EventCatcher: def test_microbatch( self, mocker: MockerFixture, project, batch_exc_catcher: EventCatcher ) -> None: - mocked_srip = mocker.patch("dbt.task.run.MicrobatchModelRunner._should_run_in_parallel") + mocked_srip = mocker.patch("dbt.task.run.MicrobatchModelRunner.should_run_in_parallel") # Should be run in parallel mocked_srip.return_value = True diff --git a/tests/unit/task/test_run.py b/tests/unit/task/test_run.py index b33e6f57ffe..b28ac505a7f 100644 --- a/tests/unit/task/test_run.py +++ b/tests/unit/task/test_run.py @@ -264,7 +264,7 @@ class Relation: (False, False, False, True, False), ], ) - def test__should_run_in_parallel( + def test_should_run_in_parallel( self, mocker: MockerFixture, model_runner: MicrobatchModelRunner, @@ -276,11 +276,13 @@ def test__should_run_in_parallel( ) -> None: model_runner.node._has_this = has_this model_runner.node.config = ModelConfig(concurrent_batches=concurrent_batches) + model_runner.set_relation_exists(has_relation) + mocked_supports = mocker.patch.object(model_runner.adapter, "supports") mocked_supports.return_value = adapter_microbatch_concurrency - # Assert result of _should_run_in_parallel - assert model_runner._should_run_in_parallel(has_relation) == expectation + # Assert result of should_run_in_parallel + assert model_runner.should_run_in_parallel() == expectation class TestRunTask: