Skip to content

Introduce JOB_DISCONNECTED_RETRY_TIMEOUT #2627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class JobTerminationReason(str, Enum):
# Set by the server
FAILED_TO_START_DUE_TO_NO_CAPACITY = "failed_to_start_due_to_no_capacity"
INTERRUPTED_BY_NO_CAPACITY = "interrupted_by_no_capacity"
INSTANCE_UNREACHABLE = "instance_unreachable"
WAITING_INSTANCE_LIMIT_EXCEEDED = "waiting_instance_limit_exceeded"
WAITING_RUNNER_LIMIT_EXCEEDED = "waiting_runner_limit_exceeded"
TERMINATED_BY_USER = "terminated_by_user"
Expand All @@ -126,6 +127,7 @@ def to_status(self) -> JobStatus:
mapping = {
self.FAILED_TO_START_DUE_TO_NO_CAPACITY: JobStatus.FAILED,
self.INTERRUPTED_BY_NO_CAPACITY: JobStatus.FAILED,
self.INSTANCE_UNREACHABLE: JobStatus.FAILED,
self.WAITING_INSTANCE_LIMIT_EXCEEDED: JobStatus.FAILED,
self.WAITING_RUNNER_LIMIT_EXCEEDED: JobStatus.FAILED,
self.TERMINATED_BY_USER: JobStatus.TERMINATED,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from collections.abc import Iterable
from datetime import timedelta
from datetime import timedelta, timezone
from typing import Dict, List, Optional

from sqlalchemy import select
Expand Down Expand Up @@ -71,6 +71,12 @@
logger = get_logger(__name__)


# Minimum time before terminating active job in case of connectivity issues.
# Should be sufficient to survive most problems caused by
# the server network flickering and providers' glitches.
JOB_DISCONNECTED_RETRY_TIMEOUT = timedelta(minutes=2)


async def process_running_jobs(batch_size: int = 1):
tasks = []
for _ in range(batch_size):
Expand Down Expand Up @@ -202,7 +208,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
user_ssh_key = run.run_spec.ssh_key_pub.strip()
public_keys = [project.ssh_public_key.strip(), user_ssh_key]
if job_provisioning_data.backend == BackendType.LOCAL:
# No need to update ~/.ssh/authorized_keys when running shim localy
# No need to update ~/.ssh/authorized_keys when running shim locally
user_ssh_key = ""
success = await common_utils.run_async(
_process_provisioning_with_shim,
Expand Down Expand Up @@ -299,19 +305,38 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
run_model,
job_model,
)
if not success:
job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY

if not success: # kill the job
logger.warning(
"%s: failed because runner is not available or return an error, age=%s",
fmt(job_model),
job_submission.age,
)
job_model.status = JobStatus.TERMINATING
if not job_model.termination_reason:
job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY
# job will be terminated and instance will be emptied by process_terminating_jobs
if success:
job_model.disconnected_at = None
else:
if job_model.termination_reason:
logger.warning(
"%s: failed because shim/runner returned an error, age=%s",
fmt(job_model),
job_submission.age,
)
job_model.status = JobStatus.TERMINATING
# job will be terminated and instance will be emptied by process_terminating_jobs
else:
# No job_model.termination_reason set means ssh connection failed
if job_model.disconnected_at is None:
job_model.disconnected_at = common_utils.get_current_datetime()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that runner_ssh_tunnel also does 3 blocking retries, and each can take 15 seconds. So this disconnected_at can be imprecise, since we only set it after the runner_ssh_tunnel retries.

But I'm not sure we need to do anything about it. I thought about disabling the runner_ssh_tunnel retries in favor of the new disconnected_at retries, but having retry logic in both places may further improve stability.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing out. Let's keep both! :)

if _should_terminate_job_due_to_disconnect(job_model):
logger.warning(
"%s: failed because instance is unreachable, age=%s",
fmt(job_model),
job_submission.age,
)
# TODO: Replace with JobTerminationReason.INSTANCE_UNREACHABLE in 0.20 or
# when CLI <= 0.19.8 is no longer supported
job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY
job_model.status = JobStatus.TERMINATING
else:
logger.warning(
"%s: is unreachable, waiting for the instance to become reachable again, age=%s",
fmt(job_model),
job_submission.age,
)

if (
initial_status != job_model.status
Expand Down Expand Up @@ -692,6 +717,15 @@ def _terminate_if_inactivity_duration_exceeded(
)


def _should_terminate_job_due_to_disconnect(job_model: JobModel) -> bool:
if job_model.disconnected_at is None:
return False
return (
common_utils.get_current_datetime()
> job_model.disconnected_at.replace(tzinfo=timezone.utc) + JOB_DISCONNECTED_RETRY_TIMEOUT
)


async def _check_gpu_utilization(session: AsyncSession, job_model: JobModel, job: Job) -> None:
policy = job.job_spec.utilization_policy
if policy is None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Add JobModel.disconnected_at

Revision ID: 20166748b60c
Revises: 6c1a9d6530ee
Create Date: 2025-05-13 16:24:32.496578

"""

import sqlalchemy as sa
from alembic import op
from alembic_postgresql_enum import TableReference

import dstack._internal.server.models

# revision identifiers, used by Alembic.
revision = "20166748b60c"
down_revision = "6c1a9d6530ee"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("jobs", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"disconnected_at", dstack._internal.server.models.NaiveDateTime(), nullable=True
)
)

op.sync_enum_values(
enum_schema="public",
enum_name="jobterminationreason",
new_values=[
"FAILED_TO_START_DUE_TO_NO_CAPACITY",
"INTERRUPTED_BY_NO_CAPACITY",
"INSTANCE_UNREACHABLE",
"WAITING_INSTANCE_LIMIT_EXCEEDED",
"WAITING_RUNNER_LIMIT_EXCEEDED",
"TERMINATED_BY_USER",
"VOLUME_ERROR",
"GATEWAY_ERROR",
"SCALED_DOWN",
"DONE_BY_RUNNER",
"ABORTED_BY_USER",
"TERMINATED_BY_SERVER",
"INACTIVITY_DURATION_EXCEEDED",
"TERMINATED_DUE_TO_UTILIZATION_POLICY",
"CONTAINER_EXITED_WITH_ERROR",
"PORTS_BINDING_FAILED",
"CREATING_CONTAINER_ERROR",
"EXECUTOR_ERROR",
"MAX_DURATION_EXCEEDED",
],
affected_columns=[
TableReference(
table_schema="public", table_name="jobs", column_name="termination_reason"
)
],
enum_values_to_rename=[],
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.sync_enum_values(
enum_schema="public",
enum_name="jobterminationreason",
new_values=[
"FAILED_TO_START_DUE_TO_NO_CAPACITY",
"INTERRUPTED_BY_NO_CAPACITY",
"WAITING_INSTANCE_LIMIT_EXCEEDED",
"WAITING_RUNNER_LIMIT_EXCEEDED",
"TERMINATED_BY_USER",
"VOLUME_ERROR",
"GATEWAY_ERROR",
"SCALED_DOWN",
"DONE_BY_RUNNER",
"ABORTED_BY_USER",
"TERMINATED_BY_SERVER",
"INACTIVITY_DURATION_EXCEEDED",
"TERMINATED_DUE_TO_UTILIZATION_POLICY",
"CONTAINER_EXITED_WITH_ERROR",
"PORTS_BINDING_FAILED",
"CREATING_CONTAINER_ERROR",
"EXECUTOR_ERROR",
"MAX_DURATION_EXCEEDED",
],
affected_columns=[
TableReference(
table_schema="public", table_name="jobs", column_name="termination_reason"
)
],
enum_values_to_rename=[],
)
with op.batch_alter_table("jobs", schema=None) as batch_op:
batch_op.drop_column("disconnected_at")

# ### end Alembic commands ###
5 changes: 4 additions & 1 deletion src/dstack/_internal/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,9 @@ class JobModel(BaseModel):
Enum(JobTerminationReason)
)
termination_reason_message: Mapped[Optional[str]] = mapped_column(Text)
# `disconnected_at` stores the first time of connectivity issues with the instance.
# Resets every time connectivity is restored.
disconnected_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime)
exit_status: Mapped[Optional[int]] = mapped_column(Integer)
job_spec_data: Mapped[str] = mapped_column(Text)
job_provisioning_data: Mapped[Optional[str]] = mapped_column(Text)
Expand All @@ -391,7 +394,7 @@ class JobModel(BaseModel):
remove_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime)
volumes_detached_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime)
# `instance_assigned` means instance assignment was done.
# if `instance_assigned` is True and `instance` is None, no instance was assiged.
# if `instance_assigned` is True and `instance` is None, no instance was assigned.
instance_assigned: Mapped[bool] = mapped_column(Boolean, default=False)
instance_id: Mapped[Optional[uuid.UUID]] = mapped_column(
ForeignKey("instances.id", ondelete="CASCADE")
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ async def create_job(
job_num: int = 0,
replica_num: int = 0,
instance_assigned: bool = False,
disconnected_at: Optional[datetime] = None,
) -> JobModel:
run_spec = RunSpec.parse_raw(run.run_spec)
job_spec = (await get_job_specs_from_run_spec(run_spec, replica_num=replica_num))[0]
Expand All @@ -323,6 +324,7 @@ async def create_job(
instance=instance,
instance_assigned=instance_assigned,
used_instance_id=instance.id if instance is not None else None,
disconnected_at=disconnected_at,
)
session.add(job)
await session.commit()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Optional
from unittest.mock import MagicMock, Mock, patch
Expand Down Expand Up @@ -490,6 +490,17 @@ async def test_pulling_shim_failed(self, test_db, session: AsyncSession):
assert SSHTunnelMock.call_count == 3
await session.refresh(job)
assert job is not None
assert job.disconnected_at is not None
assert job.status == JobStatus.PULLING
with (
patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock,
patch("dstack._internal.server.services.runner.ssh.time.sleep"),
freeze_time(job.disconnected_at + timedelta(minutes=5)),
):
SSHTunnelMock.side_effect = SSHError
await process_running_jobs()
assert SSHTunnelMock.call_count == 3
await session.refresh(job)
assert job.status == JobStatus.TERMINATING
assert job.termination_reason == JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY
assert job.remove_at is None
Expand Down
Loading