Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed and improved sorting #3266

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
)
FILTERING_DATETIME_FORMAT: str = "%Y-%m-%d %H:%M:%S"
SORT_PIPELINES_BY_LATEST_RUN_KEY = "latest_run"
SORT_BY_LATEST_VERSION_KEY = "latest_version"

# Metadata constants
METADATA_ORCHESTRATOR_URL = "orchestrator_url"
Expand Down
60 changes: 32 additions & 28 deletions src/zenml/models/v2/base/scoped.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ def apply_sorting(
UserSchema, getattr(table, "user_id") == UserSchema.id
)

query = query.add_columns(UserSchema.name)
schustmi marked this conversation as resolved.
Show resolved Hide resolved

if operand == SorterOps.ASCENDING:
query = query.order_by(asc(column))
else:
Expand Down Expand Up @@ -449,6 +451,8 @@ def apply_sorting(
getattr(table, "workspace_id") == WorkspaceSchema.id,
)

query = query.add_columns(WorkspaceSchema.name)

if operand == SorterOps.ASCENDING:
query = query.order_by(asc(column))
else:
Expand All @@ -470,10 +474,9 @@ class WorkspaceScopedTaggableFilter(WorkspaceScopedFilter):
*WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS,
"tag",
]

CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [
*WorkspaceScopedFilter.CUSTOM_SORTING_OPTIONS,
"tag",
"tags",
]

def apply_filter(
Expand Down Expand Up @@ -540,8 +543,8 @@ def apply_sorting(
"""
sort_by, operand = self.sorting_params

if sort_by == "tag":
from sqlmodel import and_, asc, desc, func
if sort_by == "tags":
from sqlmodel import asc, desc, func, select

from zenml.enums import SorterOps, TaggableResourceTypes
from zenml.zen_stores.schemas import (
Expand All @@ -566,35 +569,36 @@ def apply_sorting(
RunTemplateSchema: TaggableResourceTypes.RUN_TEMPLATE,
}

query = (
query.outerjoin(
TagResourceSchema,
and_(
table.id == TagResourceSchema.resource_id,
TagResourceSchema.resource_type
== resource_type_mapping[table],
),
sorted_tags = (
select(TagResourceSchema.resource_id, TagSchema.name)
.join(TagSchema, TagResourceSchema.tag_id == TagSchema.id) # type: ignore[arg-type]
.filter(
TagResourceSchema.resource_type # type: ignore[arg-type]
== resource_type_mapping[table]
)
.order_by(
asc(TagResourceSchema.resource_id), asc(TagSchema.name)
)
.outerjoin(TagSchema, TagResourceSchema.tag_id == TagSchema.id)
.group_by(table.id)
).alias("sorted_tags")

tags_subquery = (
select(
sorted_tags.c.resource_id,
func.group_concat(sorted_tags.c.name, ", ").label(
"tags_list"
),
).group_by(sorted_tags.c.resource_id)
).alias("tags_subquery")

query = query.add_columns(tags_subquery.c.tags_list).outerjoin(
tags_subquery, table.id == tags_subquery.c.resource_id
)

# Apply ordering based on the tags list
if operand == SorterOps.ASCENDING:
query = query.order_by(
asc(
func.group_concat(TagSchema.name, ",").label(
"tags_list"
)
)
)
query = query.order_by(asc("tags_list"))
else:
query = query.order_by(
desc(
func.group_concat(TagSchema.name, ",").label(
"tags_list"
)
)
)
query = query.order_by(desc("tags_list"))

return query

Expand Down
88 changes: 86 additions & 2 deletions src/zenml/models/v2/core/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,21 @@
# permissions and limitations under the License.
"""Models representing artifacts."""

from typing import TYPE_CHECKING, Dict, List, Optional
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
List,
Optional,
Type,
TypeVar,
)
from uuid import UUID

from pydantic import BaseModel, Field

from zenml.constants import STR_FIELD_MAX_LENGTH
from zenml.constants import SORT_BY_LATEST_VERSION_KEY, STR_FIELD_MAX_LENGTH
from zenml.models.v2.base.base import (
BaseDatedResponseBody,
BaseIdentifiedResponse,
Expand All @@ -31,6 +40,11 @@

if TYPE_CHECKING:
from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
from zenml.zen_stores.schemas import BaseSchema

AnySchema = TypeVar("AnySchema", bound=BaseSchema)

AnyQuery = TypeVar("AnyQuery", bound=Any)

# ------------------ Request Model ------------------

Expand Down Expand Up @@ -174,3 +188,73 @@ class ArtifactFilter(WorkspaceScopedTaggableFilter):

name: Optional[str] = None
has_custom_name: Optional[bool] = None

CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [
*WorkspaceScopedTaggableFilter.CUSTOM_SORTING_OPTIONS,
SORT_BY_LATEST_VERSION_KEY,
]

def apply_sorting(
self,
query: AnyQuery,
table: Type["AnySchema"],
) -> AnyQuery:
"""Apply sorting to the query for Artifacts.

Args:
query: The query to which to apply the sorting.
table: The query table.

Returns:
The query with sorting applied.
"""
from sqlmodel import asc, case, col, desc, func, select

from zenml.enums import SorterOps
from zenml.zen_stores.schemas import (
ArtifactSchema,
ArtifactVersionSchema,
)

sort_by, operand = self.sorting_params

if sort_by == SORT_BY_LATEST_VERSION_KEY:
# Subquery to find the latest version per artifact
latest_version_subquery = (
select(
ArtifactSchema.id,
case(
(
func.max(ArtifactVersionSchema.created).is_(None),
ArtifactSchema.created,
),
else_=func.max(ArtifactVersionSchema.created),
).label("latest_version_created"),
)
.outerjoin(
ArtifactVersionSchema,
ArtifactSchema.id == ArtifactVersionSchema.artifact_id, # type: ignore[arg-type]
)
.group_by(col(ArtifactSchema.id))
.subquery()
)

query = query.add_columns(
latest_version_subquery.c.latest_version_created,
).where(ArtifactSchema.id == latest_version_subquery.c.id)

# Apply sorting based on the operand
if operand == SorterOps.ASCENDING:
query = query.order_by(
asc(latest_version_subquery.c.latest_version_created),
asc(ArtifactSchema.id),
)
else:
query = query.order_by(
desc(latest_version_subquery.c.latest_version_created),
desc(ArtifactSchema.id),
)
return query

# For other sorting cases, delegate to the parent class
return super().apply_sorting(query=query, table=table)
83 changes: 81 additions & 2 deletions src/zenml/models/v2/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@
# permissions and limitations under the License.
"""Models representing models."""

from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Any, ClassVar, List, Optional, Type, TypeVar
from uuid import UUID

from pydantic import BaseModel, Field

from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
from zenml.constants import (
SORT_BY_LATEST_VERSION_KEY,
STR_FIELD_MAX_LENGTH,
TEXT_FIELD_MAX_LENGTH,
)
from zenml.models.v2.base.scoped import (
WorkspaceScopedRequest,
WorkspaceScopedResponse,
Expand All @@ -32,6 +36,11 @@
if TYPE_CHECKING:
from zenml.model.model import Model
from zenml.models.v2.core.tag import TagResponse
from zenml.zen_stores.schemas import BaseSchema

AnySchema = TypeVar("AnySchema", bound=BaseSchema)

AnyQuery = TypeVar("AnyQuery", bound=Any)

# ------------------ Request Model ------------------

Expand Down Expand Up @@ -320,3 +329,73 @@ class ModelFilter(WorkspaceScopedTaggableFilter):
default=None,
description="Name of the Model",
)

CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [
*WorkspaceScopedTaggableFilter.CUSTOM_SORTING_OPTIONS,
SORT_BY_LATEST_VERSION_KEY,
]

def apply_sorting(
self,
query: AnyQuery,
table: Type["AnySchema"],
) -> AnyQuery:
"""Apply sorting to the query for Models.

Args:
query: The query to which to apply the sorting.
table: The query table.

Returns:
The query with sorting applied.
"""
from sqlmodel import asc, case, col, desc, func, select

from zenml.enums import SorterOps
from zenml.zen_stores.schemas import (
ModelSchema,
ModelVersionSchema,
)

sort_by, operand = self.sorting_params

if sort_by == SORT_BY_LATEST_VERSION_KEY:
# Subquery to find the latest version per model
latest_version_subquery = (
select(
ModelSchema.id,
case(
(
func.max(ModelVersionSchema.created).is_(None),
ModelSchema.created,
),
else_=func.max(ModelVersionSchema.created),
).label("latest_version_created"),
)
.outerjoin(
ModelVersionSchema,
ModelSchema.id == ModelVersionSchema.model_id, # type: ignore[arg-type]
)
.group_by(col(ModelSchema.id))
.subquery()
)

query = query.add_columns(
latest_version_subquery.c.latest_version_created,
).where(ModelSchema.id == latest_version_subquery.c.id)

# Apply sorting based on the operand
if operand == SorterOps.ASCENDING:
query = query.order_by(
asc(latest_version_subquery.c.latest_version_created),
asc(ModelSchema.id),
)
else:
query = query.order_by(
desc(latest_version_subquery.c.latest_version_created),
desc(ModelSchema.id),
)
return query

# For other sorting cases, delegate to the parent class
return super().apply_sorting(query=query, table=table)
27 changes: 15 additions & 12 deletions src/zenml/models/v2/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def apply_sorting(
# Subquery to find the latest run per pipeline
latest_run_subquery = (
select(
PipelineRunSchema.pipeline_id,
PipelineSchema.id,
case(
(
func.max(PipelineRunSchema.created).is_(None),
Expand All @@ -366,25 +366,28 @@ def apply_sorting(
else_=func.max(PipelineRunSchema.created),
).label("latest_run"),
)
.group_by(col(PipelineRunSchema.pipeline_id))
.outerjoin(
PipelineRunSchema,
PipelineSchema.id == PipelineRunSchema.pipeline_id, # type: ignore[arg-type]
)
.group_by(col(PipelineSchema.id))
.subquery()
)

# Join the subquery with the pipelines
query = query.outerjoin(
latest_run_subquery,
PipelineSchema.id == latest_run_subquery.c.pipeline_id,
)
query = query.add_columns(
latest_run_subquery.c.latest_run,
).where(PipelineSchema.id == latest_run_subquery.c.id)

if operand == SorterOps.ASCENDING:
query = query.order_by(
asc(latest_run_subquery.c.latest_run)
).order_by(col(PipelineSchema.id))
asc(latest_run_subquery.c.latest_run),
asc(PipelineSchema.id),
)
else:
query = query.order_by(
desc(latest_run_subquery.c.latest_run)
).order_by(col(PipelineSchema.id))

desc(latest_run_subquery.c.latest_run),
desc(PipelineSchema.id),
)
return query
else:
return super().apply_sorting(query=query, table=table)
2 changes: 2 additions & 0 deletions src/zenml/models/v2/core/pipeline_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,8 @@ def apply_sorting(
else:
return super().apply_sorting(query=query, table=table)

query = query.add_columns(column)

if operand == SorterOps.ASCENDING:
query = query.order_by(asc(column))
else:
Expand Down
7 changes: 7 additions & 0 deletions src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@
)
from sqlalchemy.orm import Mapped, noload
from sqlalchemy.util import immutabledict

# Important to note: The select function of SQLModel works slightly differently
# from the select function of sqlalchemy. If you input only one entity on the
# select function of SQLModel, it automatically maps it to a SelectOfScalar.
# As a result, it will not return a tuple as a result, but the first entity in
# the tuple. While this is convenient in most cases, in unique cases like using
# the "add_columns" functionality, one might encounter unexpected results.
from sqlmodel import (
Session,
SQLModel,
Expand Down
Loading
Loading