Skip to content

Commit fecb121

Browse files
authored
Add filter option for runs linked to model version (#4003)
* Add filter option for runs linked to model version * Tests and filter on client method * Typo
1 parent 1bdfff4 commit fecb121

File tree

3 files changed

+62
-0
lines changed

3 files changed

+62
-0
lines changed

src/zenml/client.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4168,6 +4168,7 @@ def list_pipeline_runs(
41684168
template_id: Optional[Union[str, UUID]] = None,
41694169
source_snapshot_id: Optional[Union[str, UUID]] = None,
41704170
model_version_id: Optional[Union[str, UUID]] = None,
4171+
linked_to_model_version_id: Optional[Union[str, UUID]] = None,
41714172
orchestrator_run_id: Optional[str] = None,
41724173
status: Optional[str] = None,
41734174
start_time: Optional[Union[datetime, str]] = None,
@@ -4210,6 +4211,11 @@ def list_pipeline_runs(
42104211
template_id: The ID of the template to filter by.
42114212
source_snapshot_id: The ID of the source snapshot to filter by.
42124213
model_version_id: The ID of the model version to filter by.
4214+
linked_to_model_version_id: Filter by model version linked to the
4215+
pipeline run. The difference to `model_version_id` is that this
4216+
filter will not only include pipeline runs which are directly
4217+
linked to the model version, but also if any step run is linked
4218+
to the model version.
42134219
orchestrator_run_id: The run id of the orchestrator to filter by.
42144220
name: The name of the run to filter by.
42154221
status: The status of the pipeline run
@@ -4256,6 +4262,7 @@ def list_pipeline_runs(
42564262
template_id=template_id,
42574263
source_snapshot_id=source_snapshot_id,
42584264
model_version_id=model_version_id,
4265+
linked_to_model_version_id=linked_to_model_version_id,
42594266
orchestrator_run_id=orchestrator_run_id,
42604267
stack_id=stack_id,
42614268
status=status,

src/zenml/models/v2/core/pipeline_run.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,7 @@ class PipelineRunFilter(
636636
"pipeline_name",
637637
"templatable",
638638
"triggered_by_step_run_id",
639+
"linked_to_model_version_id",
639640
]
640641
CLI_EXCLUDE_FIELDS = [
641642
*ProjectScopedFilter.CLI_EXCLUDE_FIELDS,
@@ -702,6 +703,14 @@ class PipelineRunFilter(
702703
description="Model version associated with the pipeline run.",
703704
union_mode="left_to_right",
704705
)
706+
linked_to_model_version_id: Optional[Union[UUID, str]] = Field(
707+
default=None,
708+
description="Filter by model version linked to the pipeline run. "
709+
"The difference to `model_version_id` is that this filter will "
710+
"not only include pipeline runs which are directly linked to the model "
711+
"version, but also if any step run is linked to the model version.",
712+
union_mode="left_to_right",
713+
)
705714
status: Optional[str] = Field(
706715
default=None,
707716
description="Name of the Pipeline Run",
@@ -777,6 +786,7 @@ def get_custom_filters(
777786
CodeReferenceSchema,
778787
CodeRepositorySchema,
779788
ModelSchema,
789+
ModelVersionPipelineRunSchema,
780790
ModelVersionSchema,
781791
PipelineBuildSchema,
782792
PipelineRunSchema,
@@ -960,6 +970,20 @@ def get_custom_filters(
960970
)
961971
custom_filters.append(trigger_filter)
962972

973+
if self.linked_to_model_version_id:
974+
linked_to_model_version_filter = and_(
975+
PipelineRunSchema.id
976+
== ModelVersionPipelineRunSchema.pipeline_run_id,
977+
ModelVersionPipelineRunSchema.model_version_id
978+
== ModelVersionSchema.id,
979+
self.generate_custom_query_conditions_for_column(
980+
value=self.linked_to_model_version_id,
981+
table=ModelVersionSchema,
982+
column="id",
983+
),
984+
)
985+
custom_filters.append(linked_to_model_version_filter)
986+
963987
return custom_filters
964988

965989
def apply_sorting(

tests/integration/functional/zen_stores/test_zen_store.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5075,6 +5075,37 @@ def test_link_create_duplicated(self):
50755075
)
50765076
assert link_1.id == link_2.id
50775077

5078+
def test_linked_pipeline_run_fetching(self):
5079+
"""Test that model version pipeline run links can be fetched using
5080+
the list_pipeline_runs method.
5081+
"""
5082+
with ModelContext(True, create_prs=2) as (
5083+
model_version,
5084+
prs,
5085+
):
5086+
zs = Client().zen_store
5087+
zs.create_model_version_pipeline_run_link(
5088+
ModelVersionPipelineRunRequest(
5089+
model_version=model_version.id,
5090+
pipeline_run=prs[0].id,
5091+
)
5092+
)
5093+
zs.create_model_version_pipeline_run_link(
5094+
ModelVersionPipelineRunRequest(
5095+
model_version=model_version.id,
5096+
pipeline_run=prs[1].id,
5097+
)
5098+
)
5099+
5100+
assert (
5101+
Client()
5102+
.list_pipeline_runs(
5103+
linked_to_model_version_id=model_version.id
5104+
)
5105+
.total
5106+
== 2
5107+
)
5108+
50785109
def test_link_delete_found(self):
50795110
with ModelContext(True, create_prs=1) as (
50805111
model_version,

0 commit comments

Comments
 (0)