Skip to content

Commit e174db4

Browse files
authored
Add Caching in Kubernetes orchestrator entrypoint (#3703)
* Cache step in kubernetes orchestrator pod * Add caching in kubernetes orchestrator entrypoint * Linting * Add option to disable orchestrator pod caching
1 parent 26430a9 commit e174db4

File tree

4 files changed

+96
-46
lines changed

4 files changed

+96
-46
lines changed

src/zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ class KubernetesOrchestratorSettings(BaseSettings):
6767
ttl_seconds_after_finished: The amount of seconds to keep finished jobs
6868
before deleting them. This only applies to jobs created when
6969
scheduling a pipeline.
70+
prevent_orchestrator_pod_caching: If `True`, the orchestrator pod will
71+
not try to compute cached steps before starting the step pods.
7072
"""
7173

7274
synchronous: bool = True
@@ -85,6 +87,7 @@ class KubernetesOrchestratorSettings(BaseSettings):
8587
successful_jobs_history_limit: Optional[NonNegativeInt] = None
8688
failed_jobs_history_limit: Optional[NonNegativeInt] = None
8789
ttl_seconds_after_finished: Optional[NonNegativeInt] = None
90+
prevent_orchestrator_pod_caching: bool = False
8891

8992

9093
class KubernetesOrchestratorConfig(

src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -74,23 +74,37 @@ def main() -> None:
7474
orchestrator_pod_name = socket.gethostname()
7575

7676
client = Client()
77+
active_stack = client.active_stack
78+
orchestrator = active_stack.orchestrator
79+
assert isinstance(orchestrator, KubernetesOrchestrator)
7780

78-
deployment_config = client.get_deployment(args.deployment_id)
81+
deployment = client.get_deployment(args.deployment_id)
82+
pipeline_settings = cast(
83+
KubernetesOrchestratorSettings,
84+
orchestrator.get_settings(deployment),
85+
)
7986

80-
pipeline_dag = {
81-
step_name: step.spec.upstream_steps
82-
for step_name, step in deployment_config.step_configurations.items()
83-
}
8487
step_command = StepEntrypointConfiguration.get_entrypoint_command()
8588

86-
active_stack = client.active_stack
89+
if args.run_id and not pipeline_settings.prevent_orchestrator_pod_caching:
90+
from zenml.orchestrators import cache_utils
91+
92+
run_required = (
93+
cache_utils.create_cached_step_runs_and_prune_deployment(
94+
deployment=deployment,
95+
pipeline_run=client.get_pipeline_run(args.run_id),
96+
stack=active_stack,
97+
)
98+
)
99+
100+
if not run_required:
101+
return
102+
87103
mount_local_stores = active_stack.orchestrator.config.is_local
88104

89105
# Get a Kubernetes client from the active Kubernetes orchestrator, but
90106
# override the `incluster` setting to `True` since we are running inside
91107
# the Kubernetes cluster.
92-
orchestrator = active_stack.orchestrator
93-
assert isinstance(orchestrator, KubernetesOrchestrator)
94108
kube_client = orchestrator.get_kube_client(incluster=True)
95109
core_api = k8s_client.CoreV1Api(kube_client)
96110

@@ -121,7 +135,7 @@ def run_step_on_kubernetes(step_name: str) -> None:
121135
Raises:
122136
Exception: If the pod fails to start.
123137
"""
124-
step_config = deployment_config.step_configurations[step_name].config
138+
step_config = deployment.step_configurations[step_name].config
125139
settings = step_config.settings.get("orchestrator.kubernetes", None)
126140
settings = KubernetesOrchestratorSettings.model_validate(
127141
settings.model_dump() if settings else {}
@@ -147,10 +161,10 @@ def run_step_on_kubernetes(step_name: str) -> None:
147161
)
148162

149163
image = KubernetesOrchestrator.get_image(
150-
deployment=deployment_config, step_name=step_name
164+
deployment=deployment, step_name=step_name
151165
)
152166
step_args = StepEntrypointConfiguration.get_entrypoint_arguments(
153-
step_name=step_name, deployment_id=deployment_config.id
167+
step_name=step_name, deployment_id=deployment.id
154168
)
155169

156170
# We set some default minimum memory resource requests for the step pod
@@ -165,9 +179,7 @@ def run_step_on_kubernetes(step_name: str) -> None:
165179

166180
if orchestrator.config.pass_zenml_token_as_secret:
167181
env.pop("ZENML_STORE_API_TOKEN", None)
168-
secret_name = orchestrator.get_token_secret_name(
169-
deployment_config.id
170-
)
182+
secret_name = orchestrator.get_token_secret_name(deployment.id)
171183
pod_settings.env.append(
172184
{
173185
"name": "ZENML_STORE_API_TOKEN",
@@ -184,7 +196,7 @@ def run_step_on_kubernetes(step_name: str) -> None:
184196
pod_manifest = build_pod_manifest(
185197
pod_name=pod_name,
186198
run_name=args.run_name,
187-
pipeline_name=deployment_config.pipeline_configuration.name,
199+
pipeline_name=deployment.pipeline_configuration.name,
188200
image_name=image,
189201
command=step_command,
190202
args=step_args,
@@ -251,8 +263,8 @@ def finalize_run(node_states: Dict[str, NodeStatus]) -> None:
251263

252264
pipeline_runs = client.list_pipeline_runs(
253265
hydrate=True,
254-
project=deployment_config.project_id,
255-
deployment_id=deployment_config.id,
266+
project=deployment.project_id,
267+
deployment_id=deployment.id,
256268
**list_args,
257269
)
258270
if not len(pipeline_runs):
@@ -298,27 +310,26 @@ def finalize_run(node_states: Dict[str, NodeStatus]) -> None:
298310
parallel_node_startup_waiting_period = (
299311
orchestrator.config.parallel_step_startup_waiting_period or 0.0
300312
)
301-
settings = cast(
302-
KubernetesOrchestratorSettings,
303-
orchestrator.get_settings(deployment_config),
304-
)
313+
314+
pipeline_dag = {
315+
step_name: step.spec.upstream_steps
316+
for step_name, step in deployment.step_configurations.items()
317+
}
305318
try:
306319
ThreadedDagRunner(
307320
dag=pipeline_dag,
308321
run_fn=run_step_on_kubernetes,
309322
finalize_fn=finalize_run,
310323
parallel_node_startup_waiting_period=parallel_node_startup_waiting_period,
311-
max_parallelism=settings.max_parallelism,
324+
max_parallelism=pipeline_settings.max_parallelism,
312325
).run()
313326
logger.info("Orchestration pod completed.")
314327
finally:
315328
if (
316329
orchestrator.config.pass_zenml_token_as_secret
317-
and deployment_config.schedule is None
330+
and deployment.schedule is None
318331
):
319-
secret_name = orchestrator.get_token_secret_name(
320-
deployment_config.id
321-
)
332+
secret_name = orchestrator.get_token_secret_name(deployment.id)
322333
try:
323334
kube_utils.delete_secret(
324335
core_api=core_api,

src/zenml/orchestrators/base_orchestrator.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -220,29 +220,18 @@ def run(
220220
and not deployment.schedule
221221
and not prevent_client_side_caching
222222
):
223-
from zenml.orchestrators import step_run_utils
223+
from zenml.orchestrators import cache_utils
224224

225-
cached_invocations = step_run_utils.create_cached_step_runs(
226-
deployment=deployment,
227-
pipeline_run=placeholder_run,
228-
stack=stack,
225+
run_required = (
226+
cache_utils.create_cached_step_runs_and_prune_deployment(
227+
deployment=deployment,
228+
pipeline_run=placeholder_run,
229+
stack=stack,
230+
)
229231
)
230232

231-
for invocation_id in cached_invocations:
232-
# Remove the cached step invocations from the deployment so
233-
# the orchestrator does not try to run them
234-
deployment.step_configurations.pop(invocation_id)
235-
236-
for step in deployment.step_configurations.values():
237-
for invocation_id in cached_invocations:
238-
if invocation_id in step.spec.upstream_steps:
239-
step.spec.upstream_steps.remove(invocation_id)
240-
241-
if len(deployment.step_configurations) == 0:
242-
# All steps were cached, we update the pipeline run status and
243-
# don't actually use the orchestrator to run the pipeline
233+
if not run_required:
244234
self._cleanup_run()
245-
logger.info("All steps of the pipeline run were cached.")
246235
return
247236
else:
248237
logger.debug("Skipping client-side caching.")

src/zenml/orchestrators/cache_utils.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,20 @@
1919
from zenml.client import Client
2020
from zenml.enums import ExecutionStatus, SorterOps
2121
from zenml.logger import get_logger
22+
from zenml.orchestrators import step_run_utils
2223

2324
if TYPE_CHECKING:
2425
from uuid import UUID
2526

2627
from zenml.artifact_stores import BaseArtifactStore
2728
from zenml.config.step_configurations import Step
28-
from zenml.models import StepRunResponse
29+
from zenml.models import (
30+
PipelineDeploymentResponse,
31+
PipelineRunResponse,
32+
StepRunResponse,
33+
)
34+
from zenml.stack import Stack
35+
2936

3037
logger = get_logger(__name__)
3138

@@ -127,3 +134,43 @@ def get_cached_step_run(cache_key: str) -> Optional["StepRunResponse"]:
127134
if cache_candidates:
128135
return cache_candidates[0]
129136
return None
137+
138+
139+
def create_cached_step_runs_and_prune_deployment(
140+
deployment: "PipelineDeploymentResponse",
141+
pipeline_run: "PipelineRunResponse",
142+
stack: "Stack",
143+
) -> bool:
144+
"""Create cached step runs and prune the cached steps from the deployment.
145+
146+
Args:
147+
deployment: The deployment of the pipeline run.
148+
pipeline_run: The pipeline run for which to create the step runs.
149+
stack: The stack on which the pipeline run is happening.
150+
151+
Returns:
152+
Whether an actual pipeline run is still required.
153+
"""
154+
cached_invocations = step_run_utils.create_cached_step_runs(
155+
deployment=deployment,
156+
pipeline_run=pipeline_run,
157+
stack=stack,
158+
)
159+
160+
for invocation_id in cached_invocations:
161+
# Remove the cached step invocations from the deployment so
162+
# the orchestrator does not try to run them
163+
deployment.step_configurations.pop(invocation_id)
164+
165+
for step in deployment.step_configurations.values():
166+
for invocation_id in cached_invocations:
167+
if invocation_id in step.spec.upstream_steps:
168+
step.spec.upstream_steps.remove(invocation_id)
169+
170+
if len(deployment.step_configurations) == 0:
171+
# All steps were cached, we update the pipeline run status and
172+
# don't actually use the orchestrator to run the pipeline
173+
logger.info("All steps of the pipeline run were cached.")
174+
return False
175+
176+
return True

0 commit comments

Comments
 (0)