Skip to content

Fix input resolution for steps with dynamic artifact names #3228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/zenml/models/v2/core/pipeline_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,15 @@ def is_templatable(self) -> bool:
"""
return self.get_metadata().is_templatable

@property
def step_substitutions(self) -> Dict[str, Dict[str, str]]:
"""The `step_substitutions` property.

Returns:
the value of the property.
"""
return self.get_metadata().step_substitutions

@property
def model_version(self) -> Optional[ModelVersionResponse]:
"""The `model_version` property.
Expand Down
25 changes: 19 additions & 6 deletions src/zenml/orchestrators/input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from zenml.config.step_configurations import Step
from zenml.enums import ArtifactSaveType, StepRunInputArtifactType
from zenml.exceptions import InputResolutionError
from zenml.utils import pagination_utils
from zenml.utils import pagination_utils, string_utils

if TYPE_CHECKING:
from zenml.models import PipelineRunResponse
Expand Down Expand Up @@ -53,7 +53,8 @@ def resolve_step_inputs(
current_run_steps = {
run_step.name: run_step
for run_step in pagination_utils.depaginate(
Client().list_run_steps, pipeline_run_id=pipeline_run.id
Client().list_run_steps,
pipeline_run_id=pipeline_run.id,
)
}

Expand All @@ -66,11 +67,23 @@ def resolve_step_inputs(
f"No step `{input_.step_name}` found in current run."
)

# Try to get the substitutions from the pipeline run first, as we
# already have a hydrated version of that. In the unlikely case
# that the pipeline run is outdated, we fetch it from the step
# run instead which will costs us one hydration call.
substitutions = (
pipeline_run.step_substitutions.get(step_run.name)
or step_run.config.substitutions
)
output_name = string_utils.format_name_template(
input_.output_name, substitutions=substitutions
)

try:
outputs = step_run.outputs[input_.output_name]
outputs = step_run.outputs[output_name]
except KeyError:
raise InputResolutionError(
f"No step output `{input_.output_name}` found for step "
f"No step output `{output_name}` found for step "
f"`{input_.step_name}`."
)

Expand All @@ -83,12 +96,12 @@ def resolve_step_inputs(
# This should never happen, there can only be a single regular step
# output for a name
raise InputResolutionError(
f"Too many step outputs for output `{input_.output_name}` of "
f"Too many step outputs for output `{output_name}` of "
f"step `{input_.step_name}`."
)
elif len(step_outputs) == 0:
raise InputResolutionError(
f"No step output `{input_.output_name}` found for step "
f"No step output `{output_name}` found for step "
f"`{input_.step_name}`."
)

Expand Down
3 changes: 3 additions & 0 deletions src/zenml/orchestrators/step_run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,9 @@ def create_cached_step_runs(
for invocation_id in cache_candidates:
visited_invocations.add(invocation_id)

# Make sure the request factory has the most up to date pipeline
# run to avoid hydration calls
request_factory.pipeline_run = pipeline_run
try:
step_run_request = request_factory.create_request(
invocation_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# or implied. See the License for the specific language governing
# permissions and limitations under the License.

from contextlib import ExitStack as does_not_raise
from typing import Callable, Tuple

import pytest
Expand Down Expand Up @@ -122,6 +123,11 @@ def mixed_with_unannotated_returns() -> (
)


@step
def step_with_string_input(input_: str) -> None:
pass


@pytest.mark.parametrize(
"step",
[
Expand Down Expand Up @@ -362,3 +368,17 @@ def _inner(pass_to_step: str = ""):
assert p2_step_subs["date"] == "step_level"
assert p1_step_subs["funny_name"] == "pipeline_level"
assert p2_step_subs["funny_name"] == "step_level"


def test_dynamically_named_artifacts_in_downstream_steps(
clean_client: "Client",
):
"""Test that dynamically named artifacts can be used in downstream steps."""

@pipeline(enable_cache=False)
def _inner(ret: str):
artifact = dynamic_single_string_standard()
step_with_string_input(artifact)

with does_not_raise():
_inner("output_1")
13 changes: 13 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
from collections import defaultdict
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional
from uuid import uuid4
Expand Down Expand Up @@ -416,6 +417,12 @@ def sample_pipeline_run(
sample_workspace_model: WorkspaceResponse,
) -> PipelineRunResponse:
"""Return sample pipeline run view for testing purposes."""
now = datetime.utcnow()
substitutions = {
"date": now.strftime("%Y_%m_%d"),
"time": now.strftime("%H_%M_%S_%f"),
}

return PipelineRunResponse(
id=uuid4(),
name="sample_run_name",
Expand All @@ -430,6 +437,7 @@ def sample_pipeline_run(
workspace=sample_workspace_model,
config=PipelineConfiguration(name="aria_pipeline"),
is_templatable=False,
steps_substitutions=defaultdict(lambda: substitutions.copy()),
),
resources=PipelineRunResponseResources(tags=[]),
)
Expand Down Expand Up @@ -543,10 +551,15 @@ def f(
spec = StepSpec.model_validate(
{"source": "module.step_class", "upstream_steps": []}
)
now = datetime.utcnow()
config = StepConfiguration.model_validate(
{
"name": step_name,
"outputs": outputs or {},
"substitutions": {
"date": now.strftime("%Y_%m_%d"),
"time": now.strftime("%H_%M_%S_%f"),
},
}
)
return StepRunResponse(
Expand Down
20 changes: 11 additions & 9 deletions tests/unit/orchestrators/test_input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
# or implied. See the License for the specific language governing
# permissions and limitations under the License.

from uuid import uuid4

import pytest

from zenml.config.step_configurations import Step
from zenml.enums import StepRunInputArtifactType
from zenml.exceptions import InputResolutionError
from zenml.models import Page, PipelineRunResponse
from zenml.models import Page
from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
from zenml.models.v2.core.step_run import StepRunInputResponse
from zenml.orchestrators import input_utils
Expand All @@ -29,6 +28,7 @@ def test_input_resolution(
mocker,
sample_artifact_version_model: ArtifactVersionResponse,
create_step_run,
sample_pipeline_run,
):
"""Tests that input resolution works if the correct models exist in the
zen store."""
Expand Down Expand Up @@ -60,7 +60,7 @@ def test_input_resolution(
)

input_artifacts, parent_ids = input_utils.resolve_step_inputs(
step=step, pipeline_run=PipelineRunResponse(id=uuid4(), name="foo")
step=step, pipeline_run=sample_pipeline_run
)
assert input_artifacts == {
"input_name": StepRunInputResponse(
Expand All @@ -71,7 +71,7 @@ def test_input_resolution(
assert parent_ids == [step_run.id]


def test_input_resolution_with_missing_step_run(mocker):
def test_input_resolution_with_missing_step_run(mocker, sample_pipeline_run):
"""Tests that input resolution fails if the upstream step run is missing."""
mocker.patch(
"zenml.zen_stores.sql_zen_store.SqlZenStore.list_run_steps",
Expand All @@ -97,11 +97,13 @@ def test_input_resolution_with_missing_step_run(mocker):

with pytest.raises(InputResolutionError):
input_utils.resolve_step_inputs(
step=step, pipeline_run=PipelineRunResponse(id=uuid4(), name="foo")
step=step, pipeline_run=sample_pipeline_run
)


def test_input_resolution_with_missing_artifact(mocker, create_step_run):
def test_input_resolution_with_missing_artifact(
mocker, create_step_run, sample_pipeline_run
):
"""Tests that input resolution fails if the upstream step run output
artifact is missing."""
step_run = create_step_run(
Expand Down Expand Up @@ -132,12 +134,12 @@ def test_input_resolution_with_missing_artifact(mocker, create_step_run):

with pytest.raises(InputResolutionError):
input_utils.resolve_step_inputs(
step=step, pipeline_run=PipelineRunResponse(id=uuid4(), name="foo")
step=step, pipeline_run=sample_pipeline_run
)


def test_input_resolution_fetches_all_run_steps(
mocker, sample_artifact_version_model, create_step_run
mocker, sample_artifact_version_model, create_step_run, sample_pipeline_run
):
"""Tests that input resolution fetches all step runs of the pipeline run."""
step_run = create_step_run(
Expand Down Expand Up @@ -178,7 +180,7 @@ def test_input_resolution_fetches_all_run_steps(
)

input_utils.resolve_step_inputs(
step=step, pipeline_run=PipelineRunResponse(id=uuid4(), name="foo")
step=step, pipeline_run=sample_pipeline_run
)

# `resolve_step_inputs(...)` depaginates the run steps so we fetch all
Expand Down
Loading