-
Notifications
You must be signed in to change notification settings - Fork 516
Extend notebook source replacement code to other objects apart from ZenML steps #2919
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6cc2495
ce94ad6
6824680
d0d120a
cbf9553
a855ac5
a89e32c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,15 +22,8 @@ | |
from distutils.sysconfig import get_python_lib | ||
from pathlib import Path, PurePath | ||
from types import BuiltinFunctionType, FunctionType, ModuleType | ||
from typing import ( | ||
Any, | ||
Callable, | ||
Iterator, | ||
Optional, | ||
Type, | ||
Union, | ||
cast, | ||
) | ||
from typing import Any, Callable, Dict, Iterator, Optional, Type, Union, cast | ||
from uuid import UUID | ||
|
||
from zenml.config.source import ( | ||
CodeRepositorySource, | ||
|
@@ -69,6 +62,8 @@ | |
) | ||
|
||
_SHARED_TEMPDIR: Optional[str] = None | ||
_resolved_notebook_sources: Dict[str, str] = {} | ||
_notebook_modules: Dict[str, UUID] = {} | ||
|
||
|
||
def load(source: Union[Source, str]) -> Any: | ||
|
@@ -237,13 +232,23 @@ def resolve( | |
source_type = SourceType.UNKNOWN | ||
elif source_type == SourceType.NOTEBOOK: | ||
source = NotebookSource( | ||
module=module_name, | ||
module="__main__", | ||
attribute=attribute_name, | ||
type=source_type, | ||
) | ||
# Private attributes are ignored by pydantic if passed in the __init__ | ||
# method, so we set this afterwards | ||
source._cell_code = notebook_utils.load_notebook_cell_code(obj) | ||
|
||
if module_name in _notebook_modules: | ||
source.replacement_module = module_name | ||
source.artifact_store_id = _notebook_modules[module_name] | ||
elif cell_code := notebook_utils.load_notebook_cell_code(obj): | ||
replacement_module = ( | ||
notebook_utils.compute_cell_replacement_module_name( | ||
cell_code=cell_code | ||
) | ||
) | ||
source.replacement_module = replacement_module | ||
_resolved_notebook_sources[source.import_path] = cell_code | ||
|
||
return source | ||
|
||
return Source( | ||
|
@@ -387,6 +392,9 @@ def get_source_type(module: ModuleType) -> SourceType: | |
Returns: | ||
The source type. | ||
""" | ||
if module.__name__ in _notebook_modules: | ||
return SourceType.NOTEBOOK | ||
|
||
try: | ||
file_path = inspect.getfile(module) | ||
except (TypeError, OSError): | ||
|
@@ -582,33 +590,61 @@ def _try_to_load_notebook_source(source: NotebookSource) -> Any: | |
|
||
Raises: | ||
RuntimeError: If the source can't be loaded. | ||
FileNotFoundError: If the file containing the notebook cell code can't | ||
be found. | ||
|
||
Returns: | ||
The loaded object. | ||
""" | ||
if not source.code_path or not source.replacement_module: | ||
if not source.replacement_module: | ||
raise RuntimeError( | ||
f"Failed to load {source.import_path}. This object was defined in " | ||
"a notebook and you're trying to load it outside of a notebook. " | ||
"This is currently only enabled for ZenML steps." | ||
"This is currently only enabled for ZenML steps and materializers. " | ||
"To enable this for your custom classes or functions, use the " | ||
"`zenml.utils.notebook_utils.enable_notebook_code_extraction` " | ||
"decorator." | ||
) | ||
|
||
extract_dir = _get_shared_temp_dir() | ||
file_path = os.path.join(extract_dir, f"{source.replacement_module}.py") | ||
file_name = f"{source.replacement_module}.py" | ||
file_path = os.path.join(extract_dir, file_name) | ||
|
||
if not os.path.exists(file_path): | ||
from zenml.client import Client | ||
from zenml.utils import code_utils | ||
|
||
artifact_store = Client().active_stack.artifact_store | ||
|
||
if ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the notebook cell content is stored in a different artifact store than the active one we are failing here. Similar to what we did with the pipeline artifacts, could we not try to initialize this other artifact store and use it to load it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1, seems like the old bug we had There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would use active only in case |
||
source.artifact_store_id | ||
and source.artifact_store_id != artifact_store.id | ||
): | ||
raise RuntimeError( | ||
"Notebook cell code not stored in active artifact store." | ||
) | ||
|
||
logger.info( | ||
"Downloading notebook cell content from `%s` to load `%s`.", | ||
source.code_path, | ||
"Downloading notebook cell content to load `%s`.", | ||
source.import_path, | ||
) | ||
|
||
code_utils.download_and_extract_code( | ||
code_path=source.code_path, extract_dir=extract_dir | ||
) | ||
try: | ||
code_utils.download_notebook_code( | ||
artifact_store=artifact_store, | ||
file_name=file_name, | ||
download_path=file_path, | ||
) | ||
except FileNotFoundError: | ||
if not source.artifact_store_id: | ||
raise FileNotFoundError( | ||
"Unable to find notebook code file. This might be because " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the error message above, this one is now a bit misleading as it will fail if the file is stored in a different artifact store. |
||
"the file is stored in a different artifact store." | ||
) | ||
|
||
raise | ||
else: | ||
_notebook_modules[source.replacement_module] = artifact_store.id | ||
try: | ||
module = _load_module( | ||
module_name=source.replacement_module, import_root=extract_dir | ||
|
@@ -734,3 +770,13 @@ def validate_source_class( | |
return True | ||
else: | ||
return False | ||
|
||
|
||
def get_resolved_notebook_sources() -> Dict[str, str]: | ||
"""Get all notebook sources that were resolved in this process. | ||
|
||
Returns: | ||
Dictionary mapping the import path of notebook sources to the code | ||
of their notebook cell. | ||
""" | ||
return _resolved_notebook_sources.copy() |
Uh oh!
There was an error while loading. Please reload this page.