Skip to content

Commit cb7f33d

Browse files
un-defr4victor
andauthored
Allow multi-node tasks on idle shared instances (#2651)
Fixes: #2650 Co-authored-by: Victor Skvortsov <[email protected]>
1 parent a5b9b93 commit cb7f33d

File tree

4 files changed

+120
-22
lines changed

4 files changed

+120
-22
lines changed

src/dstack/_internal/server/background/tasks/process_submitted_jobs.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -369,16 +369,16 @@ async def _assign_job_to_pool_instance(
369369
(instance, common_utils.get_or_error(get_instance_offer(instance)))
370370
for instance in nonshared_instances
371371
]
372-
if not multinode:
373-
shared_instances_with_offers = get_shared_pool_instances_with_offers(
374-
pool_instances=pool_instances,
375-
profile=profile,
376-
requirements=job.job_spec.requirements,
377-
idle_only=True,
378-
fleet_model=fleet_model,
379-
volumes=volumes,
380-
)
381-
instances_with_offers.extend(shared_instances_with_offers)
372+
shared_instances_with_offers = get_shared_pool_instances_with_offers(
373+
pool_instances=pool_instances,
374+
profile=profile,
375+
requirements=job.job_spec.requirements,
376+
idle_only=True,
377+
fleet_model=fleet_model,
378+
multinode=multinode,
379+
volumes=volumes,
380+
)
381+
instances_with_offers.extend(shared_instances_with_offers)
382382

383383
if len(instances_with_offers) == 0:
384384
return None
@@ -581,7 +581,7 @@ def _create_instance_model_for_job(
581581

582582

583583
def _prepare_job_runtime_data(offer: InstanceOfferWithAvailability) -> JobRuntimeData:
584-
if offer.total_blocks == 1:
584+
if offer.blocks == offer.total_blocks:
585585
if env_utils.get_bool("DSTACK_FORCE_BRIDGE_NETWORK"):
586586
network_mode = NetworkMode.BRIDGE
587587
else:

src/dstack/_internal/server/services/instances.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def get_shared_pool_instances_with_offers(
235235
*,
236236
idle_only: bool = False,
237237
fleet_model: Optional[FleetModel] = None,
238+
multinode: bool = False,
238239
volumes: Optional[List[List[Volume]]] = None,
239240
) -> list[tuple[InstanceModel, InstanceOfferWithAvailability]]:
240241
instances_with_offers: list[tuple[InstanceModel, InstanceOfferWithAvailability]] = []
@@ -243,19 +244,22 @@ def get_shared_pool_instances_with_offers(
243244
pool_instances=pool_instances,
244245
profile=profile,
245246
fleet_model=fleet_model,
246-
multinode=False,
247+
multinode=multinode,
247248
volumes=volumes,
248249
shared=True,
249250
)
250251
for instance in filtered_instances:
251252
if idle_only and instance.status not in [InstanceStatus.IDLE, InstanceStatus.BUSY]:
252253
continue
254+
if multinode and instance.busy_blocks > 0:
255+
continue
253256
offer = get_instance_offer(instance)
254257
if offer is None:
255258
continue
256259
total_blocks = common_utils.get_or_error(instance.total_blocks)
257260
idle_blocks = total_blocks - instance.busy_blocks
258-
for blocks in range(1, total_blocks + 1):
261+
min_blocks = total_blocks if multinode else 1
262+
for blocks in range(min_blocks, total_blocks + 1):
259263
shared_offer = generate_shared_offer(offer, blocks, total_blocks)
260264
catalog_item = offer_to_catalog_item(shared_offer)
261265
if gpuhunt.matches(catalog_item, query_filter):

src/dstack/_internal/server/services/runs.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -727,15 +727,15 @@ async def _get_pool_offers(
727727
pool_instances = [i for i in pool_instances if i.id not in detaching_instances_ids]
728728
multinode = job.job_spec.jobs_per_replica > 1
729729

730-
if not multinode:
731-
shared_instances_with_offers = get_shared_pool_instances_with_offers(
732-
pool_instances=pool_instances,
733-
profile=run_spec.merged_profile,
734-
requirements=job.job_spec.requirements,
735-
volumes=volumes,
736-
)
737-
for _, offer in shared_instances_with_offers:
738-
pool_offers.append(offer)
730+
shared_instances_with_offers = get_shared_pool_instances_with_offers(
731+
pool_instances=pool_instances,
732+
profile=run_spec.merged_profile,
733+
requirements=job.job_spec.requirements,
734+
volumes=volumes,
735+
multinode=multinode,
736+
)
737+
for _, offer in shared_instances_with_offers:
738+
pool_offers.append(offer)
739739

740740
nonshared_instances = filter_pool_instances(
741741
pool_instances=pool_instances,

src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sqlalchemy.orm import joinedload
88

99
from dstack._internal.core.models.backends.base import BackendType
10+
from dstack._internal.core.models.configurations import TaskConfiguration
1011
from dstack._internal.core.models.instances import (
1112
InstanceAvailability,
1213
InstanceOfferWithAvailability,
@@ -536,6 +537,99 @@ async def test_assigns_job_to_shared_instance(self, test_db, session: AsyncSessi
536537
assert instance.total_blocks == 4
537538
assert instance.busy_blocks == 2
538539

540+
@pytest.mark.asyncio
541+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
542+
async def test_assigns_multi_node_job_to_shared_instance(self, test_db, session: AsyncSession):
543+
project = await create_project(session)
544+
user = await create_user(session)
545+
repo = await create_repo(
546+
session=session,
547+
project_id=project.id,
548+
)
549+
offer = get_instance_offer_with_availability(gpu_count=8, cpu_count=64, memory_gib=128)
550+
instance = await create_instance(
551+
session=session,
552+
project=project,
553+
status=InstanceStatus.IDLE,
554+
backend=BackendType.AWS,
555+
offer=offer,
556+
total_blocks=4,
557+
busy_blocks=0,
558+
)
559+
configuration = TaskConfiguration(image="debian", nodes=2)
560+
run_spec = get_run_spec(run_name="run", repo_id=repo.name, configuration=configuration)
561+
run = await create_run(
562+
session=session,
563+
run_name="run",
564+
project=project,
565+
repo=repo,
566+
user=user,
567+
run_spec=run_spec,
568+
)
569+
job = await create_job(
570+
session=session,
571+
run=run,
572+
instance_assigned=False,
573+
)
574+
await process_submitted_jobs()
575+
await session.refresh(job)
576+
await session.refresh(instance)
577+
res = await session.execute(select(JobModel).options(joinedload(JobModel.instance)))
578+
job = res.unique().scalar_one()
579+
assert job.status == JobStatus.SUBMITTED
580+
assert job.instance_assigned
581+
assert job.instance is not None
582+
assert job.instance.id == instance.id
583+
assert instance.total_blocks == 4
584+
assert instance.busy_blocks == 4
585+
586+
@pytest.mark.asyncio
587+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
588+
async def test_cannot_assign_multi_node_job_to_partially_busy_shared_instance(
589+
self, test_db, session: AsyncSession
590+
):
591+
project = await create_project(session)
592+
user = await create_user(session)
593+
repo = await create_repo(
594+
session=session,
595+
project_id=project.id,
596+
)
597+
offer = get_instance_offer_with_availability(gpu_count=8, cpu_count=64, memory_gib=128)
598+
instance = await create_instance(
599+
session=session,
600+
project=project,
601+
status=InstanceStatus.IDLE,
602+
backend=BackendType.AWS,
603+
offer=offer,
604+
total_blocks=4,
605+
busy_blocks=1,
606+
)
607+
configuration = TaskConfiguration(image="debian", nodes=2)
608+
run_spec = get_run_spec(run_name="run", repo_id=repo.name, configuration=configuration)
609+
run = await create_run(
610+
session=session,
611+
run_name="run",
612+
project=project,
613+
repo=repo,
614+
user=user,
615+
run_spec=run_spec,
616+
)
617+
job = await create_job(
618+
session=session,
619+
run=run,
620+
instance_assigned=False,
621+
)
622+
await process_submitted_jobs()
623+
await session.refresh(job)
624+
await session.refresh(instance)
625+
res = await session.execute(select(JobModel).options(joinedload(JobModel.instance)))
626+
job = res.unique().scalar_one()
627+
assert job.status == JobStatus.SUBMITTED
628+
assert job.instance_assigned
629+
assert job.instance is None
630+
assert instance.total_blocks == 4
631+
assert instance.busy_blocks == 1
632+
539633
@pytest.mark.asyncio
540634
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
541635
async def test_assigns_job_to_specific_fleet(self, test_db, session: AsyncSession):

0 commit comments

Comments
 (0)