From 3158183f2a658e5462600de172f58855b5af0b12 Mon Sep 17 00:00:00 2001 From: Dmitry Meyer Date: Fri, 16 May 2025 14:38:30 +0000 Subject: [PATCH] Allow multi-node tasks on idle shared instances Fixes: https://github.com/dstackai/dstack/issues/2650 --- .../tasks/process_submitted_jobs.py | 22 ++--- .../_internal/server/services/instances.py | 8 +- src/dstack/_internal/server/services/runs.py | 18 ++-- .../tasks/test_process_submitted_jobs.py | 94 +++++++++++++++++++ 4 files changed, 120 insertions(+), 22 deletions(-) diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index efa8600a4..048a93818 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -360,16 +360,16 @@ async def _assign_job_to_pool_instance( (instance, common_utils.get_or_error(get_instance_offer(instance))) for instance in nonshared_instances ] - if not multinode: - shared_instances_with_offers = get_shared_pool_instances_with_offers( - pool_instances=pool_instances, - profile=profile, - requirements=job.job_spec.requirements, - idle_only=True, - fleet_model=fleet_model, - volumes=volumes, - ) - instances_with_offers.extend(shared_instances_with_offers) + shared_instances_with_offers = get_shared_pool_instances_with_offers( + pool_instances=pool_instances, + profile=profile, + requirements=job.job_spec.requirements, + idle_only=True, + fleet_model=fleet_model, + multinode=multinode, + volumes=volumes, + ) + instances_with_offers.extend(shared_instances_with_offers) if len(instances_with_offers) == 0: return None @@ -572,7 +572,7 @@ def _create_instance_model_for_job( def _prepare_job_runtime_data(offer: InstanceOfferWithAvailability) -> JobRuntimeData: - if offer.total_blocks == 1: + if offer.blocks == offer.total_blocks: if env_utils.get_bool("DSTACK_FORCE_BRIDGE_NETWORK"): network_mode = NetworkMode.BRIDGE else: diff --git a/src/dstack/_internal/server/services/instances.py b/src/dstack/_internal/server/services/instances.py index 9d0cb5299..53ca55396 100644 --- a/src/dstack/_internal/server/services/instances.py +++ b/src/dstack/_internal/server/services/instances.py @@ -235,6 +235,7 @@ def get_shared_pool_instances_with_offers( *, idle_only: bool = False, fleet_model: Optional[FleetModel] = None, + multinode: bool = False, volumes: Optional[List[List[Volume]]] = None, ) -> list[tuple[InstanceModel, InstanceOfferWithAvailability]]: instances_with_offers: list[tuple[InstanceModel, InstanceOfferWithAvailability]] = [] @@ -243,19 +244,22 @@ def get_shared_pool_instances_with_offers( pool_instances=pool_instances, profile=profile, fleet_model=fleet_model, - multinode=False, + multinode=multinode, volumes=volumes, shared=True, ) for instance in filtered_instances: if idle_only and instance.status not in [InstanceStatus.IDLE, InstanceStatus.BUSY]: continue + if multinode and instance.busy_blocks > 0: + continue offer = get_instance_offer(instance) if offer is None: continue total_blocks = common_utils.get_or_error(instance.total_blocks) idle_blocks = total_blocks - instance.busy_blocks - for blocks in range(1, total_blocks + 1): + min_blocks = total_blocks if multinode else 1 + for blocks in range(min_blocks, total_blocks + 1): shared_offer = generate_shared_offer(offer, blocks, total_blocks) catalog_item = offer_to_catalog_item(shared_offer) if gpuhunt.matches(catalog_item, query_filter): diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 806606331..8a6adf208 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -721,15 +721,15 @@ async def _get_pool_offers( pool_instances = [i for i in pool_instances if i.id not in detaching_instances_ids] multinode = job.job_spec.jobs_per_replica > 1 - if not multinode: - shared_instances_with_offers = get_shared_pool_instances_with_offers( - pool_instances=pool_instances, - profile=run_spec.merged_profile, - requirements=job.job_spec.requirements, - volumes=volumes, - ) - for _, offer in shared_instances_with_offers: - pool_offers.append(offer) + shared_instances_with_offers = get_shared_pool_instances_with_offers( + pool_instances=pool_instances, + profile=run_spec.merged_profile, + requirements=job.job_spec.requirements, + volumes=volumes, + multinode=multinode, + ) + for _, offer in shared_instances_with_offers: + pool_offers.append(offer) nonshared_instances = filter_pool_instances( pool_instances=pool_instances, diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py index 24cb19daa..b065a41b3 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import joinedload from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.configurations import TaskConfiguration from dstack._internal.core.models.instances import ( InstanceAvailability, InstanceOfferWithAvailability, @@ -536,6 +537,99 @@ async def test_assigns_job_to_shared_instance(self, test_db, session: AsyncSessi assert instance.total_blocks == 4 assert instance.busy_blocks == 2 + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_assigns_multi_node_job_to_shared_instance(self, test_db, session: AsyncSession): + project = await create_project(session) + user = await create_user(session) + repo = await create_repo( + session=session, + project_id=project.id, + ) + offer = get_instance_offer_with_availability(gpu_count=8, cpu_count=64, memory_gib=128) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + backend=BackendType.AWS, + offer=offer, + total_blocks=4, + busy_blocks=0, + ) + configuration = TaskConfiguration(image="debian", nodes=2) + run_spec = get_run_spec(run_name="run", repo_id=repo.name, configuration=configuration) + run = await create_run( + session=session, + run_name="run", + project=project, + repo=repo, + user=user, + run_spec=run_spec, + ) + job = await create_job( + session=session, + run=run, + instance_assigned=False, + ) + await process_submitted_jobs() + await session.refresh(job) + await session.refresh(instance) + res = await session.execute(select(JobModel).options(joinedload(JobModel.instance))) + job = res.unique().scalar_one() + assert job.status == JobStatus.SUBMITTED + assert job.instance_assigned + assert job.instance is not None + assert job.instance.id == instance.id + assert instance.total_blocks == 4 + assert instance.busy_blocks == 4 + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_cannot_assign_multi_node_job_to_partially_busy_shared_instance( + self, test_db, session: AsyncSession + ): + project = await create_project(session) + user = await create_user(session) + repo = await create_repo( + session=session, + project_id=project.id, + ) + offer = get_instance_offer_with_availability(gpu_count=8, cpu_count=64, memory_gib=128) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + backend=BackendType.AWS, + offer=offer, + total_blocks=4, + busy_blocks=1, + ) + configuration = TaskConfiguration(image="debian", nodes=2) + run_spec = get_run_spec(run_name="run", repo_id=repo.name, configuration=configuration) + run = await create_run( + session=session, + run_name="run", + project=project, + repo=repo, + user=user, + run_spec=run_spec, + ) + job = await create_job( + session=session, + run=run, + instance_assigned=False, + ) + await process_submitted_jobs() + await session.refresh(job) + await session.refresh(instance) + res = await session.execute(select(JobModel).options(joinedload(JobModel.instance))) + job = res.unique().scalar_one() + assert job.status == JobStatus.SUBMITTED + assert job.instance_assigned + assert job.instance is None + assert instance.total_blocks == 4 + assert instance.busy_blocks == 1 + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_assigns_job_to_specific_fleet(self, test_db, session: AsyncSession):