Skip to content

Commit fe3a6d2

Browse files
authored
Implement run priorities (#2635)
* Implement run priorities * Test runs priority * Document priorities for tasks * Exclude priority for backward compatibility
1 parent b4325b6 commit fe3a6d2

File tree

10 files changed

+178
-12
lines changed

10 files changed

+178
-12
lines changed

docs/docs/concepts/tasks.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,34 @@ retry:
426426
If one job of a multi-node task fails with retry enabled,
427427
`dstack` will stop all the jobs and resubmit the run.
428428

429+
### Priority
430+
431+
Be default, submitted runs are scheduled in the order they were submitted.
432+
When compute resources are limited, you may want to prioritize some runs over others.
433+
This can be done by specifying the [`priority`](../reference/dstack.yml/task.md) property in the run configuration:
434+
435+
<div editor-title=".dstack.yml">
436+
437+
```yaml
438+
type: task
439+
name: train
440+
441+
python: "3.10"
442+
443+
# Commands of the task
444+
commands:
445+
- pip install -r fine-tuning/qlora/requirements.txt
446+
- python fine-tuning/qlora/train.py
447+
448+
priority: 50
449+
```
450+
451+
</div>
452+
453+
`dstack` tries to provision runs with higher priority first.
454+
Note that if a high priority run cannot be scheduled,
455+
it does not block other runs with lower priority from scheduling.
456+
429457
--8<-- "docs/concepts/snippets/manage-fleets.ext"
430458

431459
--8<-- "docs/concepts/snippets/manage-runs.ext"

src/dstack/_internal/core/models/configurations.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
MAX_INT64 = 2**63 - 1
2424
SERVICE_HTTPS_DEFAULT = True
2525
STRIP_PREFIX_DEFAULT = True
26+
RUN_PRIOTIRY_MIN = 0
27+
RUN_PRIOTIRY_MAX = 100
28+
RUN_PRIORITY_DEFAULT = 0
2629

2730

2831
class RunConfigurationType(str, Enum):
@@ -221,14 +224,26 @@ class BaseRunConfiguration(CoreModel):
221224
)
222225
),
223226
] = None
224-
# deprecated since 0.18.31; task, service -- no effect; dev-environment -- executed right before `init`
225-
setup: CommandsList = []
226227
resources: Annotated[
227228
ResourcesSpec, Field(description="The resources requirements to run the configuration")
228229
] = ResourcesSpec()
230+
priority: Annotated[
231+
Optional[int],
232+
Field(
233+
ge=RUN_PRIOTIRY_MIN,
234+
le=RUN_PRIOTIRY_MAX,
235+
description=(
236+
f"The priority of the run, an integer between `{RUN_PRIOTIRY_MIN}` and `{RUN_PRIOTIRY_MAX}`."
237+
" `dstack` tries to provision runs with higher priority first."
238+
f" Defaults to `{RUN_PRIORITY_DEFAULT}`"
239+
),
240+
),
241+
] = None
229242
volumes: Annotated[
230243
List[Union[MountPoint, str]], Field(description="The volumes mount points")
231244
] = []
245+
# deprecated since 0.18.31; task, service -- no effect; dev-environment -- executed right before `init`
246+
setup: CommandsList = []
232247

233248
@validator("python", pre=True, always=True)
234249
def convert_python(cls, v, values) -> Optional[PythonVersion]:

src/dstack/_internal/server/background/tasks/process_submitted_jobs.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,20 @@ async def _process_next_submitted_job():
9393
async with lock:
9494
res = await session.execute(
9595
select(JobModel)
96+
.join(JobModel.run)
9697
.where(
9798
JobModel.status == JobStatus.SUBMITTED,
9899
JobModel.id.not_in(lockset),
99100
)
100-
.order_by(JobModel.last_processed_at.asc())
101+
# Jobs are process in FIFO sorted by priority globally,
102+
# thus runs from different project can "overtake" each other by using higher priorities.
103+
# That's not a big problem as long as projects do not compete for the same compute resources.
104+
# Jobs with lower priorities from other projects will be processed without major lag
105+
# as long as new higher priority runs are not constantly submitted.
106+
# TODO: Consider processing jobs from different projects fairly/round-robin
107+
# Fully fair processing can be tricky to implement via the current DB queue as
108+
# there can be many projects and we are limited by the max DB connections.
109+
.order_by(RunModel.priority.desc(), JobModel.last_processed_at.asc())
101110
.limit(1)
102111
.with_for_update(skip_locked=True)
103112
)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Add RunModel.priority
2+
3+
Revision ID: bca2fdf130bf
4+
Revises: 20166748b60c
5+
Create Date: 2025-05-14 15:24:21.269775
6+
7+
"""
8+
9+
import sqlalchemy as sa
10+
from alembic import op
11+
12+
# revision identifiers, used by Alembic.
13+
revision = "bca2fdf130bf"
14+
down_revision = "20166748b60c"
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade() -> None:
20+
# ### commands auto generated by Alembic - please adjust! ###
21+
with op.batch_alter_table("runs", schema=None) as batch_op:
22+
batch_op.add_column(sa.Column("priority", sa.Integer(), nullable=True))
23+
batch_op.execute("UPDATE runs SET priority = 0")
24+
with op.batch_alter_table("runs", schema=None) as batch_op:
25+
batch_op.alter_column("priority", nullable=False)
26+
# ### end Alembic commands ###
27+
28+
29+
def downgrade() -> None:
30+
# ### commands auto generated by Alembic - please adjust! ###
31+
with op.batch_alter_table("runs", schema=None) as batch_op:
32+
batch_op.drop_column("priority")
33+
34+
# ### end Alembic commands ###

src/dstack/_internal/server/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ class RunModel(BaseModel):
348348
resubmission_attempt: Mapped[int] = mapped_column(Integer, default=0)
349349
run_spec: Mapped[str] = mapped_column(Text)
350350
service_spec: Mapped[Optional[str]] = mapped_column(Text)
351+
priority: Mapped[int] = mapped_column(Integer, default=0)
351352

352353
jobs: Mapped[List["JobModel"]] = relationship(
353354
back_populates="run", lazy="selectin", order_by="[JobModel.replica_num, JobModel.job_num]"

src/dstack/_internal/server/services/runs.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
ServerClientError,
1717
)
1818
from dstack._internal.core.models.common import ApplyAction
19-
from dstack._internal.core.models.configurations import AnyRunConfiguration
19+
from dstack._internal.core.models.configurations import RUN_PRIORITY_DEFAULT, AnyRunConfiguration
2020
from dstack._internal.core.models.instances import (
2121
InstanceAvailability,
2222
InstanceOfferWithAvailability,
@@ -434,7 +434,12 @@ async def apply_plan(
434434
# FIXME: potentially long write transaction
435435
# Avoid getting run_model after update
436436
await session.execute(
437-
update(RunModel).where(RunModel.id == current_resource.id).values(run_spec=run_spec.json())
437+
update(RunModel)
438+
.where(RunModel.id == current_resource.id)
439+
.values(
440+
run_spec=run_spec.json(),
441+
priority=run_spec.configuration.priority,
442+
)
438443
)
439444
run = await get_run_by_name(
440445
session=session,
@@ -495,6 +500,7 @@ async def submit_run(
495500
status=RunStatus.SUBMITTED,
496501
run_spec=run_spec.json(),
497502
last_processed_at=submitted_at,
503+
priority=run_spec.configuration.priority,
498504
)
499505
session.add(run_model)
500506

@@ -852,6 +858,13 @@ def _get_job_submission_cost(job_submission: JobSubmission) -> float:
852858

853859

854860
def _validate_run_spec_and_set_defaults(run_spec: RunSpec):
861+
# This function may set defaults for null run_spec values,
862+
# although most defaults are resolved when building job_spec
863+
# so that we can keep both the original user-supplied value (null in run_spec)
864+
# and the default in job_spec.
865+
# If a property is stored in job_spec - resolve the default there.
866+
# Server defaults are preferable over client defaults so that
867+
# the defaults depend on the server version, not the client version.
855868
if run_spec.run_name is not None:
856869
validate_dstack_resource_name(run_spec.run_name)
857870
for mount_point in run_spec.configuration.volumes:
@@ -875,11 +888,14 @@ def _validate_run_spec_and_set_defaults(run_spec: RunSpec):
875888
raise ServerClientError(
876889
f"Maximum utilization_policy.time_window is {settings.SERVER_METRICS_RUNNING_TTL_SECONDS}s"
877890
)
891+
if run_spec.configuration.priority is None:
892+
run_spec.configuration.priority = RUN_PRIORITY_DEFAULT
878893
set_resources_defaults(run_spec.configuration.resources)
879894

880895

881896
_UPDATABLE_SPEC_FIELDS = ["repo_code_hash", "configuration"]
882-
_CONF_TYPE_TO_UPDATABLE_FIELDS = {
897+
_CONF_UPDATABLE_FIELDS = ["priority"]
898+
_TYPE_SPECIFIC_CONF_UPDATABLE_FIELDS = {
883899
"dev-environment": ["inactivity_duration"],
884900
# Most service fields can be updated via replica redeployment.
885901
# TODO: Allow updating other fields when rolling deployment is supported.
@@ -915,12 +931,9 @@ def _check_can_update_configuration(
915931
raise ServerClientError(
916932
f"Configuration type changed from {current.type} to {new.type}, cannot update"
917933
)
918-
updatable_fields = _CONF_TYPE_TO_UPDATABLE_FIELDS.get(new.type)
919-
if updatable_fields is None:
920-
raise ServerClientError(
921-
f"Can only update {', '.join(_CONF_TYPE_TO_UPDATABLE_FIELDS)} configurations."
922-
f" Not {new.type}"
923-
)
934+
updatable_fields = _CONF_UPDATABLE_FIELDS + _TYPE_SPECIFIC_CONF_UPDATABLE_FIELDS.get(
935+
new.type, []
936+
)
924937
diff = diff_models(current, new)
925938
changed_fields = list(diff.keys())
926939
for key in changed_fields:

src/dstack/_internal/server/testing/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ async def create_run(
262262
run_spec: Optional[RunSpec] = None,
263263
run_id: Optional[UUID] = None,
264264
deleted: bool = False,
265+
priority: int = 0,
265266
) -> RunModel:
266267
if run_spec is None:
267268
run_spec = get_run_spec(
@@ -282,6 +283,7 @@ async def create_run(
282283
run_spec=run_spec.json(),
283284
last_processed_at=submitted_at,
284285
jobs=[],
286+
priority=priority,
285287
)
286288
session.add(run)
287289
await session.commit()

src/dstack/api/server/_runs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ def _get_run_spec_excludes(run_spec: RunSpec) -> Optional[Dict]:
186186
configuration_excludes["rate_limits"] = True
187187
if configuration.shell is None:
188188
configuration_excludes["shell"] = True
189+
if configuration.priority is None:
190+
configuration_excludes["priority"] = True
189191

190192
if configuration_excludes:
191193
spec_excludes["configuration"] = configuration_excludes

src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,3 +634,63 @@ async def test_creates_new_instance_in_existing_fleet(self, test_db, session: As
634634
assert job.instance is not None
635635
assert job.instance.instance_num == 1
636636
assert job.instance.fleet_id == fleet.id
637+
638+
@pytest.mark.asyncio
639+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
640+
async def test_picks_high_priority_jobs_first(self, test_db, session: AsyncSession):
641+
project = await create_project(session)
642+
user = await create_user(session)
643+
repo = await create_repo(
644+
session=session,
645+
project_id=project.id,
646+
)
647+
instance = await create_instance(
648+
session=session,
649+
project=project,
650+
status=InstanceStatus.IDLE,
651+
)
652+
run1 = await create_run(
653+
session=session,
654+
project=project,
655+
repo=repo,
656+
user=user,
657+
priority=10,
658+
)
659+
job1 = await create_job(
660+
session=session,
661+
run=run1,
662+
instance_assigned=True,
663+
instance=instance,
664+
)
665+
run2 = await create_run(
666+
session=session,
667+
project=project,
668+
repo=repo,
669+
user=user,
670+
priority=0,
671+
)
672+
job2 = await create_job(
673+
session=session, run=run2, instance_assigned=True, instance=instance
674+
)
675+
run3 = await create_run(
676+
session=session,
677+
project=project,
678+
repo=repo,
679+
user=user,
680+
priority=100,
681+
)
682+
job3 = await create_job(
683+
session=session,
684+
run=run3,
685+
instance_assigned=True,
686+
instance=instance,
687+
)
688+
await process_submitted_jobs()
689+
await session.refresh(job3)
690+
assert job3.status == JobStatus.PROVISIONING
691+
await process_submitted_jobs()
692+
await session.refresh(job1)
693+
assert job1.status == JobStatus.PROVISIONING
694+
await process_submitted_jobs()
695+
await session.refresh(job2)
696+
assert job2.status == JobStatus.PROVISIONING

src/tests/_internal/server/routers/test_runs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def get_dev_env_run_plan_dict(
124124
"reservation": None,
125125
"fleets": None,
126126
"tags": None,
127+
"priority": 0,
127128
},
128129
"configuration_path": "dstack.yaml",
129130
"profile": {
@@ -284,6 +285,7 @@ def get_dev_env_run_dict(
284285
"reservation": None,
285286
"fleets": None,
286287
"tags": None,
288+
"priority": 0,
287289
},
288290
"configuration_path": "dstack.yaml",
289291
"profile": {

0 commit comments

Comments
 (0)