Skip to content

Do not tag model via Model class on creation #3098

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions docs/book/how-to/handle-data-artifacts/tagging.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,19 @@ Note that [ZenML Pro](https://zenml.io/pro) users can tag artifacts directly in

Just like artifacts, you can also tag your models to organize them semantically. Here's how to use tags with models in the ZenML Python SDK and CLI (or in the [ZenML Pro Dashboard directly](https://zenml.io/pro)).

When creating a model using the `Model` object, you can specify tags as key-value pairs that will be attached to the model upon creation:
When creating a model version using the `Model` object, you can specify tags as key-value pairs that will be attached to the model version upon creation.
{% hint style="warning" %}
During pipeline run a model can be also implicitly created (if not exists), in such cases it will not get the `tags` from the `Model` class.
You can manipulate the model tags using SDK (see below) or the ZenML Pro UI.
{% endhint %}

```python
from zenml.models import Model

# Define tags to be added to the model
# Define tags to be added to the model version
tags = ["experiment", "v1", "classification-task"]

# Create a model with tags
# Create a model version with tags
model = Model(
name="iris_classifier",
version="1.0.0",
Expand Down
1 change: 0 additions & 1 deletion src/zenml/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,6 @@ def _get_or_create_model(self) -> "ModelResponse":
limitations=self.limitations,
trade_offs=self.trade_offs,
ethics=self.ethics,
tags=self.tags,
user=zenml_client.active_user.id,
workspace=zenml_client.active_workspace.id,
save_models_to_registry=self.save_models_to_registry,
Expand Down
1 change: 1 addition & 0 deletions src/zenml/models/v2/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class ModelRequest(WorkspaceScopedRequest):
)
tags: Optional[List[str]] = Field(
title="Tags associated with the model",
default=None,
)
save_models_to_registry: bool = Field(
title="Whether to save all ModelArtifacts to Model Registry",
Expand Down
23 changes: 11 additions & 12 deletions tests/integration/functional/model/test_model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def test_model_create_model_and_version(self):
assert mv.name == str(mv.number)
assert mv.model.name == mdl_name
assert {t.name for t in mv.tags} == {"tag1", "tag2"}
assert {t.name for t in mv.model.tags} == {"tag1", "tag2"}
assert len(mv.model.tags) == 0

def test_create_model_version_makes_proper_tagging(self):
"""Test if model versions get unique tags."""
Expand All @@ -208,14 +208,14 @@ def test_create_model_version_makes_proper_tagging(self):
assert mv.name == str(mv.number)
assert mv.model.name == mdl_name
assert {t.name for t in mv.tags} == {"tag1", "tag2"}
assert {t.name for t in mv.model.tags} == {"tag1", "tag2"}
assert len(mv.model.tags) == 0

mv = Model(name=mdl_name, tags=["tag3", "tag4"])
mv = mv._get_or_create_model_version()
assert mv.name == str(mv.number)
assert mv.model.name == mdl_name
assert {t.name for t in mv.tags} == {"tag3", "tag4"}
assert {t.name for t in mv.model.tags} == {"tag1", "tag2"}
assert len(mv.model.tags) == 0

def test_model_fetch_model_and_version_by_number(self):
"""Test model and model version retrieval by exact version number."""
Expand Down Expand Up @@ -301,15 +301,17 @@ def test_tags_properly_created(self):

# run 2 times to first create, next get
for _ in range(2):
model = mv._get_or_create_model()
model_version = mv._get_or_create_model_version()

assert len(model.tags) == 2
assert {t.name for t in model.tags} == {
assert len(model_version.tags) == 2
assert {t.name for t in model_version.tags} == {
green_tag,
new_tag,
}
assert {
t.color for t in model.tags if t.name == green_tag
t.color
for t in model_version.tags
if t.name == green_tag
} == {"green"}

def test_tags_properly_updated(self):
Expand All @@ -324,10 +326,8 @@ def test_tags_properly_updated(self):

client.update_model(model_id, add_tags=["tag1", "tag2"])
model = mv._get_or_create_model()
assert len(model.tags) == 4
assert len(model.tags) == 2
assert {t.name for t in model.tags} == {
"foo",
"bar",
"tag1",
"tag2",
}
Expand All @@ -346,8 +346,7 @@ def test_tags_properly_updated(self):

client.update_model(model_id, remove_tags=["tag1", "tag2"])
model = mv._get_or_create_model()
assert len(model.tags) == 2
assert {t.name for t in model.tags} == {"foo", "bar"}
assert len(model.tags) == 0

client.update_model_version(
model_id, "1", remove_tags=["tag3", "tag4"]
Expand Down
Loading