Skip to content

Commit c6e5459

Browse files
committed
Upload notebook code to artifact store instead
1 parent 0524dae commit c6e5459

File tree

11 files changed

+227
-348
lines changed

11 files changed

+227
-348
lines changed

src/zenml/config/source.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -234,17 +234,20 @@ class NotebookSource(Source):
234234
"""Source representing an object defined in a notebook.
235235
236236
Attributes:
237-
notebook_path: Path of the notebook (relative to the source root) in
238-
which the object is defined.
239-
cell_id: ID of the notebook cell in which the object is defined. This
240-
will only be set for objects which explicitly store this by calling
241-
`zenml.utils.notebook_utils.save_notebook_cell_id()`.
237+
code_path: Path where the notebook cell code for this source is
238+
uploaded.
239+
replacement_module: Name of the module from which this source should
240+
be loaded in case the code is not running in a notebook.
242241
"""
243242

244-
notebook_path: Optional[str] = None
245-
cell_id: Optional[str] = None
243+
code_path: Optional[str] = None
244+
replacement_module: Optional[str] = None
246245
type: SourceType = SourceType.NOTEBOOK
247246

247+
# Private attribute that is used to store the code but should not be
248+
# serialized
249+
_cell_code: Optional[str] = None
250+
248251
@field_validator("type")
249252
@classmethod
250253
def _validate_type(cls, value: SourceType) -> SourceType:

src/zenml/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,6 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
173173
)
174174
ENV_ZENML_IGNORE_FAILURE_HOOK = "ZENML_IGNORE_FAILURE_HOOK"
175175
ENV_ZENML_CUSTOM_SOURCE_ROOT = "ZENML_CUSTOM_SOURCE_ROOT"
176-
ENV_ZENML_NOTEBOOK_PATH = "ZENML_NOTEBOOK_PATH"
177176

178177
# ZenML Server environment variables
179178
ENV_ZENML_SERVER_PREFIX = "ZENML_SERVER_"

src/zenml/entrypoints/base_entrypoint_configuration.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import argparse
1717
import os
18-
import shutil
1918
import sys
2019
from abc import ABC, abstractmethod
2120
from typing import TYPE_CHECKING, Any, Dict, List, NoReturn, Set
@@ -27,9 +26,13 @@
2726
ENV_ZENML_REQUIRES_CODE_DOWNLOAD,
2827
handle_bool_env_var,
2928
)
30-
from zenml.io import fileio
3129
from zenml.logger import get_logger
32-
from zenml.utils import code_repository_utils, source_utils, uuid_utils
30+
from zenml.utils import (
31+
code_repository_utils,
32+
code_utils,
33+
source_utils,
34+
uuid_utils,
35+
)
3336

3437
if TYPE_CHECKING:
3538
from zenml.models import CodeReferenceResponse, PipelineDeploymentResponse
@@ -272,19 +275,14 @@ def download_code_from_artifact_store(self, code_path: str) -> None:
272275

273276
# Do not remove this line, we need to instantiate the artifact store to
274277
# register the filesystem needed for the file download
275-
artifact_store = Client().active_stack.artifact_store
276-
277-
if not code_path.startswith(artifact_store.path):
278-
raise RuntimeError("Code stored in different artifact store.")
278+
_ = Client().active_stack.artifact_store
279279

280280
extract_dir = os.path.abspath("code")
281281
os.makedirs(extract_dir)
282282

283-
download_path = os.path.basename(code_path)
284-
fileio.copy(code_path, download_path)
285-
286-
shutil.unpack_archive(filename=download_path, extract_dir=extract_dir)
287-
os.remove(download_path)
283+
code_utils.download_and_extract_code(
284+
code_path=code_path, extract_dir=extract_dir
285+
)
288286

289287
source_utils.set_custom_source_root(extract_dir)
290288
sys.path.insert(0, extract_dir)

src/zenml/materializers/base_materializer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def __new__(
100100

101101
from zenml.utils import notebook_utils
102102

103-
notebook_utils.try_to_save_notebook_cell_id(cls)
103+
notebook_utils.try_to_save_notebook_cell_code(cls)
104104

105105
return cls
106106

src/zenml/new/pipelines/build_utils.py

Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414
"""Pipeline build utilities."""
1515

1616
import hashlib
17-
import os
1817
import platform
19-
import tempfile
2018
from typing import (
2119
TYPE_CHECKING,
2220
Dict,
@@ -29,7 +27,6 @@
2927
import zenml
3028
from zenml.client import Client
3129
from zenml.code_repositories import BaseCodeRepository
32-
from zenml.io import fileio
3330
from zenml.logger import get_logger
3431
from zenml.models import (
3532
BuildItem,
@@ -40,9 +37,8 @@
4037
PipelineDeploymentBase,
4138
StackResponse,
4239
)
43-
from zenml.new.pipelines.code_archive import CodeArchive
4440
from zenml.stack import Stack
45-
from zenml.utils import source_utils, string_utils
41+
from zenml.utils import source_utils
4642
from zenml.utils.pipeline_docker_image_builder import (
4743
PipelineDockerImageBuilder,
4844
)
@@ -724,56 +720,3 @@ def should_upload_code(
724720
return True
725721

726722
return False
727-
728-
729-
def upload_code_if_necessary() -> str:
730-
"""Upload code to the artifact store if necessary.
731-
732-
This function computes a hash of the code to be uploaded, and if an archive
733-
with the same hash already exists it will not re-upload but instead return
734-
the path to the existing archive.
735-
736-
Returns:
737-
The path where to archived code is uploaded.
738-
"""
739-
logger.info("Archiving code...")
740-
741-
code_archive = CodeArchive(root=source_utils.get_source_root())
742-
artifact_store = Client().active_stack.artifact_store
743-
744-
with tempfile.NamedTemporaryFile(
745-
mode="w+b", delete=False, suffix=".tar.gz"
746-
) as f:
747-
code_archive.write_archive(f)
748-
749-
hash_ = hashlib.sha1() # nosec
750-
751-
while True:
752-
data = f.read(64 * 1024)
753-
if not data:
754-
break
755-
hash_.update(data)
756-
757-
filename = f"{hash_.hexdigest()}.tar.gz"
758-
upload_dir = os.path.join(artifact_store.path, "code_uploads")
759-
fileio.makedirs(upload_dir)
760-
upload_path = os.path.join(upload_dir, filename)
761-
762-
if not fileio.exists(upload_path):
763-
archive_size = string_utils.get_human_readable_filesize(
764-
os.path.getsize(f.name)
765-
)
766-
logger.info(
767-
"Uploading code to `%s` (Size: %s).", upload_path, archive_size
768-
)
769-
fileio.copy(f.name, upload_path)
770-
logger.info("Code upload finished.")
771-
else:
772-
logger.info(
773-
"Code already exists in artifact store, skipping upload."
774-
)
775-
776-
if os.path.exists(f.name):
777-
os.remove(f.name)
778-
779-
return upload_path

src/zenml/new/pipelines/pipeline.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@
7272
from zenml.new.pipelines.run_utils import (
7373
create_placeholder_run,
7474
deploy_pipeline,
75-
fail_if_running_remotely_with_notebook_not_possible,
7675
prepare_model_versions,
76+
upload_notebook_cell_code_if_necessary,
7777
)
7878
from zenml.stack import Stack
7979
from zenml.steps import BaseStep
@@ -83,6 +83,7 @@
8383
from zenml.steps.step_invocation import StepInvocation
8484
from zenml.utils import (
8585
code_repository_utils,
86+
code_utils,
8687
dashboard_utils,
8788
dict_utils,
8889
pydantic_utils,
@@ -669,7 +670,7 @@ def _run(
669670

670671
stack = Client().active_stack
671672
stack.validate()
672-
fail_if_running_remotely_with_notebook_not_possible(
673+
upload_notebook_cell_code_if_necessary(
673674
deployment=deployment, stack=stack
674675
)
675676

@@ -719,7 +720,11 @@ def _run(
719720
build=build_model,
720721
code_reference=code_reference,
721722
):
722-
code_path = build_utils.upload_code_if_necessary()
723+
code_archive = code_utils.CodeArchive(
724+
root=source_utils.get_source_root()
725+
)
726+
logger.info("Archiving pipeline code...")
727+
code_path = code_utils.upload_code_if_necessary(code_archive)
723728

724729
deployment_request = PipelineDeploymentRequest(
725730
user=Client().active_user.id,

src/zenml/new/pipelines/run_utils.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Utility functions for running pipelines."""
22

3+
import hashlib
34
import time
45
from collections import defaultdict
56
from datetime import datetime
@@ -24,7 +25,7 @@
2425
from zenml.new.pipelines.model_utils import NewModelRequest
2526
from zenml.orchestrators.utils import get_run_name
2627
from zenml.stack import Flavor, Stack
27-
from zenml.utils import cloud_utils, notebook_utils
28+
from zenml.utils import cloud_utils, code_utils, notebook_utils
2829
from zenml.zen_stores.base_zen_store import BaseZenStore
2930

3031
if TYPE_CHECKING:
@@ -364,42 +365,63 @@ def validate_run_config_is_runnable_from_server(
364365
)
365366

366367

367-
def fail_if_running_remotely_with_notebook_not_possible(
368+
def upload_notebook_cell_code_if_necessary(
368369
deployment: "PipelineDeploymentBase", stack: "Stack"
369370
) -> None:
370-
"""Fail if running the deployment on the stack is not possible.
371+
"""Upload notebook cell code if necessary.
371372
372373
This function checks if any of the steps of the pipeline that will be
373374
executed in a different process are defined in a notebook. If that is the
374-
case, it will raise an exception if the active notebook path can't be
375-
determined.
375+
case, it will extract that notebook cell code into python files and upload
376+
an archive of all the necessary files to the artifact store.
376377
377378
Args:
378379
deployment: The deployment.
379380
stack: The stack on which the deployment will happen.
380381
381382
Raises:
382-
RuntimeError: If the active notebook can't be determined and steps that
383-
are defined in that notebook should be executed out of process.
383+
RuntimeError: If the code for one of the steps that will run out of
384+
process cannot be extracted into a python file.
384385
"""
386+
code_archive = code_utils.CodeArchive(root=None)
387+
should_upload = False
388+
sources_that_require_upload = []
389+
385390
for step in deployment.step_configurations.values():
386391
if step.spec.source.type == SourceType.NOTEBOOK:
387392
if (
388393
stack.orchestrator.flavor != "local"
389394
or step.config.step_operator
390395
):
396+
should_upload = True
397+
cell_code = getattr(step.spec.source, "_cell_code", None)
398+
391399
# Code does not run in-process, which means we need to
392-
# extract it from the notebook in the execution
393-
# environment -> verify that we're able to detect the
394-
# active notebook
395-
if not notebook_utils.get_active_notebook_path():
400+
# extract the step code into a python file
401+
if not cell_code:
396402
raise RuntimeError(
397403
f"Unable to run step {step.config.name}. This step is "
398404
"defined in a notebook and you're trying to run it "
399405
"in a remote environment, but ZenML was not able to "
400-
"detect the notebook that you're running in. To fix "
401-
"this error, set the "
402-
f"{constants.ENV_ZENML_NOTEBOOK_PATH} environment "
403-
"variable to the path of the active notebook or define "
404-
"your step in a python file instead of a notebook."
406+
"detect the step code in the notebook. To fix "
407+
"this error, define your step in a python file instead "
408+
"of a notebook."
405409
)
410+
411+
notebook_utils.warn_about_notebook_cell_magic_commands(
412+
cell_code=cell_code
413+
)
414+
415+
code_hash = hashlib.sha1(cell_code.encode()).hexdigest() # nosec
416+
module_name = f"extracted_notebook_code_{code_hash}"
417+
file_name = f"{module_name}.py"
418+
code_archive.add_file(source=cell_code, destination=file_name)
419+
420+
setattr(step.spec.source, "replacement_module", module_name)
421+
sources_that_require_upload.append(step.spec.source)
422+
423+
if should_upload:
424+
logger.info("Archiving notebook code...")
425+
code_path = code_utils.upload_code_if_necessary(code_archive)
426+
for source in sources_that_require_upload:
427+
setattr(source, "code_path", code_path)

src/zenml/steps/base_step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def __init__(
250250
)
251251
self._verify_and_apply_init_params(*args, **kwargs)
252252

253-
notebook_utils.try_to_save_notebook_cell_id(self.source_object)
253+
notebook_utils.try_to_save_notebook_cell_code(self.source_object)
254254

255255
@abstractmethod
256256
def entrypoint(self, *args: Any, **kwargs: Any) -> Any:

0 commit comments

Comments
 (0)