Skip to content

Commit 694aeda

Browse files
authored
Utilizing cascading tags for cached step runs (#3655)
* cascading tags for cached step runs * moved tags out * fixed the tests * another pipeline
1 parent fdca351 commit 694aeda

File tree

5 files changed

+141
-10
lines changed

5 files changed

+141
-10
lines changed

src/zenml/orchestrators/step_launcher.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,10 @@ def _bypass() -> None:
292292
artifacts=step_run.outputs,
293293
model_version=model_version,
294294
)
295+
step_run_utils.cascade_tags_for_output_artifacts(
296+
artifacts=step_run.outputs,
297+
tags=pipeline_run.config.tags,
298+
)
295299

296300
except: # noqa: E722
297301
logger.error(f"Pipeline run `{pipeline_run.name}` failed.")

src/zenml/orchestrators/step_run_utils.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
# permissions and limitations under the License.
1414
"""Utilities for creating step runs."""
1515

16-
from typing import Dict, List, Optional, Set, Tuple
16+
from typing import Dict, List, Optional, Set, Tuple, Union
1717

18+
from zenml import Tag, add_tags
1819
from zenml.client import Client
1920
from zenml.config.step_configurations import Step
2021
from zenml.constants import CODE_HASH_PARAMETER_NAME, TEXT_FIELD_MAX_LENGTH
@@ -333,6 +334,11 @@ def create_cached_step_runs(
333334
model_version=model_version,
334335
)
335336

337+
cascade_tags_for_output_artifacts(
338+
artifacts=step_run.outputs,
339+
tags=pipeline_run.config.tags,
340+
)
341+
336342
logger.info("Using cached version of step `%s`.", invocation_id)
337343
cached_invocations.add(invocation_id)
338344

@@ -382,3 +388,26 @@ def link_output_artifacts_to_model_version(
382388
artifact_version=output_artifact,
383389
model_version=model_version,
384390
)
391+
392+
393+
def cascade_tags_for_output_artifacts(
394+
artifacts: Dict[str, List[ArtifactVersionResponse]],
395+
tags: Optional[List[Union[str, Tag]]] = None,
396+
) -> None:
397+
"""Tag the outputs of a step run.
398+
399+
Args:
400+
artifacts: The step output artifacts.
401+
tags: The tags to add to the artifacts.
402+
"""
403+
if tags is None:
404+
return
405+
406+
cascade_tags = [t for t in tags if isinstance(t, Tag) and t.cascade]
407+
408+
for output_artifacts in artifacts.values():
409+
for output_artifact in output_artifacts:
410+
add_tags(
411+
tags=[t.name for t in cascade_tags],
412+
artifact_version_id=output_artifact.id,
413+
)

src/zenml/utils/tag_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -348,10 +348,10 @@ def add_tags(
348348
if isinstance(tag, Tag):
349349
tag_model = client.get_tag(tag.name)
350350

351-
if tag.exclusive != tag_model.exclusive:
351+
if bool(tag.exclusive) != tag_model.exclusive:
352352
raise ValueError(
353-
f"The tag `{tag.name}` is an "
354-
f"{'exclusive' if tag_model.exclusive else 'non-exclusive'} "
353+
f"The tag `{tag.name}` is "
354+
f"{'an exclusive' if tag_model.exclusive else 'a non-exclusive'} "
355355
"tag. Please update it before attaching it to a resource."
356356
)
357357
if tag.cascade is not None:

src/zenml/zen_stores/sql_zen_store.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11354,13 +11354,10 @@ def _attach_tags_to_resources(
1135411354
except EntityExistsError:
1135511355
if isinstance(tag, tag_utils.Tag):
1135611356
tag_schema = self._get_tag_schema(tag.name, session)
11357-
if (
11358-
tag.exclusive is not None
11359-
and tag.exclusive != tag_schema.exclusive
11360-
):
11357+
if bool(tag.exclusive) != tag_schema.exclusive:
1136111358
raise ValueError(
11362-
f"Tag `{tag_schema.name}` has been defined as a "
11363-
f"{'exclusive' if tag_schema.exclusive else 'non-exclusive'} "
11359+
f"Tag `{tag_schema.name}` has been defined as "
11360+
f"{'an exclusive' if tag_schema.exclusive else 'a non-exclusive'} "
1136411361
"tag. Please update it before attaching it to resources."
1136511362
)
1136611363
else:

tests/integration/functional/utils/test_tag_utils.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
# permissions and limitations under the License.
1414

1515

16+
import os
1617
from typing import Annotated, Tuple
1718

1819
import pytest
1920

2021
from zenml import ArtifactConfig, Tag, add_tags, pipeline, remove_tags, step
22+
from zenml.constants import ENV_ZENML_PREVENT_CLIENT_SIDE_CACHING
23+
from zenml.enums import ExecutionStatus
2124

2225

2326
@step
@@ -139,3 +142,101 @@ def test_tag_utils(clean_client):
139142
clean_client.update_tag(
140143
tag_name_or_id=non_exclusive_tag.id, exclusive=True
141144
)
145+
146+
147+
@pipeline(
148+
tags=[Tag(name="cascade_tag", cascade=True)],
149+
enable_cache=False,
150+
)
151+
def pipeline_with_cascade_tag():
152+
"""Pipeline definition to test the tag utils."""
153+
_ = step_single_output()
154+
155+
156+
def test_cascade_tags_for_output_artifacts_of_cached_pipeline_run(
157+
clean_client,
158+
):
159+
"""Test that the cascade tags are added to the output artifacts of a cached step."""
160+
# Run the pipeline once without caching
161+
pipeline_with_cascade_tag()
162+
163+
pipeline_runs = clean_client.list_pipeline_runs(sort_by="created")
164+
assert len(pipeline_runs.items) == 1
165+
assert (
166+
pipeline_runs.items[0].steps["step_single_output"].status
167+
== ExecutionStatus.COMPLETED
168+
)
169+
assert "cascade_tag" in [
170+
t.name
171+
for t in pipeline_runs.items[0]
172+
.steps["step_single_output"]
173+
.outputs["single"][0]
174+
.tags
175+
]
176+
177+
# Run it once again with caching
178+
pipeline_with_cascade_tag.configure(enable_cache=True)
179+
pipeline_with_cascade_tag()
180+
pipeline_runs = clean_client.list_pipeline_runs(sort_by="created")
181+
assert len(pipeline_runs.items) == 2
182+
assert (
183+
pipeline_runs.items[1].steps["step_single_output"].status
184+
== ExecutionStatus.CACHED
185+
)
186+
187+
# Run it once again with caching and a new cascade tag
188+
pipeline_with_cascade_tag.configure(
189+
tags=[Tag(name="second_cascade_tag", cascade=True)]
190+
)
191+
pipeline_with_cascade_tag()
192+
pipeline_runs = clean_client.list_pipeline_runs(sort_by="created")
193+
assert len(pipeline_runs.items) == 3
194+
assert (
195+
pipeline_runs.items[2].steps["step_single_output"].status
196+
== ExecutionStatus.CACHED
197+
)
198+
199+
assert "second_cascade_tag" in [
200+
t.name
201+
for t in pipeline_runs.items[0]
202+
.steps["step_single_output"]
203+
.outputs["single"][0]
204+
.tags
205+
]
206+
assert "second_cascade_tag" in [
207+
t.name
208+
for t in pipeline_runs.items[2]
209+
.steps["step_single_output"]
210+
.outputs["single"][0]
211+
.tags
212+
]
213+
214+
# Run it once again with caching (preventing client side caching) and a new cascade tag
215+
os.environ[ENV_ZENML_PREVENT_CLIENT_SIDE_CACHING] = "true"
216+
pipeline_with_cascade_tag.configure(
217+
tags=[Tag(name="third_cascade_tag", cascade=True)]
218+
)
219+
pipeline_with_cascade_tag()
220+
221+
pipeline_runs = clean_client.list_pipeline_runs(sort_by="created")
222+
assert len(pipeline_runs.items) == 4
223+
assert (
224+
pipeline_runs.items[3].steps["step_single_output"].status
225+
== ExecutionStatus.CACHED
226+
)
227+
228+
assert "third_cascade_tag" in [
229+
t.name
230+
for t in pipeline_runs.items[0]
231+
.steps["step_single_output"]
232+
.outputs["single"][0]
233+
.tags
234+
]
235+
assert "third_cascade_tag" in [
236+
t.name
237+
for t in pipeline_runs.items[3]
238+
.steps["step_single_output"]
239+
.outputs["single"][0]
240+
.tags
241+
]
242+
os.environ.pop(ENV_ZENML_PREVENT_CLIENT_SIDE_CACHING, None)

0 commit comments

Comments
 (0)