Skip to content

Better input artifacts typing #3099

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 67 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
9bdb5f1
multi-versioned outputs
avishniakov Oct 10, 2024
4dd7e90
more artifact save types
avishniakov Oct 10, 2024
eb7f691
fix for the llm template
avishniakov Oct 10, 2024
49f0f24
Auto-update of LLM Finetuning template
actions-user Oct 10, 2024
5d81b5d
Merge branch 'develop' into feature/PRD-663-multiple-output-versions-…
avishniakov Oct 10, 2024
ac9a7bd
Auto-update of Starter template
actions-user Oct 10, 2024
9607df5
Auto-update of E2E template
actions-user Oct 10, 2024
fc19dc3
fix tests notation
avishniakov Oct 10, 2024
9f04da1
Merge branch 'feature/PRD-663-multiple-output-versions-for-a-step' of…
avishniakov Oct 10, 2024
feb3f53
Refactor artifact saving logic to use save types
avishniakov Oct 17, 2024
5557837
Refactor artifact saving logic to use save types
avishniakov Oct 17, 2024
eac4aea
Remove unneeded TODO
avishniakov Oct 17, 2024
e62d8eb
Refactor artifact saving logic to use list instead of set for output …
avishniakov Oct 17, 2024
f4bf134
Merge branch 'develop' into feature/PRD-663-multiple-output-versions-…
avishniakov Oct 17, 2024
9a8cf20
Refactor artifact saving logic to use outputs instead of saved_artifa…
avishniakov Oct 17, 2024
ecb9ff0
Merge branch 'develop' into feature/PRD-663-multiple-output-versions-…
avishniakov Oct 17, 2024
5683081
Auto-update of LLM Finetuning template
actions-user Oct 17, 2024
f02fee1
Refactor artifact saving logic to use 'external' as the default save …
avishniakov Oct 18, 2024
0a2dfa1
Refactor artifact saving logic to use outputs instead of saved_artifa…
avishniakov Oct 18, 2024
77b6a79
Merge branch 'develop' into feature/PRD-663-multiple-output-versions-…
avishniakov Oct 18, 2024
27099b4
Refactor artifact saving logic to define input types. WIP
avishniakov Oct 18, 2024
a704f95
Refactor StepRunRequestFactory to correctly assign output artifacts
avishniakov Oct 18, 2024
e991fd5
Merge branch 'feature/PRD-663-multiple-output-versions-for-a-step' of…
avishniakov Oct 18, 2024
e7d7a62
Merge branch 'feature/PRD-663-multiple-output-versions-for-a-step' in…
avishniakov Oct 18, 2024
9ce83ee
Refactor StepRunRequestFactory to correctly assign output artifacts
avishniakov Oct 18, 2024
8930510
Refactor StepRunRequestFactory to correctly assign input artifact types
avishniakov Oct 18, 2024
fc81ecf
Refactor StepRunRequestFactory to correctly assign input artifact typ…
avishniakov Oct 18, 2024
dec188c
Add test case
avishniakov Oct 18, 2024
d036010
bugs, bugs
avishniakov Oct 18, 2024
a2d735c
mypy
avishniakov Oct 18, 2024
980fac6
Merge branch 'feature/PRD-663-multiple-output-versions-for-a-step' in…
avishniakov Oct 18, 2024
53f5657
mypy
avishniakov Oct 18, 2024
56a37a8
Remove unused arg
schustmi Oct 30, 2024
b5c8ad0
Improve input resolution
schustmi Oct 30, 2024
3d2443b
Rename save type
schustmi Oct 30, 2024
aef08fc
Only apply artifact config to step outputs
schustmi Oct 30, 2024
239122e
Merge branch 'develop' into feature/PRD-663-multiple-output-versions-…
schustmi Oct 30, 2024
eaf2958
Fix alembic order
schustmi Oct 30, 2024
fc0c141
Fix DB migration
schustmi Oct 30, 2024
54527dc
Merge branch 'feature/PRD-663-multiple-output-versions-for-a-step' in…
schustmi Oct 30, 2024
1b246d2
Formatting
schustmi Oct 30, 2024
65b1443
Migrate default input type
schustmi Oct 30, 2024
ce76a9c
Detect input type server side
schustmi Oct 30, 2024
8b51d47
Fix input type detection
schustmi Oct 30, 2024
5f1c7fc
Don't fail for cached step runs
schustmi Oct 30, 2024
1ecefd8
Merge branch 'develop' into feature/PRD-663-multiple-output-versions-…
schustmi Oct 31, 2024
9a39cd4
Fix alembic order
schustmi Oct 31, 2024
48b28f2
Merge branch 'feature/PRD-663-multiple-output-versions-for-a-step' in…
schustmi Oct 31, 2024
3193c20
Refactor artifact saving in cacheable_multiple_versioned_producer
avishniakov Nov 5, 2024
3797a07
Merge branch 'develop' into feature/PRD-663-multiple-output-versions-…
avishniakov Nov 5, 2024
e49ff02
Merge branch 'feature/PRD-663-multiple-output-versions-for-a-step' in…
avishniakov Nov 6, 2024
2986c7a
remove duplicated logic from `load_artifact`
avishniakov Nov 6, 2024
5b6bf57
add `StepRunInputResponse`
avishniakov Nov 6, 2024
f9c4455
add `StepRunInputResponse`
avishniakov Nov 6, 2024
e3e9320
Merge branch 'develop' into feature/PRD-663-multiple-output-versions-…
avishniakov Nov 6, 2024
5cf1089
Merge branch 'develop' into feature/PRD-663-multiple-output-versions-…
avishniakov Nov 6, 2024
ad7b3da
Merge branch 'feature/PRD-663-multiple-output-versions-for-a-step' in…
avishniakov Nov 6, 2024
58e75f2
Merge branch 'develop' into feature/PRD-663-multiple-output-versions-…
schustmi Nov 6, 2024
d0e2633
Merge branch 'develop' into feature/PRD-663-multiple-output-versions-…
avishniakov Nov 6, 2024
aeafc20
Merge branch 'feature/PRD-663-multiple-output-versions-for-a-step' in…
avishniakov Nov 6, 2024
3e3c32d
lint
avishniakov Nov 6, 2024
ca27fad
Merge branch 'develop' into feature/PRD-668-better-input-artifacts-ty…
avishniakov Nov 7, 2024
4b44d39
fix typing issues
avishniakov Nov 7, 2024
4b6c922
Refactor resolve_step_inputs function to use StepRunInputResponse for…
avishniakov Nov 7, 2024
e59793c
Merge branch 'develop' into feature/PRD-668-better-input-artifacts-ty…
avishniakov Nov 7, 2024
f23772c
fix issues caught in testing
avishniakov Nov 7, 2024
612e281
fix migration for mysql
avishniakov Nov 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions src/zenml/artifacts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,17 +387,6 @@ def load_artifact(
The loaded artifact.
"""
artifact = Client().get_artifact_version(name_or_id, version)
try:
step_run = get_step_context().step_run
client = Client()
client.zen_store.update_run_step(
step_run_id=step_run.id,
step_run_update=StepRunUpdate(
loaded_artifact_versions={artifact.name: artifact.id}
),
)
except RuntimeError:
pass # Cannot link to step run if called outside of a step
return load_artifact_from_response(artifact)


Expand Down
18 changes: 17 additions & 1 deletion src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@
WorkspaceResponse,
WorkspaceUpdate,
)
from zenml.models.v2.core.step_run import StepRunUpdate
from zenml.services.service import ServiceConfig
from zenml.services.service_status import ServiceState
from zenml.services.service_type import ServiceType
Expand Down Expand Up @@ -4164,20 +4165,35 @@ def get_artifact_version(
Returns:
The artifact version.
"""
from zenml import get_step_context

if cll := client_lazy_loader(
method_name="get_artifact_version",
name_id_or_prefix=name_id_or_prefix,
version=version,
hydrate=hydrate,
):
return cll # type: ignore[return-value]
return self._get_entity_version_by_id_or_name_or_prefix(

artifact = self._get_entity_version_by_id_or_name_or_prefix(
get_method=self.zen_store.get_artifact_version,
list_method=self.list_artifact_versions,
name_id_or_prefix=name_id_or_prefix,
version=version,
hydrate=hydrate,
)
try:
step_run = get_step_context().step_run
client = Client()
client.zen_store.update_run_step(
step_run_id=step_run.id,
step_run_update=StepRunUpdate(
loaded_artifact_versions={artifact.name: artifact.id}
),
)
except RuntimeError:
pass # Cannot link to step run if called outside of a step
return artifact

def list_artifact_versions(
self,
Expand Down
6 changes: 5 additions & 1 deletion src/zenml/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ class ArtifactType(StrEnum):
class StepRunInputArtifactType(StrEnum):
"""All possible types of a step run input artifact."""

DEFAULT = "default" # input argument that is the output of a previous step
STEP_OUTPUT = (
"step_output" # input argument that is the output of a previous step
)
MANUAL = "manual" # manually loaded via `zenml.load_artifact()`
EXTERNAL = "external" # loaded via `ExternalArtifact(value=...)`
LAZY_LOADED = "lazy" # loaded via various lazy methods


class ArtifactSaveType(StrEnum):
Expand Down
35 changes: 27 additions & 8 deletions src/zenml/models/v2/core/step_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from zenml.config.step_configurations import StepConfiguration, StepSpec
from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
from zenml.enums import ExecutionStatus
from zenml.enums import ExecutionStatus, StepRunInputArtifactType
from zenml.metadata.metadata_types import MetadataType
from zenml.models.v2.base.scoped import (
WorkspaceScopedFilter,
Expand All @@ -31,18 +31,37 @@
WorkspaceScopedResponseMetadata,
WorkspaceScopedResponseResources,
)
from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
from zenml.models.v2.core.model_version import ModelVersionResponse

if TYPE_CHECKING:
from sqlalchemy.sql.elements import ColumnElement

from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
from zenml.models.v2.core.logs import (
LogsRequest,
LogsResponse,
)


class StepRunInputResponse(ArtifactVersionResponse):
"""Response model for step run inputs."""

input_type: StepRunInputArtifactType

def get_hydrated_version(self) -> "StepRunInputResponse":
"""Get the hydrated version of this step run input.

Returns:
an instance of the same entity with the metadata field attached.
"""
from zenml.client import Client

return StepRunInputResponse(
input_type=self.input_type,
**Client().zen_store.get_artifact_version(self.id).model_dump(),
)


# ------------------ Request Model ------------------


Expand Down Expand Up @@ -160,11 +179,11 @@ class StepRunResponseBody(WorkspaceScopedResponseBody):
title="The end time of the step run.",
default=None,
)
inputs: Dict[str, "ArtifactVersionResponse"] = Field(
inputs: Dict[str, StepRunInputResponse] = Field(
title="The input artifact versions of the step run.",
default_factory=dict,
)
outputs: Dict[str, List["ArtifactVersionResponse"]] = Field(
outputs: Dict[str, List[ArtifactVersionResponse]] = Field(
title="The output artifact versions of the step run.",
default_factory=dict,
)
Expand Down Expand Up @@ -268,7 +287,7 @@ def get_hydrated_version(self) -> "StepRunResponse":

# Helper properties
@property
def input(self) -> "ArtifactVersionResponse":
def input(self) -> ArtifactVersionResponse:
"""Returns the input artifact that was used to run this step.

Returns:
Expand All @@ -287,7 +306,7 @@ def input(self) -> "ArtifactVersionResponse":
return next(iter(self.inputs.values()))

@property
def output(self) -> "ArtifactVersionResponse":
def output(self) -> ArtifactVersionResponse:
"""Returns the output artifact that was written by this step.

Returns:
Expand Down Expand Up @@ -319,7 +338,7 @@ def status(self) -> ExecutionStatus:
return self.get_body().status

@property
def inputs(self) -> Dict[str, "ArtifactVersionResponse"]:
def inputs(self) -> Dict[str, StepRunInputResponse]:
"""The `inputs` property.

Returns:
Expand All @@ -328,7 +347,7 @@ def inputs(self) -> Dict[str, "ArtifactVersionResponse"]:
return self.get_body().inputs

@property
def outputs(self) -> Dict[str, List["ArtifactVersionResponse"]]:
def outputs(self) -> Dict[str, List[ArtifactVersionResponse]]:
"""The `outputs` property.

Returns:
Expand Down
30 changes: 21 additions & 9 deletions src/zenml/orchestrators/input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@

from zenml.client import Client
from zenml.config.step_configurations import Step
from zenml.enums import ArtifactSaveType
from zenml.enums import ArtifactSaveType, StepRunInputArtifactType
from zenml.exceptions import InputResolutionError
from zenml.utils import pagination_utils

if TYPE_CHECKING:
from zenml.models import ArtifactVersionResponse, PipelineRunResponse
from zenml.models import PipelineRunResponse
from zenml.models.v2.core.step_run import StepRunInputResponse


def resolve_step_inputs(
step: "Step",
pipeline_run: "PipelineRunResponse",
) -> Tuple[Dict[str, "ArtifactVersionResponse"], List[UUID]]:
) -> Tuple[Dict[str, "StepRunInputResponse"], List[UUID]]:
"""Resolves inputs for the current step.

Args:
Expand All @@ -47,6 +48,7 @@ def resolve_step_inputs(
the current step.
"""
from zenml.models import ArtifactVersionResponse
from zenml.models.v2.core.step_run import StepRunInputResponse

current_run_steps = {
run_step.name: run_step
Expand All @@ -55,7 +57,7 @@ def resolve_step_inputs(
)
}

input_artifacts: Dict[str, "ArtifactVersionResponse"] = {}
input_artifacts: Dict[str, StepRunInputResponse] = {}
for name, input_ in step.spec.inputs.items():
try:
step_run = current_run_steps[input_.step_name]
Expand Down Expand Up @@ -90,15 +92,19 @@ def resolve_step_inputs(
f"`{input_.step_name}`."
)

input_artifacts[name] = step_outputs[0]
input_artifacts[name] = StepRunInputResponse(
input_type=StepRunInputArtifactType.STEP_OUTPUT,
**step_outputs[0].model_dump(),
)

for (
name,
external_artifact,
) in step.config.external_input_artifacts.items():
artifact_version_id = external_artifact.get_artifact_version_id()
input_artifacts[name] = Client().get_artifact_version(
artifact_version_id
input_artifacts[name] = StepRunInputResponse(
input_type=StepRunInputArtifactType.EXTERNAL,
**Client().get_artifact_version(artifact_version_id).model_dump(),
)

for name, config_ in step.config.model_artifacts_or_metadata.items():
Expand Down Expand Up @@ -129,7 +135,10 @@ def resolve_step_inputs(
config_.artifact_name, config_.artifact_version
):
if config_.metadata_name is None:
input_artifacts[name] = artifact_
input_artifacts[name] = StepRunInputResponse(
input_type=StepRunInputArtifactType.LAZY_LOADED,
**artifact_.model_dump(),
)
elif config_.metadata_name:
# metadata values should go directly in parameters, as primitive types
try:
Expand All @@ -156,7 +165,10 @@ def resolve_step_inputs(
for name, cll_ in step.config.client_lazy_loaders.items():
value_ = cll_.evaluate()
if isinstance(value_, ArtifactVersionResponse):
input_artifacts[name] = value_
input_artifacts[name] = StepRunInputResponse(
input_type=StepRunInputArtifactType.LAZY_LOADED,
**value_.model_dump(),
)
else:
step.config.parameters[name] = value_

Expand Down
4 changes: 2 additions & 2 deletions src/zenml/orchestrators/step_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@
from zenml.logger import get_logger
from zenml.logging import step_logging
from zenml.models import (
ArtifactVersionResponse,
LogsRequest,
PipelineDeploymentResponse,
PipelineRunRequest,
PipelineRunResponse,
StepRunResponse,
)
from zenml.models.v2.core.step_run import StepRunInputResponse
from zenml.orchestrators import output_utils, publish_utils, step_run_utils
from zenml.orchestrators import utils as orchestrator_utils
from zenml.orchestrators.step_runner import StepRunner
Expand Down Expand Up @@ -442,7 +442,7 @@ def _run_step_without_step_operator(
pipeline_run: PipelineRunResponse,
step_run: StepRunResponse,
step_run_info: StepRunInfo,
input_artifacts: Dict[str, ArtifactVersionResponse],
input_artifacts: Dict[str, StepRunInputResponse],
output_artifact_uris: Dict[str, str],
last_retry: bool,
) -> None:
Expand Down
1 change: 1 addition & 0 deletions src/zenml/orchestrators/step_run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def populate_request(self, request: StepRunRequest) -> None:
input_name: artifact.id
for input_name, artifact in input_artifacts.items()
}

request.inputs = input_artifact_ids
request.parent_step_ids = parent_step_ids

Expand Down
5 changes: 3 additions & 2 deletions src/zenml/orchestrators/step_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from zenml.logger import get_logger
from zenml.logging.step_logging import StepLogsStorageContext, redirected
from zenml.materializers.base_materializer import BaseMaterializer
from zenml.models.v2.core.step_run import StepRunInputResponse
from zenml.orchestrators.publish_utils import (
publish_step_run_metadata,
publish_successful_step_run,
Expand Down Expand Up @@ -100,7 +101,7 @@ def run(
self,
pipeline_run: "PipelineRunResponse",
step_run: "StepRunResponse",
input_artifacts: Dict[str, "ArtifactVersionResponse"],
input_artifacts: Dict[str, StepRunInputResponse],
output_artifact_uris: Dict[str, str],
step_run_info: StepRunInfo,
) -> None:
Expand Down Expand Up @@ -306,7 +307,7 @@ def _parse_inputs(
self,
args: List[str],
annotations: Dict[str, Any],
input_artifacts: Dict[str, "ArtifactVersionResponse"],
input_artifacts: Dict[str, StepRunInputResponse],
) -> Dict[str, Any]:
"""Parses the inputs for a step entrypoint function.

Expand Down
5 changes: 3 additions & 2 deletions src/zenml/steps/step_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@
from zenml.metadata.metadata_types import MetadataType
from zenml.model.model import Model
from zenml.models import (
ArtifactVersionResponse,
PipelineResponse,
PipelineRunResponse,
StepRunResponse,
)
from zenml.models.v2.core.step_run import StepRunInputResponse


logger = get_logger(__name__)

Expand Down Expand Up @@ -191,7 +192,7 @@ def model(self) -> "Model":
return self.model_version.to_model_class()

@property
def inputs(self) -> Dict[str, "ArtifactVersionResponse"]:
def inputs(self) -> Dict[str, "StepRunInputResponse"]:
"""Returns the input artifacts of the current step.

Returns:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ def upgrade() -> None:
op.execute("""
UPDATE artifact_version
SET save_type = (
SELECT step_run_output_artifact.type
SELECT max(step_run_output_artifact.type)
FROM step_run_output_artifact
WHERE step_run_output_artifact.artifact_id = artifact_version.id
GROUP BY artifact_id
)
""")
op.execute("""
Expand Down Expand Up @@ -71,9 +72,10 @@ def downgrade() -> None:
op.execute("""
UPDATE step_run_output_artifact
SET type = (
SELECT artifact_version.save_type
SELECT max(artifact_version.save_type)
FROM artifact_version
WHERE step_run_output_artifact.artifact_id = artifact_version.id
GROUP BY artifact_id
)
""")
op.execute("""
Expand Down
Loading
Loading