Skip to content

Commit

Permalink
Refactor ancestor and successor client code
Browse files Browse the repository at this point in the history
  • Loading branch information
talsperre committed Nov 1, 2024
1 parent ec43f14 commit a84f463
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 103 deletions.
191 changes: 95 additions & 96 deletions metaflow/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,64 +1123,108 @@ def _iter_filter(self, x):
# exclude private data artifacts
return x.id[0] != "_"

def immediate_ancestors(self) -> Dict[str, List[str]]:
def _get_task_for_queried_step(self, flow_id, run_id, query_step):
"""
Returns a dictionary of immediate ancestors task ids of this task for each
previous step.
Returns
-------
Dict[str, List[str]]
Dictionary of immediate ancestors of this task. The keys are the
names of the ancestors steps and the values are the corresponding
task ids of the ancestors.
Returns a Task object corresponding to the queried step.
If the queried step has several tasks, the first task is returned.
"""
# Find any previous task for current step
step = Step(f"{flow_id}/{run_id}/{query_step}", _namespace_check=False)
task = next(iter(step.tasks()), None)
if task:
return task
raise MetaflowNotFound(f"No task found for the queried step {query_step}")

def _prev_task(flow_id, run_id, previous_step):
# Find any previous task for current step
step = Step(f"{flow_id}/{run_id}/{previous_step}", _namespace_check=False)
task = next(iter(step.tasks()), None)
if task:
return task
raise MetaflowNotFound(f"No previous task found for step {previous_step}")

flow_id, run_id, step_name, task_id = self.path_components
previous_steps = self.metadata_dict.get("previous_steps", None)

if not previous_steps or len(previous_steps) == 0:
return

cur_foreach_stack_len = len(self.metadata_dict.get("foreach-stack", []))
ancestor_iters = {}
if len(previous_steps) > 1:
def _get_filter_query_value(
self, flow_id, run_id, cur_foreach_stack_len, query_steps, query_type
):
"""
For a given query type, returns the field name and value to be used for filtering tasks
based on the task's metadata.
"""
if len(query_steps) > 1:
# This is a static join, so there is no change in foreach stack length
prev_foreach_stack_len = cur_foreach_stack_len
query_foreach_stack_len = cur_foreach_stack_len
else:
prev_task = _prev_task(flow_id, run_id, previous_steps[0])
prev_foreach_stack_len = len(
prev_task.metadata_dict.get("foreach-stack", [])
query_task = self._get_task_for_queried_step(
flow_id, run_id, query_steps[0]
)
query_foreach_stack_len = len(
query_task.metadata_dict.get("foreach-stack", [])
)

if prev_foreach_stack_len == cur_foreach_stack_len:
# print(f"query_foreach_stack_len: {query_foreach_stack_len} cur_foreach_stack_len: {cur_foreach_stack_len}")
if query_foreach_stack_len == cur_foreach_stack_len:
field_name = "foreach-indices"
field_value = self.metadata_dict.get(field_name)
elif prev_foreach_stack_len > cur_foreach_stack_len:
field_name = "foreach-indices-truncated"
field_value = self.metadata_dict.get("foreach-indices")
elif query_type == "ancestor":
if query_foreach_stack_len > cur_foreach_stack_len:
# This is a foreach join
# We will compare the foreach-indices-truncated value of ancestor task with the
# foreach-indices value of current task
field_name = "foreach-indices-truncated"
field_value = self.metadata_dict.get("foreach-indices")
else:
# This is a foreach split
# We will compare the foreach-indices value of ancestor task with the
# foreach-indices value of current task
field_name = "foreach-indices"
field_value = self.metadata_dict.get("foreach-indices-truncated")
else:
# We will compare the foreach-stack-truncated value of current task with the
# foreach-stack value of tasks in previous steps
field_name = "foreach-indices"
field_value = self.metadata_dict.get("foreach-indices-truncated")
if query_foreach_stack_len > cur_foreach_stack_len:
# This is a foreach split
# We will compare the foreach-indices value of current task with the
# foreach-indices-truncated value of successor tasks
field_name = "foreach-indices-truncated"
field_value = self.metadata_dict.get("foreach-indices")
else:
# This is a foreach join
# We will compare the foreach-indices-truncated value of current task with the
# foreach-indices value of successor tasks
field_name = "foreach-indices"
field_value = self.metadata_dict.get("foreach-indices-truncated")
return field_name, field_value

def _get_related_tasks(
self, steps_key: str, relation_type: str
) -> Dict[str, List[str]]:
flow_id, run_id, _, _ = self.path_components
query_steps = self.metadata_dict.get(steps_key)

if not query_steps:
return {}

field_name, field_value = self._get_filter_query_value(
flow_id,
run_id,
len(self.metadata_dict.get("foreach-stack", [])),
query_steps,
relation_type,
)

for prev_step in previous_steps:
ancestor_iters[prev_step] = (
self._metaflow.metadata.filter_tasks_by_metadata(
flow_id, run_id, step_name, prev_step, field_name, field_value
)
return {
query_step: self._metaflow.metadata.filter_tasks_by_metadata(
flow_id, run_id, query_step, field_name, field_value
)
return ancestor_iters
for query_step in query_steps
}

@property
def immediate_ancestors(self) -> Dict[str, List[str]]:
"""
Returns a dictionary of immediate ancestors task ids of this task for each
previous step.
Returns
-------
Dict[str, List[str]]
Dictionary of immediate ancestors of this task. The keys are the
names of the ancestors steps and the values are the corresponding
task ids of the ancestors.
"""
return self._get_related_tasks("previous_steps", "ancestor")

@property
def immediate_successors(self) -> Dict[str, List[str]]:
"""
Returns a dictionary of immediate successors task ids of this task for each
Expand All @@ -1193,55 +1237,10 @@ def immediate_successors(self) -> Dict[str, List[str]]:
names of the successors steps and the values are the corresponding
task ids of the successors.
"""
return self._get_related_tasks("successor_steps", "successor")

def _successor_task(flow_id, run_id, successor_step):
# Find any previous task for current step
step = Step(f"{flow_id}/{run_id}/{successor_step}", _namespace_check=False)
task = next(iter(step.tasks()), None)
if task:
return task
raise MetaflowNotFound(f"No successor task found for step {successor_step}")

flow_id, run_id, step_name, task_id = self.path_components
successor_steps = self.metadata_dict.get("successor_steps", None)

if not successor_steps or len(successor_steps) == 0:
return

cur_foreach_stack_len = len(self.metadata_dict.get("foreach-stack", []))
successor_iters = {}
if len(successor_steps) > 1:
# This is a static split, so there is no change in foreach stack length
successor_foreach_stack_len = cur_foreach_stack_len
else:
successor_task = _successor_task(flow_id, run_id, successor_steps[0])
successor_foreach_stack_len = len(
successor_task.metadata_dict.get("foreach-stack", [])
)

if successor_foreach_stack_len == cur_foreach_stack_len:
field_name = "foreach-indices"
field_value = self.metadata_dict.get(field_name)
elif successor_foreach_stack_len > cur_foreach_stack_len:
# We will compare the foreach-indices value of current task with the
# foreach-indices-truncated value of tasks in successor steps
field_name = "foreach-indices-truncated"
field_value = self.metadata_dict.get("foreach-indices")
else:
# We will compare the foreach-stack-truncated value of current task with the
# foreach-stack value of tasks in successor steps
field_name = "foreach-indices"
field_value = self.metadata_dict.get("foreach-indices-truncated")

for successor_step in successor_steps:
successor_iters[successor_step] = (
self._metaflow.metadata.filter_tasks_by_metadata(
flow_id, run_id, step_name, successor_step, field_name, field_value
)
)
return successor_iters

def closest_siblings(self) -> Dict[str, List[str]]:
@property
def immediate_siblings(self) -> Dict[str, List[str]]:
"""
Returns a dictionary of closest siblings of this task for each step.
Expand All @@ -1252,13 +1251,13 @@ def closest_siblings(self) -> Dict[str, List[str]]:
names of the current step and the values are the corresponding
task ids of the siblings.
"""
flow_id, run_id, step_name, task_id = self.path_components
flow_id, run_id, step_name, _ = self.path_components

foreach_stack = self.metadata_dict.get("foreach-stack", [])
foreach_step_names = self.metadata_dict.get("foreach-step-names", [])
if len(foreach_stack) == 0:
raise MetaflowInternalError("Task is not part of any foreach split")
elif step_name != foreach_step_names[-1]:
if step_name != foreach_step_names[-1]:
raise MetaflowInternalError(
f"Step {step_name} does not have any direct siblings since it is not part "
f"of a new foreach split."
Expand All @@ -1269,7 +1268,7 @@ def closest_siblings(self) -> Dict[str, List[str]]:
# We find all tasks of the same step that have the same foreach-indices-truncated value
return {
step_name: self._metaflow.metadata.filter_tasks_by_metadata(
flow_id, run_id, step_name, step_name, field_name, field_value
flow_id, run_id, step_name, field_name, field_value
)
}

Expand Down
7 changes: 4 additions & 3 deletions metaflow/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,16 +674,17 @@ def _register_system_metadata(self, run_id, step_name, task_id, attempt):

@classmethod
def _filter_tasks_by_metadata(
cls, flow_id, run_id, step_name, query_step, field_name, field_value
cls, flow_id, run_id, query_step, field_name, field_value
):
raise NotImplementedError()

@classmethod
def filter_tasks_by_metadata(
cls, flow_id, run_id, step_name, query_step, field_name, field_value
cls, flow_id, run_id, query_step, field_name, field_value
):
# TODO: Do we need to do anything wrt to task attempt?
task_ids = cls._filter_tasks_by_metadata(
flow_id, run_id, step_name, query_step, field_name, field_value
flow_id, run_id, query_step, field_name, field_value
)
return task_ids

Expand Down
4 changes: 0 additions & 4 deletions metaflow/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,10 +611,6 @@ def run_step(
# user artifacts in the user's step code.

if join_type:
if join_type == "foreach":
# We only want to persist one of the input paths
self.flow._input_paths = str(input_paths[0])

# Join step:

# Ensure that we have the right number of inputs. The
Expand Down

0 comments on commit a84f463

Please sign in to comment.