diff --git a/src/zenml/client.py b/src/zenml/client.py index 995f2d8bdb3..8006a4285c7 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -2663,11 +2663,13 @@ def list_builds( user_id: Optional[Union[str, UUID]] = None, pipeline_id: Optional[Union[str, UUID]] = None, stack_id: Optional[Union[str, UUID]] = None, + container_registry_id: Optional[Union[UUID, str]] = None, is_local: Optional[bool] = None, contains_code: Optional[bool] = None, zenml_version: Optional[str] = None, python_version: Optional[str] = None, checksum: Optional[str] = None, + stack_checksum: Optional[str] = None, hydrate: bool = False, ) -> Page[PipelineBuildResponse]: """List all builds. @@ -2684,11 +2686,14 @@ def list_builds( user_id: The id of the user to filter by. pipeline_id: The id of the pipeline to filter by. stack_id: The id of the stack to filter by. + container_registry_id: The id of the container registry to + filter by. is_local: Use to filter local builds. contains_code: Use to filter builds that contain code. zenml_version: The version of ZenML to filter by. python_version: The Python version to filter by. checksum: The build checksum to filter by. + stack_checksum: The stack checksum to filter by. hydrate: Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. @@ -2707,11 +2712,13 @@ def list_builds( user_id=user_id, pipeline_id=pipeline_id, stack_id=stack_id, + container_registry_id=container_registry_id, is_local=is_local, contains_code=contains_code, zenml_version=zenml_version, python_version=python_version, checksum=checksum, + stack_checksum=stack_checksum, ) build_filter_model.set_scope_workspace(self.active_workspace.id) return self.zen_store.list_builds( diff --git a/src/zenml/models/v2/core/pipeline_build.py b/src/zenml/models/v2/core/pipeline_build.py index 3cb6dcb4e47..93c0ff63a8c 100644 --- a/src/zenml/models/v2/core/pipeline_build.py +++ b/src/zenml/models/v2/core/pipeline_build.py @@ -14,7 +14,7 @@ """Models representing pipeline builds.""" import json -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union from uuid import UUID from pydantic import Field @@ -31,6 +31,8 @@ from zenml.models.v2.misc.build_item import BuildItem if TYPE_CHECKING: + from sqlalchemy.sql.elements import ColumnElement + from zenml.models.v2.core.pipeline import PipelineResponse from zenml.models.v2.core.stack import StackResponse @@ -446,6 +448,11 @@ def contains_code(self) -> bool: class PipelineBuildFilter(WorkspaceScopedFilter): """Model to enable advanced filtering of all pipeline builds.""" + FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + *WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS, + "container_registry_id", + ] + workspace_id: Optional[Union[UUID, str]] = Field( description="Workspace for this pipeline build.", default=None, @@ -462,7 +469,12 @@ class PipelineBuildFilter(WorkspaceScopedFilter): union_mode="left_to_right", ) stack_id: Optional[Union[UUID, str]] = Field( - description="Stack used for the Pipeline Run", + description="Stack associated with the pipeline build.", + default=None, + union_mode="left_to_right", + ) + container_registry_id: Optional[Union[UUID, str]] = Field( + description="Container registry associated with the pipeline build.", default=None, union_mode="left_to_right", ) @@ -484,3 +496,39 @@ class PipelineBuildFilter(WorkspaceScopedFilter): checksum: Optional[str] = Field( description="The build checksum.", default=None ) + stack_checksum: Optional[str] = Field( + description="The stack checksum.", default=None + ) + + def get_custom_filters( + self, + ) -> List["ColumnElement[bool]"]: + """Get custom filters. + + Returns: + A list of custom filters. + """ + custom_filters = super().get_custom_filters() + + from sqlmodel import and_ + + from zenml.enums import StackComponentType + from zenml.zen_stores.schemas import ( + PipelineBuildSchema, + StackComponentSchema, + StackCompositionSchema, + StackSchema, + ) + + if self.container_registry_id: + container_registry_filter = and_( + PipelineBuildSchema.stack_id == StackSchema.id, + StackSchema.id == StackCompositionSchema.stack_id, + StackCompositionSchema.component_id == StackComponentSchema.id, + StackComponentSchema.type + == StackComponentType.CONTAINER_REGISTRY.value, + StackComponentSchema.id == self.container_registry_id, + ) + custom_filters.append(container_registry_filter) + + return custom_filters diff --git a/src/zenml/pipelines/build_utils.py b/src/zenml/pipelines/build_utils.py index eacbd1d07da..810f8d5f177 100644 --- a/src/zenml/pipelines/build_utils.py +++ b/src/zenml/pipelines/build_utils.py @@ -249,6 +249,11 @@ def find_existing_build( client = Client() stack = client.active_stack + if not stack.container_registry: + # There can be no non-local builds that we can reuse if there is no + # container registry in the stack. + return None + python_version_prefix = ".".join(platform.python_version_tuple()[:2]) required_builds = stack.get_docker_builds(deployment=deployment) @@ -263,6 +268,13 @@ def find_existing_build( sort_by="desc:created", size=1, stack_id=stack.id, + # Until we implement stack versioning, users can still update their + # stack to update/remove the container registry. In that case, we might + # try to pull an image from a container registry that we don't have + # access to. This is why we add an additional check for the container + # registry ID here. (This is still not perfect as users can update the + # container registry URI or config, but the best we can do) + container_registry_id=stack.container_registry.id, # The build is local and it's not clear whether the images # exist on the current machine or if they've been overwritten. # TODO: Should we support this by storing the unique Docker ID for diff --git a/tests/unit/pipelines/test_build_utils.py b/tests/unit/pipelines/test_build_utils.py index 73684af306d..de278fac778 100644 --- a/tests/unit/pipelines/test_build_utils.py +++ b/tests/unit/pipelines/test_build_utils.py @@ -518,7 +518,9 @@ def test_local_repo_verification( assert isinstance(code_repo, StubCodeRepository) -def test_finding_existing_build(mocker, sample_deployment_response_model): +def test_finding_existing_build( + mocker, sample_deployment_response_model, remote_container_registry +): """Tests finding an existing build.""" mock_list_builds = mocker.patch( "zenml.client.Client.list_builds", @@ -551,14 +553,30 @@ def test_finding_existing_build(mocker, sample_deployment_response_model): ], ) + build_utils.find_existing_build( + deployment=sample_deployment_response_model, + code_repository=StubCodeRepository(), + ) + # No container registry -> no non-local build to pull + mock_list_builds.assert_not_called() + + mocker.patch.object( + Stack, + "container_registry", + new_callable=mocker.PropertyMock, + return_value=remote_container_registry, + ) + build = build_utils.find_existing_build( deployment=sample_deployment_response_model, code_repository=StubCodeRepository(), ) + mock_list_builds.assert_called_once_with( sort_by="desc:created", size=1, stack_id=Client().active_stack.id, + container_registry_id=remote_container_registry.id, is_local=False, contains_code=False, zenml_version=zenml.__version__,