Skip to content

Commit

Permalink
Fix bug in parsing steps due to different data formats across metadat…
Browse files Browse the repository at this point in the history
…a services
  • Loading branch information
talsperre committed Jan 22, 2025
1 parent 2c9a81b commit 7833e40
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 18 deletions.
36 changes: 21 additions & 15 deletions metaflow/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,7 +1120,6 @@ class Task(MetaflowObject):
def __init__(self, *args, **kwargs):
super(Task, self).__init__(*args, **kwargs)
# We want to cache metadata dictionary since it's used in many places
self._metadata_dict = None

def _iter_filter(self, x):
# exclude private data artifacts
Expand All @@ -1140,6 +1139,7 @@ def _get_metadata_query_vals(
cur_foreach_stack_len: int,
steps: List[str],
is_ancestor: bool,
metadata_dict: Dict[str, Any],
):
"""
Returns the field name and field value to be used for querying metadata of successor or ancestor tasks.
Expand All @@ -1157,7 +1157,10 @@ def _get_metadata_query_vals(
ancestors and successors across multiple steps.
is_ancestor : bool
If we are querying for ancestor tasks, set this to True.
metadata_dict : Dict[str, Any]
Cached metadata dictionary of the current task
"""

# For each task, we also log additional metadata fields such as foreach-indices and foreach-indices-truncated
# which help us in querying ancestor and successor tasks.
# `foreach-indices`: contains the indices of the foreach stack at the time of task execution.
Expand All @@ -1181,7 +1184,7 @@ def _get_metadata_query_vals(
if query_foreach_stack_len == cur_foreach_stack_len:
# The successor or ancestor tasks belong to the same foreach stack level
field_name = "foreach-indices"
field_value = self.metadata_dict.get(field_name)
field_value = metadata_dict.get(field_name)
elif is_ancestor:
if query_foreach_stack_len > cur_foreach_stack_len:
# This is a foreach join
Expand All @@ -1190,15 +1193,15 @@ def _get_metadata_query_vals(
# We will compare the foreach-indices-truncated value of ancestor task with the
# foreach-indices value of current task
field_name = "foreach-indices-truncated"
field_value = self.metadata_dict.get("foreach-indices")
field_value = metadata_dict.get("foreach-indices")
else:
# This is a foreach split
# Current Task: foreach-indices = [0, 1, 2], foreach-indices-truncated = [0, 1]
# Ancestor Task: foreach-indices = [0, 1], foreach-indices-truncated = [0]
# We will compare the foreach-indices value of ancestor task with the
# foreach-indices-truncated value of current task
field_name = "foreach-indices"
field_value = self.metadata_dict.get("foreach-indices-truncated")
field_value = metadata_dict.get("foreach-indices-truncated")
else:
if query_foreach_stack_len > cur_foreach_stack_len:
# This is a foreach split
Expand All @@ -1207,34 +1210,40 @@ def _get_metadata_query_vals(
# We will compare the foreach-indices value of current task with the
# foreach-indices-truncated value of successor tasks
field_name = "foreach-indices-truncated"
field_value = self.metadata_dict.get("foreach-indices")
field_value = metadata_dict.get("foreach-indices")
else:
# This is a foreach join
# Current Task: foreach-indices = [0, 1, 2], foreach-indices-truncated = [0, 1]
# Successor Task: foreach-indices = [0, 1], foreach-indices-truncated = [0]
# We will compare the foreach-indices-truncated value of current task with the
# foreach-indices value of successor tasks
field_name = "foreach-indices"
field_value = self.metadata_dict.get("foreach-indices-truncated")
field_value = metadata_dict.get("foreach-indices-truncated")
return field_name, field_value

def _get_related_tasks(self, is_ancestor: bool) -> Dict[str, List[str]]:
flow_id, run_id, _, _ = self.path_components
metadata_dict = self.metadata_dict
steps = (
self.metadata_dict.get("previous-steps")
metadata_dict.get("previous-steps")
if is_ancestor
else self.metadata_dict.get("successor-steps")
else metadata_dict.get("successor-steps")
)

if not steps:
return {}

# Convert steps to a list if it's stored as a string in the metadata
if is_stringish(steps):
steps = [steps]

field_name, field_value = self._get_metadata_query_vals(
flow_id,
run_id,
len(self.metadata_dict.get("foreach-indices", [])),
len(metadata_dict.get("foreach-indices", [])),
steps,
is_ancestor=is_ancestor,
metadata_dict=metadata_dict,
)

return {
Expand Down Expand Up @@ -1419,12 +1428,9 @@ def metadata_dict(self) -> Dict[str, str]:
Dictionary mapping metadata name with value
"""
# use the newest version of each key, hence sorting
if self._metadata_dict is None:
self._metadata_dict = {
m.name: m.value
for m in sorted(self.metadata, key=lambda m: m.created_at)
}
return self._metadata_dict
return {
m.name: m.value for m in sorted(self.metadata, key=lambda m: m.created_at)
}

@property
def index(self) -> Optional[int]:
Expand Down
7 changes: 4 additions & 3 deletions metaflow/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,15 @@ def _dynamic_runtime_metadata(foreach_stack):
foreach_step_names = [foreach_frame.step for foreach_frame in foreach_stack]
return foreach_indices, foreach_indices_truncated, foreach_step_names

def _static_runtime_metadata(self, graph_info, step_name):
@staticmethod
def _static_runtime_metadata(graph_info, step_name):
prev_steps = [
node_name
for node_name, attributes in graph_info["steps"].items()
if step_name in attributes["next"]
]
succesor_steps = graph_info["steps"][step_name]["next"]
return prev_steps, succesor_steps
successor_steps = graph_info["steps"][step_name]["next"]
return prev_steps, successor_steps

def __init__(
self,
Expand Down

0 comments on commit 7833e40

Please sign in to comment.