Skip to content

Commit 8b8a6af

Browse files
authored
Bugfix for artifacts coming from a different artifact store (#2928)
* first draft of the artifact store solution * fixes in error message * review changes * renaming the context manage * review comments * fixing the utils * fixing the test fixture * removed unused import * added a small test checking the register calls
1 parent 35813b1 commit 8b8a6af

File tree

5 files changed

+163
-15
lines changed

5 files changed

+163
-15
lines changed

src/zenml/artifacts/utils.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -710,12 +710,19 @@ def _get_artifact_store_from_response_or_from_active_stack(
710710
"BaseArtifactStore",
711711
StackComponent.from_model(artifact_store_model),
712712
)
713-
except (KeyError, ImportError):
714-
logger.warning(
715-
"Unable to restore artifact store while trying to load artifact "
716-
"`%s`. If this artifact is stored in a remote artifact store, "
717-
"this might lead to issues when trying to load the artifact.",
718-
artifact.id,
713+
except KeyError:
714+
raise RuntimeError(
715+
"Unable to fetch the artifact store with id: "
716+
f"'{artifact.artifact_store_id}'. Check whether the artifact "
717+
"store still exists and you have the right permissions to "
718+
"access it."
719+
)
720+
except ImportError:
721+
raise RuntimeError(
722+
"Unable to load the implementation of the artifact store with"
723+
f"id: '{artifact.artifact_store_id}'. Please make sure that "
724+
"the environment that you are loading this artifact from "
725+
"has the right dependencies."
719726
)
720727
return Client().active_stack.artifact_store
721728

src/zenml/orchestrators/step_runner.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -445,14 +445,24 @@ def _load_input_artifact(
445445
# we use the datatype of the stored artifact
446446
data_type = source_utils.load(artifact.data_type)
447447

448+
from zenml.orchestrators.utils import (
449+
register_artifact_store_filesystem,
450+
)
451+
448452
materializer_class: Type[BaseMaterializer] = (
449453
source_utils.load_and_validate_class(
450454
artifact.materializer, expected_class=BaseMaterializer
451455
)
452456
)
453-
materializer: BaseMaterializer = materializer_class(artifact.uri)
454-
materializer.validate_type_compatibility(data_type)
455-
return materializer.load(data_type=data_type)
457+
458+
with register_artifact_store_filesystem(
459+
artifact.artifact_store_id
460+
) as target_artifact_store:
461+
materializer: BaseMaterializer = materializer_class(
462+
uri=artifact.uri, artifact_store=target_artifact_store
463+
)
464+
materializer.validate_type_compatibility(data_type)
465+
return materializer.load(data_type=data_type)
456466

457467
def _validate_outputs(
458468
self,

src/zenml/orchestrators/utils.py

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
# permissions and limitations under the License.
1414
"""Utility functions for the orchestrator."""
1515

16+
import os
1617
import random
17-
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
18+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, cast
1819
from uuid import UUID
1920

2021
from zenml.client import Client
@@ -25,17 +26,21 @@
2526
from zenml.constants import (
2627
ENV_ZENML_ACTIVE_STACK_ID,
2728
ENV_ZENML_ACTIVE_WORKSPACE_ID,
29+
ENV_ZENML_SERVER,
2830
ENV_ZENML_STORE_PREFIX,
2931
PIPELINE_API_TOKEN_EXPIRES_MINUTES,
3032
)
31-
from zenml.enums import StoreType
33+
from zenml.enums import StackComponentType, StoreType
3234
from zenml.exceptions import StepContextError
35+
from zenml.logger import get_logger
3336
from zenml.model.utils import link_artifact_config_to_model
3437
from zenml.models.v2.core.step_run import StepRunRequest
3538
from zenml.new.steps.step_context import get_step_context
39+
from zenml.stack import StackComponent
3640
from zenml.utils.string_utils import format_name_template
3741

3842
if TYPE_CHECKING:
43+
from zenml.artifact_stores.base_artifact_store import BaseArtifactStore
3944
from zenml.artifacts.external_artifact_config import (
4045
ExternalArtifactConfiguration,
4146
)
@@ -302,3 +307,101 @@ def _get_model_versions_from_artifacts(
302307
else:
303308
break
304309
return models
310+
311+
312+
class register_artifact_store_filesystem:
313+
"""Context manager for the artifact_store/filesystem_registry dependency.
314+
315+
Even though it is rare, sometimes we bump into cases where we are trying to
316+
load artifacts that belong to an artifact store which is different from
317+
the active artifact store.
318+
319+
In cases like this, we will try to instantiate the target artifact store
320+
by creating the corresponding artifact store Python object, which ends up
321+
registering the right filesystem in the filesystem registry.
322+
323+
The problem is, the keys in the filesystem registry are schemes (such as
324+
"s3://" or "gcs://"). If we have two artifact stores with the same set of
325+
supported schemes, we might end up overwriting the filesystem that belongs
326+
to the active artifact store (and its authentication). That's why we have
327+
to re-instantiate the active artifact store again, so the correct filesystem
328+
will be restored.
329+
"""
330+
331+
def __init__(self, target_artifact_store_id: Optional[UUID]) -> None:
332+
"""Initialization of the context manager.
333+
334+
Args:
335+
target_artifact_store_id: the ID of the artifact store to load.
336+
"""
337+
self.target_artifact_store_id = target_artifact_store_id
338+
339+
def __enter__(self) -> "BaseArtifactStore":
340+
"""Entering the context manager.
341+
342+
It creates an instance of the target artifact store to register the
343+
correct filesystem in the registry.
344+
345+
Returns:
346+
The target artifact store object.
347+
348+
Raises:
349+
RuntimeError: If the target artifact store can not be fetched or
350+
initiated due to missing dependencies.
351+
"""
352+
try:
353+
if self.target_artifact_store_id is not None:
354+
if (
355+
Client().active_stack.artifact_store.id
356+
!= self.target_artifact_store_id
357+
):
358+
get_logger(__name__).debug(
359+
f"Trying to use the artifact store with ID:"
360+
f"'{self.target_artifact_store_id}'"
361+
f"which is currently not the active artifact store."
362+
)
363+
364+
artifact_store_model_response = Client().get_stack_component(
365+
component_type=StackComponentType.ARTIFACT_STORE,
366+
name_id_or_prefix=self.target_artifact_store_id,
367+
)
368+
return cast(
369+
"BaseArtifactStore",
370+
StackComponent.from_model(artifact_store_model_response),
371+
)
372+
else:
373+
return Client().active_stack.artifact_store
374+
375+
except KeyError:
376+
raise RuntimeError(
377+
"Unable to fetch the artifact store with id: "
378+
f"'{self.target_artifact_store_id}'. Check whether the "
379+
"artifact store still exists and you have the right "
380+
"permissions to access it."
381+
)
382+
except ImportError:
383+
raise RuntimeError(
384+
"Unable to load the implementation of the artifact store with"
385+
f"id: '{self.target_artifact_store_id}'. Please make sure that "
386+
"the environment that you are loading this artifact from "
387+
"has the right dependencies."
388+
)
389+
390+
def __exit__(
391+
self,
392+
exc_type: Optional[Any],
393+
exc_value: Optional[Any],
394+
traceback: Optional[Any],
395+
) -> None:
396+
"""Set it back to the original state.
397+
398+
Args:
399+
exc_type: The class of the exception
400+
exc_value: The instance of the exception
401+
traceback: The traceback of the exception
402+
"""
403+
if ENV_ZENML_SERVER not in os.environ:
404+
# As we exit the handler, we have to re-register the filesystem
405+
# that belongs to the active artifact store as it may have been
406+
# overwritten.
407+
Client().active_stack.artifact_store._register()

tests/unit/artifacts/test_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import os
1515
import shutil
1616
import tempfile
17-
from uuid import uuid4
1817

1918
import numpy as np
2019
import pytest
@@ -33,7 +32,7 @@
3332

3433

3534
@pytest.fixture
36-
def model_artifact(mocker):
35+
def model_artifact(mocker, clean_client: "Client"):
3736
return mocker.Mock(
3837
spec=ArtifactVersionResponse,
3938
id="123",
@@ -45,7 +44,7 @@ def model_artifact(mocker):
4544
uri="gs://my-bucket/model.joblib",
4645
data_type="path/to/model/class",
4746
materializer="path/to/materializer/class",
48-
artifact_store_id=uuid4(),
47+
artifact_store_id=clean_client.active_stack.artifact_store.id,
4948
)
5049

5150

tests/unit/orchestrators/test_utils.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
1212
# or implied. See the License for the specific language governing
1313
# permissions and limitations under the License.
14-
from zenml.orchestrators.utils import is_setting_enabled
14+
from unittest import mock
15+
16+
from zenml.enums import StackComponentType
17+
from zenml.orchestrators.utils import (
18+
is_setting_enabled,
19+
register_artifact_store_filesystem,
20+
)
1521

1622

1723
def test_is_setting_enabled():
@@ -97,3 +103,26 @@ def test_is_setting_enabled():
97103
)
98104
is False
99105
)
106+
107+
108+
def test_register_artifact_store_filesystem(clean_client):
109+
"""Tests if a new filesystem gets registered with the context manager."""
110+
with mock.patch(
111+
"zenml.artifact_stores.base_artifact_store.BaseArtifactStore._register"
112+
) as register:
113+
# Calling the active artifact store will call register once
114+
_ = clean_client.active_stack.artifact_store
115+
assert register.call_count == 1
116+
117+
new_artifact_store_model = clean_client.create_stack_component(
118+
name="new_local_artifact_store",
119+
flavor="local",
120+
component_type=StackComponentType.ARTIFACT_STORE,
121+
configuration={"path": ""},
122+
)
123+
with register_artifact_store_filesystem(new_artifact_store_model.id):
124+
# Entering the context manager will register the new filesystem
125+
assert register.call_count == 2
126+
127+
# Exiting the context manager will set it back by calling register again
128+
assert register.call_count == 3

0 commit comments

Comments
 (0)