From 0d61609acd817a4f276213e2805bbe772d9b0693 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Thu, 28 Nov 2024 16:01:43 -0500 Subject: [PATCH 1/4] microbatch: split out first and last batch to run in serial --- core/dbt/task/run.py | 85 ++++++++++++++++++++++++++++---------------- 1 file changed, 55 insertions(+), 30 deletions(-) diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index b9cf7e22a08..414009dc73e 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -703,52 +703,77 @@ def handle_microbatch_model( runner: MicrobatchModelRunner, pool: ThreadPool, ) -> RunResult: - # Initial run computes batch metadata, unless model is skipped + # Initial run computes batch metadata result = self.call_runner(runner) + # Return early if model should be skipped, or there are no batches to execute if result.status == RunStatus.Skipped: return result + elif len(runner.batches) == 0: + return result - batch_results: List[RunResult] = [] - - # Execute batches serially until a relation exists, at which point future batches are run in parallel relation_exists = runner.relation_exists + batch_results: List[RunResult] = [] batch_idx = 0 - while batch_idx < len(runner.batches): - batch_runner = MicrobatchModelRunner( - self.config, runner.adapter, deepcopy(runner.node), self.run_count, self.num_nodes - ) - batch_runner.set_batch_idx(batch_idx) - batch_runner.set_relation_exists(relation_exists) - batch_runner.set_batches(runner.batches) - - if runner._should_run_in_parallel(relation_exists): - fire_event( - MicrobatchExecutionDebug( - msg=f"{batch_runner.describe_batch} is being run concurrently" - ) - ) - self._submit(pool, [batch_runner], batch_results.append) - else: - fire_event( - MicrobatchExecutionDebug( - msg=f"{batch_runner.describe_batch} is being run sequentially" - ) - ) - batch_results.append(self.call_runner(batch_runner)) - relation_exists = batch_runner.relation_exists + # Run first batch runs in serial + relation_exists = self._submit_batch( + runner, relation_exists, batch_idx, batch_results, pool, parallel=False + ) + batch_idx += 1 + # Subsequent batches can be run in parallel + while batch_idx < len(runner.batches) - 1: + parallel = runner._should_run_in_parallel(relation_exists) + relation_exists = self._submit_batch( + runner, relation_exists, batch_idx, batch_results, pool, parallel + ) batch_idx += 1 - - # Wait until all batches have completed - while len(batch_results) != len(runner.batches): + # Wait until all submitted batches have completed + while len(batch_results) != batch_idx: pass + # Final batch runs in serial + self._submit_batch(runner, relation_exists, batch_idx, batch_results, pool, parallel=False) + runner.merge_batch_results(result, batch_results) track_model_run(runner.node_index, runner.num_nodes, result, adapter=runner.adapter) runner.print_result_line(result) return result + def _submit_batch( + self, + runner: MicrobatchModelRunner, + relation_exists: bool, + batch_idx: int, + batch_results: List[RunResult], + pool: ThreadPool, + parallel: bool, + ): + batch_runner = MicrobatchModelRunner( + self.config, runner.adapter, deepcopy(runner.node), self.run_count, self.num_nodes + ) + batch_runner.set_batch_idx(batch_idx) + batch_runner.set_relation_exists(relation_exists) + batch_runner.set_batches(runner.batches) + + if parallel: + fire_event( + MicrobatchExecutionDebug( + msg=f"{batch_runner.describe_batch} is being run concurrently" + ) + ) + self._submit(pool, [batch_runner], batch_results.append) + else: + fire_event( + MicrobatchExecutionDebug( + msg=f"{batch_runner.describe_batch} is being run sequentially" + ) + ) + batch_results.append(self.call_runner(batch_runner)) + relation_exists = batch_runner.relation_exists + + return relation_exists + def _hook_keyfunc(self, hook: HookNode) -> Tuple[str, Optional[int]]: package_name = hook.package_name if package_name == self.config.project_name: From bec5d57114cb04a08c37df438ac957f9d23a3cb5 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Thu, 28 Nov 2024 16:23:05 -0500 Subject: [PATCH 2/4] use Task.get_runner --- core/dbt/task/run.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index 414009dc73e..1d8300884e6 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -705,34 +705,41 @@ def handle_microbatch_model( ) -> RunResult: # Initial run computes batch metadata result = self.call_runner(runner) + batches = runner.batches + node = runner.node + relation_exists = runner.relation_exists + # Return early if model should be skipped, or there are no batches to execute if result.status == RunStatus.Skipped: return result elif len(runner.batches) == 0: return result - relation_exists = runner.relation_exists batch_results: List[RunResult] = [] batch_idx = 0 # Run first batch runs in serial relation_exists = self._submit_batch( - runner, relation_exists, batch_idx, batch_results, pool, parallel=False + node, relation_exists, batches, batch_idx, batch_results, pool, parallel=False ) batch_idx += 1 + # Subsequent batches can be run in parallel while batch_idx < len(runner.batches) - 1: parallel = runner._should_run_in_parallel(relation_exists) relation_exists = self._submit_batch( - runner, relation_exists, batch_idx, batch_results, pool, parallel + node, relation_exists, batches, batch_idx, batch_results, pool, parallel ) batch_idx += 1 + # Wait until all submitted batches have completed while len(batch_results) != batch_idx: pass # Final batch runs in serial - self._submit_batch(runner, relation_exists, batch_idx, batch_results, pool, parallel=False) + self._submit_batch( + node, relation_exists, batches, batch_idx, batch_results, pool, parallel=False + ) runner.merge_batch_results(result, batch_results) track_model_run(runner.node_index, runner.num_nodes, result, adapter=runner.adapter) @@ -742,19 +749,20 @@ def handle_microbatch_model( def _submit_batch( self, - runner: MicrobatchModelRunner, + node: ModelNode, relation_exists: bool, + batches: Dict[int, BatchType], batch_idx: int, batch_results: List[RunResult], pool: ThreadPool, parallel: bool, ): - batch_runner = MicrobatchModelRunner( - self.config, runner.adapter, deepcopy(runner.node), self.run_count, self.num_nodes - ) + batch_runner = self.get_runner(deepcopy(node)) + assert isinstance(batch_runner, MicrobatchModelRunner) + batch_runner.set_batch_idx(batch_idx) batch_runner.set_relation_exists(relation_exists) - batch_runner.set_batches(runner.batches) + batch_runner.set_batches(batches) if parallel: fire_event( From 32002ea69f075050c81d8842f0c770b04845a8e8 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Thu, 28 Nov 2024 16:43:30 -0500 Subject: [PATCH 3/4] only run pre_hook on first batch, post_hook on last batch --- core/dbt/task/run.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index 1d8300884e6..28b0b9734c7 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -757,9 +757,16 @@ def _submit_batch( pool: ThreadPool, parallel: bool, ): - batch_runner = self.get_runner(deepcopy(node)) + node_copy = deepcopy(node) + # Only run pre_hook(s) for first batch + if batch_idx != 0: + node_copy.config.pre_hook = [] + # Only run post_hook(s) for last batch + elif batch_idx != len(batches) - 1: + node_copy.config.post_hook = [] + + batch_runner = self.get_runner(node_copy) assert isinstance(batch_runner, MicrobatchModelRunner) - batch_runner.set_batch_idx(batch_idx) batch_runner.set_relation_exists(relation_exists) batch_runner.set_batches(batches) From 37dbd118d3043ce1ebe67d19ec121396275bdbf8 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Thu, 28 Nov 2024 17:06:51 -0500 Subject: [PATCH 4/4] refactor: internalize parallel to RunTask._submit_batch --- core/dbt/task/run.py | 39 +++++++------------ .../functional/microbatch/test_microbatch.py | 2 +- tests/unit/task/test_run.py | 8 ++-- 3 files changed, 19 insertions(+), 30 deletions(-) diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index 28b0b9734c7..f5bcc578541 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -602,15 +602,15 @@ def _has_relation(self, model) -> bool: ) return relation is not None - def _should_run_in_parallel( - self, - relation_exists: bool, - ) -> bool: + def should_run_in_parallel(self) -> bool: if not self.adapter.supports(Capability.MicrobatchConcurrency): run_in_parallel = False - elif not relation_exists: + elif not self.relation_exists: # If the relation doesn't exist, we can't run in parallel run_in_parallel = False + elif self.batch_idx == 0 or self.batch_idx == len(self.batches) - 1: + # First and last batch don't run in parallel + run_in_parallel = False elif self.node.config.concurrent_batches is not None: # If the relation exists and the `concurrent_batches` config isn't None, use the config value run_in_parallel = self.node.config.concurrent_batches @@ -705,9 +705,7 @@ def handle_microbatch_model( ) -> RunResult: # Initial run computes batch metadata result = self.call_runner(runner) - batches = runner.batches - node = runner.node - relation_exists = runner.relation_exists + batches, node, relation_exists = runner.batches, runner.node, runner.relation_exists # Return early if model should be skipped, or there are no batches to execute if result.status == RunStatus.Skipped: @@ -717,30 +715,20 @@ def handle_microbatch_model( batch_results: List[RunResult] = [] batch_idx = 0 - - # Run first batch runs in serial - relation_exists = self._submit_batch( - node, relation_exists, batches, batch_idx, batch_results, pool, parallel=False - ) - batch_idx += 1 - - # Subsequent batches can be run in parallel + # Run all batches except last batch, in parallel if possible while batch_idx < len(runner.batches) - 1: - parallel = runner._should_run_in_parallel(relation_exists) relation_exists = self._submit_batch( - node, relation_exists, batches, batch_idx, batch_results, pool, parallel + node, relation_exists, batches, batch_idx, batch_results, pool ) batch_idx += 1 # Wait until all submitted batches have completed while len(batch_results) != batch_idx: pass + # Final batch runs once all others complete to ensure post_hook runs at the end + self._submit_batch(node, relation_exists, batches, batch_idx, batch_results, pool) - # Final batch runs in serial - self._submit_batch( - node, relation_exists, batches, batch_idx, batch_results, pool, parallel=False - ) - + # Finalize run: merge results, track model run, and print final result line runner.merge_batch_results(result, batch_results) track_model_run(runner.node_index, runner.num_nodes, result, adapter=runner.adapter) runner.print_result_line(result) @@ -755,7 +743,6 @@ def _submit_batch( batch_idx: int, batch_results: List[RunResult], pool: ThreadPool, - parallel: bool, ): node_copy = deepcopy(node) # Only run pre_hook(s) for first batch @@ -764,14 +751,14 @@ def _submit_batch( # Only run post_hook(s) for last batch elif batch_idx != len(batches) - 1: node_copy.config.post_hook = [] - + batch_runner = self.get_runner(node_copy) assert isinstance(batch_runner, MicrobatchModelRunner) batch_runner.set_batch_idx(batch_idx) batch_runner.set_relation_exists(relation_exists) batch_runner.set_batches(batches) - if parallel: + if batch_runner.should_run_in_parallel(): fire_event( MicrobatchExecutionDebug( msg=f"{batch_runner.describe_batch} is being run concurrently" diff --git a/tests/functional/microbatch/test_microbatch.py b/tests/functional/microbatch/test_microbatch.py index e3acc415273..56537ad48cf 100644 --- a/tests/functional/microbatch/test_microbatch.py +++ b/tests/functional/microbatch/test_microbatch.py @@ -875,7 +875,7 @@ def batch_exc_catcher(self) -> EventCatcher: def test_microbatch( self, mocker: MockerFixture, project, batch_exc_catcher: EventCatcher ) -> None: - mocked_srip = mocker.patch("dbt.task.run.MicrobatchModelRunner._should_run_in_parallel") + mocked_srip = mocker.patch("dbt.task.run.MicrobatchModelRunner.should_run_in_parallel") # Should be run in parallel mocked_srip.return_value = True diff --git a/tests/unit/task/test_run.py b/tests/unit/task/test_run.py index b33e6f57ffe..b28ac505a7f 100644 --- a/tests/unit/task/test_run.py +++ b/tests/unit/task/test_run.py @@ -264,7 +264,7 @@ class Relation: (False, False, False, True, False), ], ) - def test__should_run_in_parallel( + def test_should_run_in_parallel( self, mocker: MockerFixture, model_runner: MicrobatchModelRunner, @@ -276,11 +276,13 @@ def test__should_run_in_parallel( ) -> None: model_runner.node._has_this = has_this model_runner.node.config = ModelConfig(concurrent_batches=concurrent_batches) + model_runner.set_relation_exists(has_relation) + mocked_supports = mocker.patch.object(model_runner.adapter, "supports") mocked_supports.return_value = adapter_microbatch_concurrency - # Assert result of _should_run_in_parallel - assert model_runner._should_run_in_parallel(has_relation) == expectation + # Assert result of should_run_in_parallel + assert model_runner.should_run_in_parallel() == expectation class TestRunTask: