-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Microbatch first last batch serial #11072
base: main
Are you sure you want to change the base?
Changes from all commits
0d61609
bec5d57
32002ea
37dbd11
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -602,15 +602,15 @@ | |
) | ||
return relation is not None | ||
|
||
def _should_run_in_parallel( | ||
self, | ||
relation_exists: bool, | ||
) -> bool: | ||
def should_run_in_parallel(self) -> bool: | ||
if not self.adapter.supports(Capability.MicrobatchConcurrency): | ||
run_in_parallel = False | ||
elif not relation_exists: | ||
elif not self.relation_exists: | ||
# If the relation doesn't exist, we can't run in parallel | ||
run_in_parallel = False | ||
elif self.batch_idx == 0 or self.batch_idx == len(self.batches) - 1: | ||
# First and last batch don't run in parallel | ||
run_in_parallel = False | ||
elif self.node.config.concurrent_batches is not None: | ||
# If the relation exists and the `concurrent_batches` config isn't None, use the config value | ||
run_in_parallel = self.node.config.concurrent_batches | ||
|
@@ -703,52 +703,79 @@ | |
runner: MicrobatchModelRunner, | ||
pool: ThreadPool, | ||
) -> RunResult: | ||
# Initial run computes batch metadata, unless model is skipped | ||
# Initial run computes batch metadata | ||
result = self.call_runner(runner) | ||
batches, node, relation_exists = runner.batches, runner.node, runner.relation_exists | ||
|
||
# Return early if model should be skipped, or there are no batches to execute | ||
if result.status == RunStatus.Skipped: | ||
return result | ||
elif len(runner.batches) == 0: | ||
return result | ||
|
||
batch_results: List[RunResult] = [] | ||
|
||
# Execute batches serially until a relation exists, at which point future batches are run in parallel | ||
relation_exists = runner.relation_exists | ||
batch_idx = 0 | ||
while batch_idx < len(runner.batches): | ||
batch_runner = MicrobatchModelRunner( | ||
self.config, runner.adapter, deepcopy(runner.node), self.run_count, self.num_nodes | ||
# Run all batches except last batch, in parallel if possible | ||
while batch_idx < len(runner.batches) - 1: | ||
relation_exists = self._submit_batch( | ||
node, relation_exists, batches, batch_idx, batch_results, pool | ||
Comment on lines
+719
to
+721
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another reason for splitting out the first batch: |
||
) | ||
batch_runner.set_batch_idx(batch_idx) | ||
batch_runner.set_relation_exists(relation_exists) | ||
batch_runner.set_batches(runner.batches) | ||
|
||
if runner._should_run_in_parallel(relation_exists): | ||
fire_event( | ||
MicrobatchExecutionDebug( | ||
msg=f"{batch_runner.describe_batch} is being run concurrently" | ||
) | ||
) | ||
self._submit(pool, [batch_runner], batch_results.append) | ||
else: | ||
fire_event( | ||
MicrobatchExecutionDebug( | ||
msg=f"{batch_runner.describe_batch} is being run sequentially" | ||
) | ||
) | ||
batch_results.append(self.call_runner(batch_runner)) | ||
relation_exists = batch_runner.relation_exists | ||
|
||
batch_idx += 1 | ||
|
||
# Wait until all batches have completed | ||
while len(batch_results) != len(runner.batches): | ||
# Wait until all submitted batches have completed | ||
while len(batch_results) != batch_idx: | ||
pass | ||
# Final batch runs once all others complete to ensure post_hook runs at the end | ||
self._submit_batch(node, relation_exists, batches, batch_idx, batch_results, pool) | ||
|
||
# Finalize run: merge results, track model run, and print final result line | ||
runner.merge_batch_results(result, batch_results) | ||
track_model_run(runner.node_index, runner.num_nodes, result, adapter=runner.adapter) | ||
runner.print_result_line(result) | ||
|
||
return result | ||
|
||
def _submit_batch( | ||
self, | ||
node: ModelNode, | ||
relation_exists: bool, | ||
batches: Dict[int, BatchType], | ||
batch_idx: int, | ||
batch_results: List[RunResult], | ||
pool: ThreadPool, | ||
): | ||
node_copy = deepcopy(node) | ||
# Only run pre_hook(s) for first batch | ||
if batch_idx != 0: | ||
node_copy.config.pre_hook = [] | ||
# Only run post_hook(s) for last batch | ||
elif batch_idx != len(batches) - 1: | ||
node_copy.config.post_hook = [] | ||
|
||
batch_runner = self.get_runner(node_copy) | ||
assert isinstance(batch_runner, MicrobatchModelRunner) | ||
batch_runner.set_batch_idx(batch_idx) | ||
batch_runner.set_relation_exists(relation_exists) | ||
batch_runner.set_batches(batches) | ||
|
||
if batch_runner.should_run_in_parallel(): | ||
fire_event( | ||
MicrobatchExecutionDebug( | ||
msg=f"{batch_runner.describe_batch} is being run concurrently" | ||
) | ||
) | ||
self._submit(pool, [batch_runner], batch_results.append) | ||
else: | ||
fire_event( | ||
MicrobatchExecutionDebug( | ||
msg=f"{batch_runner.describe_batch} is being run sequentially" | ||
) | ||
) | ||
batch_results.append(self.call_runner(batch_runner)) | ||
relation_exists = batch_runner.relation_exists | ||
|
||
return relation_exists | ||
|
||
def _hook_keyfunc(self, hook: HookNode) -> Tuple[str, Optional[int]]: | ||
package_name = hook.package_name | ||
if package_name == self.config.project_name: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check could also be skipped if we're instead handling
force_sequential
to determine if we should even checkshould_run_in_parallel
in_submit_batch
. It'd be nice for this function to be less dependent on "where" it is, and I think this check breaks that.