diff --git a/docs/book/how-to/run-remote-pipelines-from-notebooks/README.md b/docs/book/how-to/run-remote-pipelines-from-notebooks/README.md new file mode 100644 index 00000000000..af76f69ecec --- /dev/null +++ b/docs/book/how-to/run-remote-pipelines-from-notebooks/README.md @@ -0,0 +1,14 @@ +--- +description: Use Jupyter Notebooks to run remote steps or pipelines +--- + +# 📔 Run remote pipelines from notebooks + +ZenML steps and pipelines can be defined in a Jupyter notebook and executed remotely. To do so, ZenML will extract the code from your notebook cells and run them as Python modules inside the Docker containers that execute your pipeline steps remotely. For this to work, the notebook cells in which you define your steps need to meet certain conditions. + +Learn more about it in the following sections: + +
Define steps in notebook cellsdefine-steps-in-notebook-cells.md
Configure the notebook path
+ + +
ZenML Scarf
diff --git a/docs/book/how-to/run-remote-pipelines-from-notebooks/define-steps-in-notebook-cells.md b/docs/book/how-to/run-remote-pipelines-from-notebooks/define-steps-in-notebook-cells.md new file mode 100644 index 00000000000..1cceabe9a9d --- /dev/null +++ b/docs/book/how-to/run-remote-pipelines-from-notebooks/define-steps-in-notebook-cells.md @@ -0,0 +1,10 @@ + +# Define steps in notebook cells + +If you want to run ZenML steps defined in notebook cells remotely (either with a remote [orchestrator](../../component-guide/orchestrators/orchestrators.md) or [step operator](../../component-guide/step-operators/step-operators.md)), the cells defining your steps must meet the following conditions: +- The cell can only contain python code, no Jupyter magic commands or shell commands starting with a `%` or `!`. +- The cell **must not** call code from other notebook cells. Functions or classes imported from python files are allowed. +- The cell **must not** rely on imports of previous cells. This means your cell must perform all the imports it needs itself, including ZenML imports like `from zenml import step`. + + +
ZenML Scarf
\ No newline at end of file diff --git a/docs/book/toc.md b/docs/book/toc.md index 5393e4c327f..35565a0230e 100644 --- a/docs/book/toc.md +++ b/docs/book/toc.md @@ -177,6 +177,8 @@ * [🔌 Connect to a server](how-to/connecting-to-zenml/README.md) * [Connect in with your User (interactive)](how-to/connecting-to-zenml/connect-in-with-your-user-interactive.md) * [Connect with a Service Account](how-to/connecting-to-zenml/connect-with-a-service-account.md) +* [📔 Run remote pipelines from notebooks](how-to/run-remote-pipelines-from-notebooks/README.md) + * [Define steps in notebook cells](how-to/run-remote-pipelines-from-notebooks/define-steps-in-notebook-cells.md) * [🔐 Interact with secrets](how-to/interact-with-secrets.md) * [🐞 Debug and solve issues](how-to/debug-and-solve-issues.md) diff --git a/docs/mocked_libs.json b/docs/mocked_libs.json index 9744212fede..714fd0c87d8 100644 --- a/docs/mocked_libs.json +++ b/docs/mocked_libs.json @@ -241,5 +241,10 @@ "databricks.sdk", "databricks.sdk.service.compute", "databricks.sdk.service.jobs", - "databricks.sdk.service.serving" + "databricks.sdk.service.serving", + "IPython", + "IPython.core", + "IPython.core.display", + "IPython.core.display_functions", + "ipywidgets" ] diff --git a/src/zenml/config/source.py b/src/zenml/config/source.py index 0e55f35ebbd..4e74d000d64 100644 --- a/src/zenml/config/source.py +++ b/src/zenml/config/source.py @@ -42,6 +42,7 @@ class SourceType(Enum): INTERNAL = "internal" DISTRIBUTION_PACKAGE = "distribution_package" CODE_REPOSITORY = "code_repository" + NOTEBOOK = "notebook" UNKNOWN = "unknown" @@ -229,6 +230,63 @@ def _validate_type(cls, value: SourceType) -> SourceType: return value +class NotebookSource(Source): + """Source representing an object defined in a notebook. + + Attributes: + code_path: Path where the notebook cell code for this source is + uploaded. + replacement_module: Name of the module from which this source should + be loaded in case the code is not running in a notebook. + """ + + code_path: Optional[str] = None + replacement_module: Optional[str] = None + type: SourceType = SourceType.NOTEBOOK + + # Private attribute that is used to store the code but should not be + # serialized + _cell_code: Optional[str] = None + + @field_validator("type") + @classmethod + def _validate_type(cls, value: SourceType) -> SourceType: + """Validate the source type. + + Args: + value: The source type. + + Raises: + ValueError: If the source type is not `NOTEBOOK`. + + Returns: + The source type. + """ + if value != SourceType.NOTEBOOK: + raise ValueError("Invalid source type.") + + return value + + @field_validator("module") + @classmethod + def _validate_module(cls, value: str) -> str: + """Validate the module. + + Args: + value: The module. + + Raises: + ValueError: If the module is not `__main__`. + + Returns: + The module. + """ + if value != "__main__": + raise ValueError("Invalid module for notebook source.") + + return value + + def convert_source(source: Any) -> Any: """Converts an old source string to a source object. diff --git a/src/zenml/entrypoints/base_entrypoint_configuration.py b/src/zenml/entrypoints/base_entrypoint_configuration.py index cb24d646fdb..e35ead1cedb 100644 --- a/src/zenml/entrypoints/base_entrypoint_configuration.py +++ b/src/zenml/entrypoints/base_entrypoint_configuration.py @@ -15,7 +15,6 @@ import argparse import os -import shutil import sys from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, NoReturn, Set @@ -27,9 +26,13 @@ ENV_ZENML_REQUIRES_CODE_DOWNLOAD, handle_bool_env_var, ) -from zenml.io import fileio from zenml.logger import get_logger -from zenml.utils import code_repository_utils, source_utils, uuid_utils +from zenml.utils import ( + code_repository_utils, + code_utils, + source_utils, + uuid_utils, +) if TYPE_CHECKING: from zenml.models import CodeReferenceResponse, PipelineDeploymentResponse @@ -261,10 +264,6 @@ def download_code_from_artifact_store(self, code_path: str) -> None: Args: code_path: Path where the code is stored. - - Raises: - RuntimeError: If the code is stored in an artifact store which is - not active. """ logger.info( "Downloading code from artifact store path `%s`.", code_path @@ -272,19 +271,14 @@ def download_code_from_artifact_store(self, code_path: str) -> None: # Do not remove this line, we need to instantiate the artifact store to # register the filesystem needed for the file download - artifact_store = Client().active_stack.artifact_store - - if not code_path.startswith(artifact_store.path): - raise RuntimeError("Code stored in different artifact store.") + _ = Client().active_stack.artifact_store extract_dir = os.path.abspath("code") os.makedirs(extract_dir) - download_path = os.path.basename(code_path) - fileio.copy(code_path, download_path) - - shutil.unpack_archive(filename=download_path, extract_dir=extract_dir) - os.remove(download_path) + code_utils.download_and_extract_code( + code_path=code_path, extract_dir=extract_dir + ) source_utils.set_custom_source_root(extract_dir) sys.path.insert(0, extract_dir) diff --git a/src/zenml/environment.py b/src/zenml/environment.py index f25c2c641d8..56bcc89dcb7 100644 --- a/src/zenml/environment.py +++ b/src/zenml/environment.py @@ -15,7 +15,6 @@ import os import platform -from importlib.util import find_spec from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, cast @@ -259,15 +258,17 @@ def in_notebook() -> bool: if Environment.in_google_colab(): return True - if find_spec("IPython") is not None: - from IPython import get_ipython + try: + ipython = get_ipython() # type: ignore[name-defined] + except NameError: + return False - if get_ipython().__class__.__name__ in [ - "TerminalInteractiveShell", - "ZMQInteractiveShell", - "DatabricksShell", - ]: - return True + if ipython.__class__.__name__ in [ + "TerminalInteractiveShell", + "ZMQInteractiveShell", + "DatabricksShell", + ]: + return True return False @staticmethod diff --git a/src/zenml/image_builders/build_context.py b/src/zenml/image_builders/build_context.py index 6c0146d8a60..e8284cfb446 100644 --- a/src/zenml/image_builders/build_context.py +++ b/src/zenml/image_builders/build_context.py @@ -45,9 +45,9 @@ def __init__( given, a file called `.dockerignore` in the build context root directory will be used instead if it exists. """ + super().__init__() self._root = root self._dockerignore_file = dockerignore_file - self._extra_files: Dict[str, str] = {} @property def dockerignore_file(self) -> Optional[str]: diff --git a/src/zenml/integrations/azure/step_operators/azureml_step_operator.py b/src/zenml/integrations/azure/step_operators/azureml_step_operator.py index c5882d1ee77..3e63dacd8fd 100644 --- a/src/zenml/integrations/azure/step_operators/azureml_step_operator.py +++ b/src/zenml/integrations/azure/step_operators/azureml_step_operator.py @@ -250,6 +250,9 @@ def launch( "apt_packages", "user", "source_files", + "allow_including_files_in_images", + "allow_download_from_code_repository", + "allow_download_from_artifact_store", ] docker_settings = info.config.docker_settings ignored_docker_fields = docker_settings.model_fields_set.intersection( diff --git a/src/zenml/integrations/databricks/flavors/databricks_orchestrator_flavor.py b/src/zenml/integrations/databricks/flavors/databricks_orchestrator_flavor.py index e2206ba1ca2..e91420956d7 100644 --- a/src/zenml/integrations/databricks/flavors/databricks_orchestrator_flavor.py +++ b/src/zenml/integrations/databricks/flavors/databricks_orchestrator_flavor.py @@ -93,6 +93,15 @@ def is_local(self) -> bool: """ return False + @property + def is_remote(self) -> bool: + """Checks if this stack component is running remotely. + + Returns: + True if this config is for a remote component, False otherwise. + """ + return True + class DatabricksOrchestratorFlavor(BaseOrchestratorFlavor): """Databricks orchestrator flavor.""" diff --git a/src/zenml/new/pipelines/build_utils.py b/src/zenml/new/pipelines/build_utils.py index ea9dbc7326f..eacbd1d07da 100644 --- a/src/zenml/new/pipelines/build_utils.py +++ b/src/zenml/new/pipelines/build_utils.py @@ -14,9 +14,7 @@ """Pipeline build utilities.""" import hashlib -import os import platform -import tempfile from typing import ( TYPE_CHECKING, Dict, @@ -29,7 +27,6 @@ import zenml from zenml.client import Client from zenml.code_repositories import BaseCodeRepository -from zenml.io import fileio from zenml.logger import get_logger from zenml.models import ( BuildItem, @@ -40,9 +37,8 @@ PipelineDeploymentBase, StackResponse, ) -from zenml.new.pipelines.code_archive import CodeArchive from zenml.stack import Stack -from zenml.utils import source_utils, string_utils +from zenml.utils import source_utils from zenml.utils.pipeline_docker_image_builder import ( PipelineDockerImageBuilder, ) @@ -487,8 +483,7 @@ def verify_local_repository_context( raise RuntimeError( "The `DockerSettings` of the pipeline or one of its " "steps specify that code should be downloaded from a " - "code repository " - "(`source_files=['download_from_code_repository']`), but " + "code repository, but " "there is no code repository active at your current source " f"root `{source_utils.get_source_root()}`." ) @@ -496,8 +491,7 @@ def verify_local_repository_context( raise RuntimeError( "The `DockerSettings` of the pipeline or one of its " "steps specify that code should be downloaded from a " - "code repository " - "(`source_files=['download_from_code_repository']`), but " + "code repository, but " "the code repository active at your current source root " f"`{source_utils.get_source_root()}` has uncommitted " "changes." @@ -506,8 +500,7 @@ def verify_local_repository_context( raise RuntimeError( "The `DockerSettings` of the pipeline or one of its " "steps specify that code should be downloaded from a " - "code repository " - "(`source_files=['download_from_code_repository']`), but " + "code repository, but " "the code repository active at your current source root " f"`{source_utils.get_source_root()}` has unpushed " "changes." @@ -578,7 +571,7 @@ def verify_custom_build( raise RuntimeError( "The `DockerSettings` of the pipeline or one of its " "steps specify that code should be included in the Docker " - "image (`source_files=['include']`), but the build you " + "image, but the build you " "specified requires code download. Either update your " "`DockerSettings` or specify a different build and try " "again." @@ -591,8 +584,7 @@ def verify_custom_build( raise RuntimeError( "The `DockerSettings` of the pipeline or one of its " "steps specify that code should be downloaded from a " - "code repository " - "(`source_files=['download_from_code_repository']`), but " + "code repository but " "there is no code repository active at your current source " f"root `{source_utils.get_source_root()}`." ) @@ -704,10 +696,10 @@ def should_upload_code( Whether the current code should be uploaded for the deployment. """ if not build: - # No build means all the code is getting executed locally, which means - # we don't need to download any code - # TODO: This does not apply to e.g. Databricks, figure out a solution - # here + # No build means we don't need to download code into a Docker container + # for step execution. In other remote orchestrators that don't use + # Docker containers but instead use e.g. Wheels to run, the code should + # already be included. return False for step in deployment.step_configurations.values(): @@ -724,56 +716,3 @@ def should_upload_code( return True return False - - -def upload_code_if_necessary() -> str: - """Upload code to the artifact store if necessary. - - This function computes a hash of the code to be uploaded, and if an archive - with the same hash already exists it will not re-upload but instead return - the path to the existing archive. - - Returns: - The path where to archived code is uploaded. - """ - logger.info("Archiving code...") - - code_archive = CodeArchive(root=source_utils.get_source_root()) - artifact_store = Client().active_stack.artifact_store - - with tempfile.NamedTemporaryFile( - mode="w+b", delete=False, suffix=".tar.gz" - ) as f: - code_archive.write_archive(f) - - hash_ = hashlib.sha1() # nosec - - while True: - data = f.read(64 * 1024) - if not data: - break - hash_.update(data) - - filename = f"{hash_.hexdigest()}.tar.gz" - upload_dir = os.path.join(artifact_store.path, "code_uploads") - fileio.makedirs(upload_dir) - upload_path = os.path.join(upload_dir, filename) - - if not fileio.exists(upload_path): - archive_size = string_utils.get_human_readable_filesize( - os.path.getsize(f.name) - ) - logger.info( - "Uploading code to `%s` (Size: %s).", upload_path, archive_size - ) - fileio.copy(f.name, upload_path) - logger.info("Code upload finished.") - else: - logger.info( - "Code already exists in artifact store, skipping upload." - ) - - if os.path.exists(f.name): - os.remove(f.name) - - return upload_path diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index 7621026750d..b411870bfe6 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -73,6 +73,7 @@ create_placeholder_run, deploy_pipeline, prepare_model_versions, + upload_notebook_cell_code_if_necessary, ) from zenml.stack import Stack from zenml.steps import BaseStep @@ -82,6 +83,7 @@ from zenml.steps.step_invocation import StepInvocation from zenml.utils import ( code_repository_utils, + code_utils, dashboard_utils, dict_utils, pydantic_utils, @@ -668,6 +670,9 @@ def _run( stack = Client().active_stack stack.validate() + upload_notebook_cell_code_if_necessary( + deployment=deployment, stack=stack + ) prepare_model_versions(deployment) @@ -715,7 +720,11 @@ def _run( build=build_model, code_reference=code_reference, ): - code_path = build_utils.upload_code_if_necessary() + code_archive = code_utils.CodeArchive( + root=source_utils.get_source_root() + ) + logger.info("Archiving pipeline code...") + code_path = code_utils.upload_code_if_necessary(code_archive) deployment_request = PipelineDeploymentRequest( user=Client().active_user.id, diff --git a/src/zenml/new/pipelines/run_utils.py b/src/zenml/new/pipelines/run_utils.py index babcbe26867..fc6f7c736ce 100644 --- a/src/zenml/new/pipelines/run_utils.py +++ b/src/zenml/new/pipelines/run_utils.py @@ -1,5 +1,6 @@ """Utility functions for running pipelines.""" +import hashlib import time from collections import defaultdict from datetime import datetime @@ -9,6 +10,7 @@ from zenml import constants from zenml.client import Client from zenml.config.pipeline_run_configuration import PipelineRunConfiguration +from zenml.config.source import SourceType from zenml.config.step_configurations import StepConfigurationUpdate from zenml.enums import ExecutionStatus, ModelStages from zenml.logger import get_logger @@ -23,7 +25,7 @@ from zenml.new.pipelines.model_utils import NewModelRequest from zenml.orchestrators.utils import get_run_name from zenml.stack import Flavor, Stack -from zenml.utils import cloud_utils +from zenml.utils import cloud_utils, code_utils, notebook_utils from zenml.zen_stores.base_zen_store import BaseZenStore if TYPE_CHECKING: @@ -361,3 +363,67 @@ def validate_run_config_is_runnable_from_server( raise ValueError( "Can't set DockerSettings when running pipeline via Rest API." ) + + +def upload_notebook_cell_code_if_necessary( + deployment: "PipelineDeploymentBase", stack: "Stack" +) -> None: + """Upload notebook cell code if necessary. + + This function checks if any of the steps of the pipeline that will be + executed in a different process are defined in a notebook. If that is the + case, it will extract that notebook cell code into python files and upload + an archive of all the necessary files to the artifact store. + + Args: + deployment: The deployment. + stack: The stack on which the deployment will happen. + + Raises: + RuntimeError: If the code for one of the steps that will run out of + process cannot be extracted into a python file. + """ + code_archive = code_utils.CodeArchive(root=None) + should_upload = False + sources_that_require_upload = [] + + for step in deployment.step_configurations.values(): + source = step.spec.source + + if source.type == SourceType.NOTEBOOK: + if ( + stack.orchestrator.flavor != "local" + or step.config.step_operator + ): + should_upload = True + cell_code = getattr(step.spec.source, "_cell_code", None) + + # Code does not run in-process, which means we need to + # extract the step code into a python file + if not cell_code: + raise RuntimeError( + f"Unable to run step {step.config.name}. This step is " + "defined in a notebook and you're trying to run it " + "in a remote environment, but ZenML was not able to " + "detect the step code in the notebook. To fix " + "this error, define your step in a python file instead " + "of a notebook." + ) + + notebook_utils.warn_about_notebook_cell_magic_commands( + cell_code=cell_code + ) + + code_hash = hashlib.sha1(cell_code.encode()).hexdigest() # nosec + module_name = f"extracted_notebook_code_{code_hash}" + file_name = f"{module_name}.py" + code_archive.add_file(source=cell_code, destination=file_name) + + setattr(step.spec.source, "replacement_module", module_name) + sources_that_require_upload.append(source) + + if should_upload: + logger.info("Archiving notebook code...") + code_path = code_utils.upload_code_if_necessary(code_archive) + for source in sources_that_require_upload: + setattr(source, "code_path", code_path) diff --git a/src/zenml/steps/base_step.py b/src/zenml/steps/base_step.py index c6f32993e78..13e9930e6fc 100644 --- a/src/zenml/steps/base_step.py +++ b/src/zenml/steps/base_step.py @@ -54,6 +54,7 @@ ) from zenml.utils import ( dict_utils, + notebook_utils, pydantic_utils, settings_utils, source_code_utils, @@ -249,6 +250,8 @@ def __init__( ) self._verify_and_apply_init_params(*args, **kwargs) + notebook_utils.try_to_save_notebook_cell_code(self.source_object) + @abstractmethod def entrypoint(self, *args: Any, **kwargs: Any) -> Any: """Abstract method for core step logic. diff --git a/src/zenml/utils/code_utils.py b/src/zenml/utils/code_utils.py new file mode 100644 index 00000000000..fff033d044c --- /dev/null +++ b/src/zenml/utils/code_utils.py @@ -0,0 +1,244 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Code utilities.""" + +import hashlib +import os +import shutil +import tempfile +from pathlib import Path +from typing import IO, TYPE_CHECKING, Dict, Optional + +from zenml.client import Client +from zenml.io import fileio +from zenml.logger import get_logger +from zenml.utils import string_utils +from zenml.utils.archivable import Archivable + +if TYPE_CHECKING: + from git.repo.base import Repo + + +logger = get_logger(__name__) + + +class CodeArchive(Archivable): + """Code archive class. + + This class is used to archive user code before uploading it to the artifact + store. If the user code is stored in a Git repository, only files not + excluded by gitignores will be included in the archive. + """ + + def __init__(self, root: Optional[str] = None) -> None: + """Initialize the object. + + Args: + root: Root directory of the archive. + """ + super().__init__() + self._root = root + + @property + def git_repo(self) -> Optional["Repo"]: + """Git repository active at the code archive root. + + Returns: + The git repository if available. + """ + try: + # These imports fail when git is not installed on the machine + from git.exc import InvalidGitRepositoryError + from git.repo.base import Repo + except ImportError: + return None + + try: + git_repo = Repo(path=self._root, search_parent_directories=True) + except InvalidGitRepositoryError: + return None + + return git_repo + + def _get_all_files(self, archive_root: str) -> Dict[str, str]: + """Get all files inside the archive root. + + Args: + archive_root: The root directory from which to get all files. + + Returns: + All files inside the archive root. + """ + all_files = {} + for root, _, files in os.walk(archive_root): + for file in files: + file_path = os.path.join(root, file) + path_in_archive = os.path.relpath(file_path, archive_root) + all_files[path_in_archive] = file_path + + return all_files + + def get_files(self) -> Dict[str, str]: + """Gets all regular files that should be included in the archive. + + Raises: + RuntimeError: If the code archive would not include any files. + + Returns: + A dict {path_in_archive: path_on_filesystem} for all regular files + in the archive. + """ + if not self._root: + return {} + + all_files = {} + + if repo := self.git_repo: + try: + result = repo.git.ls_files( + "--cached", + "--others", + "--modified", + "--exclude-standard", + self._root, + ) + except Exception as e: + logger.warning( + "Failed to get non-ignored files from git: %s", str(e) + ) + all_files = self._get_all_files(archive_root=self._root) + else: + for file in result.split(): + file_path = os.path.join(repo.working_dir, file) + path_in_archive = os.path.relpath(file_path, self._root) + + if os.path.exists(file_path): + all_files[path_in_archive] = file_path + else: + all_files = self._get_all_files(archive_root=self._root) + + if not all_files: + raise RuntimeError( + "The code archive to be uploaded does not contain any files. " + "This is probably because all files in your source root " + f"`{self._root}` are ignored by a .gitignore file." + ) + + # Explicitly remove .zen directories as we write an updated version + # to disk everytime ZenML is called. This updates the mtime of the + # file, which invalidates the code upload caching. The values in + # the .zen directory are not needed anyway as we set them as + # environment variables. + all_files = { + path_in_archive: file_path + for path_in_archive, file_path in sorted(all_files.items()) + if ".zen" not in Path(path_in_archive).parts[:-1] + } + + return all_files + + def write_archive( + self, output_file: IO[bytes], use_gzip: bool = True + ) -> None: + """Writes an archive of the build context to the given file. + + Args: + output_file: The file to write the archive to. + use_gzip: Whether to use `gzip` to compress the file. + """ + super().write_archive(output_file=output_file, use_gzip=use_gzip) + archive_size = os.path.getsize(output_file.name) + if archive_size > 20 * 1024 * 1024: + logger.warning( + "Code archive size: `%s`. If you believe this is " + "unreasonably large, make sure to version your code in git and " + "ignore unnecessary files using a `.gitignore` file.", + string_utils.get_human_readable_filesize(archive_size), + ) + + +def upload_code_if_necessary(code_archive: CodeArchive) -> str: + """Upload code to the artifact store if necessary. + + This function computes a hash of the code to be uploaded, and if an archive + with the same hash already exists it will not re-upload but instead return + the path to the existing archive. + + Args: + code_archive: The code archive to upload. + + Returns: + The path where to archived code is uploaded. + """ + artifact_store = Client().active_stack.artifact_store + + with tempfile.NamedTemporaryFile( + mode="w+b", delete=False, suffix=".tar.gz" + ) as f: + code_archive.write_archive(f) + + hash_ = hashlib.sha1() # nosec + + while True: + data = f.read(64 * 1024) + if not data: + break + hash_.update(data) + + filename = f"{hash_.hexdigest()}.tar.gz" + upload_dir = os.path.join(artifact_store.path, "code_uploads") + fileio.makedirs(upload_dir) + upload_path = os.path.join(upload_dir, filename) + + if not fileio.exists(upload_path): + archive_size = string_utils.get_human_readable_filesize( + os.path.getsize(f.name) + ) + logger.info( + "Uploading code to `%s` (Size: %s).", upload_path, archive_size + ) + fileio.copy(f.name, upload_path) + logger.info("Code upload finished.") + else: + logger.info( + "Code already exists in artifact store, skipping upload." + ) + + if os.path.exists(f.name): + os.remove(f.name) + + return upload_path + + +def download_and_extract_code(code_path: str, extract_dir: str) -> None: + """Download and extract code. + + Args: + code_path: Path where the code is uploaded. + extract_dir: Directory where to code should be extracted to. + + Raises: + RuntimeError: If the code is stored in an artifact store which is + not active. + """ + artifact_store = Client().active_stack.artifact_store + + if not code_path.startswith(artifact_store.path): + raise RuntimeError("Code stored in different artifact store.") + + download_path = os.path.basename(code_path) + fileio.copy(code_path, download_path) + + shutil.unpack_archive(filename=download_path, extract_dir=extract_dir) + os.remove(download_path) diff --git a/src/zenml/utils/notebook_utils.py b/src/zenml/utils/notebook_utils.py new file mode 100644 index 00000000000..cf477718482 --- /dev/null +++ b/src/zenml/utils/notebook_utils.py @@ -0,0 +1,122 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Notebook utilities.""" + +from typing import Any, Callable, Optional, TypeVar, Union + +from zenml.environment import Environment +from zenml.logger import get_logger + +ZENML_NOTEBOOK_CELL_CODE_ATTRIBUTE_NAME = "__zenml_notebook_cell_code__" + +AnyObject = TypeVar("AnyObject", bound=Any) + +logger = get_logger(__name__) + + +def is_defined_in_notebook_cell(obj: Any) -> bool: + """Check whether an object is defined in a notebook cell. + + Args: + obj: The object to check. + + Returns: + Whether the object is defined in a notebook cell. + """ + if not Environment.in_notebook(): + return False + + module_name = getattr(obj, "__module__", None) + return module_name == "__main__" + + +def enable_notebook_code_extraction( + _obj: Optional["AnyObject"] = None, +) -> Union["AnyObject", Callable[["AnyObject"], "AnyObject"]]: + """Decorator to enable code extraction from notebooks. + + Args: + _obj: The class or function for which to enable code extraction. + + Returns: + The decorated class or function. + """ + + def inner_decorator(obj: "AnyObject") -> "AnyObject": + try_to_save_notebook_cell_code(obj) + return obj + + if _obj is None: + return inner_decorator + else: + return inner_decorator(_obj) + + +def get_active_notebook_cell_code() -> Optional[str]: + """Get the code of the currently active notebook cell. + + Returns: + The code of the currently active notebook cell. + """ + cell_code = None + try: + ipython = get_ipython() # type: ignore[name-defined] + cell_code = ipython.get_parent()["content"]["code"] + except (NameError, KeyError) as e: + logger.warning("Unable to extract cell code: %s.", str(e)) + + return cell_code + + +def try_to_save_notebook_cell_code(obj: Any) -> None: + """Try to save the notebook cell code for an object. + + Args: + obj: The object for which to save the notebook cell code. + """ + if is_defined_in_notebook_cell(obj): + if cell_code := get_active_notebook_cell_code(): + setattr( + obj, + ZENML_NOTEBOOK_CELL_CODE_ATTRIBUTE_NAME, + cell_code, + ) + + +def load_notebook_cell_code(obj: Any) -> Optional[str]: + """Load the notebook cell code for an object. + + Args: + obj: The object for which to load the cell code. + + Returns: + The notebook cell code if it was saved. + """ + return getattr(obj, ZENML_NOTEBOOK_CELL_CODE_ATTRIBUTE_NAME, None) + + +def warn_about_notebook_cell_magic_commands(cell_code: str) -> None: + """Warn about magic commands in the cell code. + + Args: + cell_code: The cell code. + """ + if any(line.startswith(("%", "!")) for line in cell_code.splitlines()): + logger.warning( + "Some lines in your notebook cell start with a `!` or `%` " + "character. Running a ZenML step remotely from a notebook " + "only works if the cell only contains python code. If any " + "of these lines contain Jupyter notebook magic commands, " + "remove them and try again." + ) diff --git a/src/zenml/utils/source_utils.py b/src/zenml/utils/source_utils.py index d2aa36dd822..9fe76c551cd 100644 --- a/src/zenml/utils/source_utils.py +++ b/src/zenml/utils/source_utils.py @@ -35,14 +35,19 @@ from zenml.config.source import ( CodeRepositorySource, DistributionPackageSource, + NotebookSource, Source, SourceType, ) from zenml.constants import ENV_ZENML_CUSTOM_SOURCE_ROOT from zenml.environment import Environment from zenml.logger import get_logger +from zenml.utils import notebook_utils logger = get_logger(__name__) + +ZENML_SOURCE_ATTRIBUTE_NAME = "__zenml_source__" + NoneType = type(None) NoneTypeSource = Source( module=NoneType.__module__, attribute="NoneType", type=SourceType.BUILTIN @@ -58,10 +63,13 @@ type=SourceType.BUILTIN, ) + _CUSTOM_SOURCE_ROOT: Optional[str] = os.getenv( ENV_ZENML_CUSTOM_SOURCE_ROOT, None ) +_SHARED_TEMPDIR: Optional[str] = None + def load(source: Union[Source, str]) -> Any: """Load a source or import path. @@ -105,6 +113,14 @@ def load(source: Union[Source, str]) -> Any: source.version, source.import_path, ) + elif source.type == SourceType.NOTEBOOK: + if Environment.in_notebook(): + # If we're in a notebook, we don't need to do anything as the + # loading from the __main__ module should work just fine. + pass + else: + notebook_source = NotebookSource.model_validate(dict(source)) + return _try_to_load_notebook_source(notebook_source) elif source.type in {SourceType.USER, SourceType.UNKNOWN}: # Unknown source might also refer to a user file, include source # root in python path just to be sure @@ -152,6 +168,9 @@ def resolve( return FunctionTypeSource elif obj is BuiltinFunctionType: return BuiltinFunctionTypeSource + elif source := getattr(obj, ZENML_SOURCE_ATTRIBUTE_NAME, None): + assert isinstance(source, Source) + return source elif isinstance(obj, ModuleType): module = obj attribute_name = None @@ -216,6 +235,16 @@ def resolve( else: # Fallback to an unknown source if we can't find the package source_type = SourceType.UNKNOWN + elif source_type == SourceType.NOTEBOOK: + source = NotebookSource( + module=module_name, + attribute=attribute_name, + type=source_type, + ) + # Private attributes are ignored by pydantic if passed in the __init__ + # method, so we set this afterwards + source._cell_code = notebook_utils.load_notebook_cell_code(obj) + return source return Source( module=module_name, attribute=attribute_name, type=source_type @@ -362,7 +391,7 @@ def get_source_type(module: ModuleType) -> SourceType: file_path = inspect.getfile(module) except (TypeError, OSError): if module.__name__ == "__main__" and Environment.in_notebook(): - return SourceType.USER + return SourceType.NOTEBOOK return SourceType.BUILTIN @@ -529,6 +558,85 @@ def _load_module( return importlib.import_module(module_name) +def _get_shared_temp_dir() -> str: + """Get path to a shared temporary directory. + + Returns: + Path to a shared temporary directory. + """ + global _SHARED_TEMPDIR + + if not _SHARED_TEMPDIR: + import tempfile + + _SHARED_TEMPDIR = tempfile.mkdtemp() + + return _SHARED_TEMPDIR + + +def _try_to_load_notebook_source(source: NotebookSource) -> Any: + """Helper function to load a notebook source outside of a notebook. + + Args: + source: The source to load. + + Raises: + RuntimeError: If the source can't be loaded. + + Returns: + The loaded object. + """ + if not source.code_path or not source.replacement_module: + raise RuntimeError( + f"Failed to load {source.import_path}. This object was defined in " + "a notebook and you're trying to load it outside of a notebook. " + "This is currently only enabled for ZenML steps." + ) + + extract_dir = _get_shared_temp_dir() + file_path = os.path.join(extract_dir, f"{source.replacement_module}.py") + + if not os.path.exists(file_path): + from zenml.utils import code_utils + + logger.info( + "Downloading notebook cell content from `%s` to load `%s`.", + source.code_path, + source.import_path, + ) + + code_utils.download_and_extract_code( + code_path=source.code_path, extract_dir=extract_dir + ) + + try: + module = _load_module( + module_name=source.replacement_module, import_root=extract_dir + ) + except ImportError: + raise RuntimeError( + f"Unable to load {source.import_path}. This object was defined in " + "a notebook and you're trying to load it outside of a notebook. " + "To enable this, ZenML extracts the code of your cell into a " + "python file. This means your cell code needs to be " + "self-contained:\n" + " * All required imports must be done in this cell, even if the " + "same imports already happen in previous notebook cells.\n" + " * The cell can't use any code defined in other notebook cells." + ) + + if source.attribute: + obj = getattr(module, source.attribute) + else: + obj = module + + # Store the original notebook source so resolving this object works as + # expected + setattr(obj, ZENML_SOURCE_ATTRIBUTE_NAME, source) + + return obj + + def _get_package_for_module(module_name: str) -> Optional[str]: """Get the package name for a module.