Skip to content

Commit e9c1d2c

Browse files
authored
Add filter option for templatable runs (#3000)
* Add filter option for templatable runs * Apply some feedback * Fix mypy * Fix unit test * Formatting
1 parent 3ffc2d5 commit e9c1d2c

File tree

6 files changed

+67
-1
lines changed

6 files changed

+67
-1
lines changed

src/zenml/client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3748,6 +3748,7 @@ def list_pipeline_runs(
37483748
end_time: Optional[Union[datetime, str]] = None,
37493749
num_steps: Optional[Union[int, str]] = None,
37503750
unlisted: Optional[bool] = None,
3751+
templatable: Optional[bool] = None,
37513752
tag: Optional[str] = None,
37523753
hydrate: bool = False,
37533754
) -> Page[PipelineRunResponse]:
@@ -3778,6 +3779,7 @@ def list_pipeline_runs(
37783779
end_time: The end_time for the pipeline run
37793780
num_steps: The number of steps for the pipeline run
37803781
unlisted: If the runs should be unlisted or not.
3782+
templatable: If the runs should be templatable or not.
37813783
tag: Tag to filter by.
37823784
hydrate: Flag deciding whether to hydrate the output model(s)
37833785
by including metadata fields in the response.
@@ -3811,6 +3813,7 @@ def list_pipeline_runs(
38113813
num_steps=num_steps,
38123814
tag=tag,
38133815
unlisted=unlisted,
3816+
templatable=templatable,
38143817
)
38153818
runs_filter_model.set_scope_workspace(self.active_workspace.id)
38163819
return self.zen_store.list_runs(

src/zenml/models/v2/core/pipeline_run.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,9 @@ class PipelineRunResponseMetadata(WorkspaceScopedResponseMetadata):
235235
default=None,
236236
description="Template used for the pipeline run.",
237237
)
238+
is_templatable: bool = Field(
239+
description="Whether a template can be created from this run.",
240+
)
238241

239242

240243
class PipelineRunResponseResources(WorkspaceScopedResponseResources):
@@ -477,6 +480,15 @@ def template_id(self) -> Optional[UUID]:
477480
"""
478481
return self.get_metadata().template_id
479482

483+
@property
484+
def is_templatable(self) -> bool:
485+
"""The `is_templatable` property.
486+
487+
Returns:
488+
the value of the property.
489+
"""
490+
return self.get_metadata().is_templatable
491+
480492
@property
481493
def model_version(self) -> Optional[ModelVersionResponse]:
482494
"""The `model_version` property.
@@ -511,6 +523,7 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter):
511523
"stack_id",
512524
"template_id",
513525
"pipeline_name",
526+
"templatable",
514527
]
515528
name: Optional[str] = Field(
516529
default=None,
@@ -584,6 +597,7 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter):
584597
union_mode="left_to_right",
585598
)
586599
unlisted: Optional[bool] = None
600+
templatable: Optional[bool] = None
587601

588602
def get_custom_filters(
589603
self,
@@ -595,7 +609,7 @@ def get_custom_filters(
595609
"""
596610
custom_filters = super().get_custom_filters()
597611

598-
from sqlmodel import and_
612+
from sqlmodel import and_, col, or_
599613

600614
from zenml.zen_stores.schemas import (
601615
CodeReferenceSchema,
@@ -668,4 +682,40 @@ def get_custom_filters(
668682
)
669683
custom_filters.append(run_template_filter)
670684

685+
if self.templatable is not None:
686+
if self.templatable is True:
687+
templatable_filter = and_(
688+
# The following condition is not perfect as it does not
689+
# consider stacks with custom flavor components or local
690+
# components, but the best we can do currently with our
691+
# table columns.
692+
PipelineRunSchema.deployment_id
693+
== PipelineDeploymentSchema.id,
694+
PipelineDeploymentSchema.build_id
695+
== PipelineBuildSchema.id,
696+
col(PipelineBuildSchema.is_local).is_(False),
697+
col(PipelineBuildSchema.stack_id).is_not(None),
698+
)
699+
else:
700+
templatable_filter = or_(
701+
col(PipelineRunSchema.deployment_id).is_(None),
702+
and_(
703+
PipelineRunSchema.deployment_id
704+
== PipelineDeploymentSchema.id,
705+
col(PipelineDeploymentSchema.build_id).is_(None),
706+
),
707+
and_(
708+
PipelineRunSchema.deployment_id
709+
== PipelineDeploymentSchema.id,
710+
PipelineDeploymentSchema.build_id
711+
== PipelineBuildSchema.id,
712+
or_(
713+
col(PipelineBuildSchema.is_local).is_(True),
714+
col(PipelineBuildSchema.stack_id).is_(None),
715+
),
716+
),
717+
)
718+
719+
custom_filters.append(templatable_filter)
720+
671721
return custom_filters

src/zenml/zen_stores/schemas/pipeline_run_schemas.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,15 @@ def to_model(
328328
)
329329
metadata = None
330330
if include_metadata:
331+
is_templatable = False
332+
if (
333+
self.deployment
334+
and self.deployment.build
335+
and not self.deployment.build.is_local
336+
and self.deployment.build.stack
337+
):
338+
is_templatable = True
339+
331340
steps = {step.name: step.to_model() for step in self.step_runs}
332341

333342
metadata = PipelineRunResponseMetadata(
@@ -346,6 +355,7 @@ def to_model(
346355
template_id=self.deployment.template_id
347356
if self.deployment
348357
else None,
358+
is_templatable=is_templatable,
349359
)
350360

351361
resources = None

src/zenml/zen_stores/schemas/run_template_schemas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def to_model(
177177
if (
178178
self.source_deployment
179179
and self.source_deployment.build
180+
and not self.source_deployment.build.is_local
180181
and self.source_deployment.build.stack
181182
):
182183
runnable = True

src/zenml/zen_stores/sql_zen_store.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,7 @@ def filter_and_paginate(
974974
RuntimeError: if the schema does not have a `to_model` method.
975975
"""
976976
query = filter_model.apply_filter(query=query, table=table)
977+
query = query.distinct()
977978

978979
# Get the total amount of items in the database for a given query
979980
custom_fetch_result: Optional[Sequence[Any]] = None

tests/unit/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,7 @@ def sample_pipeline_run(
447447
metadata=PipelineRunResponseMetadata(
448448
workspace=sample_workspace_model,
449449
config=PipelineConfiguration(name="aria_pipeline"),
450+
is_templatable=False,
450451
),
451452
)
452453

0 commit comments

Comments
 (0)