Skip to content

Allow multi-node tasks on idle instances with blocks #2651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -369,16 +369,16 @@ async def _assign_job_to_pool_instance(
(instance, common_utils.get_or_error(get_instance_offer(instance)))
for instance in nonshared_instances
]
if not multinode:
shared_instances_with_offers = get_shared_pool_instances_with_offers(
pool_instances=pool_instances,
profile=profile,
requirements=job.job_spec.requirements,
idle_only=True,
fleet_model=fleet_model,
volumes=volumes,
)
instances_with_offers.extend(shared_instances_with_offers)
shared_instances_with_offers = get_shared_pool_instances_with_offers(
pool_instances=pool_instances,
profile=profile,
requirements=job.job_spec.requirements,
idle_only=True,
fleet_model=fleet_model,
multinode=multinode,
volumes=volumes,
)
instances_with_offers.extend(shared_instances_with_offers)

if len(instances_with_offers) == 0:
return None
Expand Down Expand Up @@ -581,7 +581,7 @@ def _create_instance_model_for_job(


def _prepare_job_runtime_data(offer: InstanceOfferWithAvailability) -> JobRuntimeData:
if offer.total_blocks == 1:
if offer.blocks == offer.total_blocks:
if env_utils.get_bool("DSTACK_FORCE_BRIDGE_NETWORK"):
network_mode = NetworkMode.BRIDGE
else:
Expand Down
8 changes: 6 additions & 2 deletions src/dstack/_internal/server/services/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def get_shared_pool_instances_with_offers(
*,
idle_only: bool = False,
fleet_model: Optional[FleetModel] = None,
multinode: bool = False,
volumes: Optional[List[List[Volume]]] = None,
) -> list[tuple[InstanceModel, InstanceOfferWithAvailability]]:
instances_with_offers: list[tuple[InstanceModel, InstanceOfferWithAvailability]] = []
Expand All @@ -243,19 +244,22 @@ def get_shared_pool_instances_with_offers(
pool_instances=pool_instances,
profile=profile,
fleet_model=fleet_model,
multinode=False,
multinode=multinode,
volumes=volumes,
shared=True,
)
for instance in filtered_instances:
if idle_only and instance.status not in [InstanceStatus.IDLE, InstanceStatus.BUSY]:
continue
if multinode and instance.busy_blocks > 0:
continue
offer = get_instance_offer(instance)
if offer is None:
continue
total_blocks = common_utils.get_or_error(instance.total_blocks)
idle_blocks = total_blocks - instance.busy_blocks
for blocks in range(1, total_blocks + 1):
min_blocks = total_blocks if multinode else 1
for blocks in range(min_blocks, total_blocks + 1):
shared_offer = generate_shared_offer(offer, blocks, total_blocks)
catalog_item = offer_to_catalog_item(shared_offer)
if gpuhunt.matches(catalog_item, query_filter):
Expand Down
18 changes: 9 additions & 9 deletions src/dstack/_internal/server/services/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,15 +727,15 @@ async def _get_pool_offers(
pool_instances = [i for i in pool_instances if i.id not in detaching_instances_ids]
multinode = job.job_spec.jobs_per_replica > 1

if not multinode:
shared_instances_with_offers = get_shared_pool_instances_with_offers(
pool_instances=pool_instances,
profile=run_spec.merged_profile,
requirements=job.job_spec.requirements,
volumes=volumes,
)
for _, offer in shared_instances_with_offers:
pool_offers.append(offer)
shared_instances_with_offers = get_shared_pool_instances_with_offers(
pool_instances=pool_instances,
profile=run_spec.merged_profile,
requirements=job.job_spec.requirements,
volumes=volumes,
multinode=multinode,
)
for _, offer in shared_instances_with_offers:
pool_offers.append(offer)

nonshared_instances = filter_pool_instances(
pool_instances=pool_instances,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sqlalchemy.orm import joinedload

from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.configurations import TaskConfiguration
from dstack._internal.core.models.instances import (
InstanceAvailability,
InstanceOfferWithAvailability,
Expand Down Expand Up @@ -536,6 +537,99 @@ async def test_assigns_job_to_shared_instance(self, test_db, session: AsyncSessi
assert instance.total_blocks == 4
assert instance.busy_blocks == 2

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_assigns_multi_node_job_to_shared_instance(self, test_db, session: AsyncSession):
project = await create_project(session)
user = await create_user(session)
repo = await create_repo(
session=session,
project_id=project.id,
)
offer = get_instance_offer_with_availability(gpu_count=8, cpu_count=64, memory_gib=128)
instance = await create_instance(
session=session,
project=project,
status=InstanceStatus.IDLE,
backend=BackendType.AWS,
offer=offer,
total_blocks=4,
busy_blocks=0,
)
configuration = TaskConfiguration(image="debian", nodes=2)
run_spec = get_run_spec(run_name="run", repo_id=repo.name, configuration=configuration)
run = await create_run(
session=session,
run_name="run",
project=project,
repo=repo,
user=user,
run_spec=run_spec,
)
job = await create_job(
session=session,
run=run,
instance_assigned=False,
)
await process_submitted_jobs()
await session.refresh(job)
await session.refresh(instance)
res = await session.execute(select(JobModel).options(joinedload(JobModel.instance)))
job = res.unique().scalar_one()
assert job.status == JobStatus.SUBMITTED
assert job.instance_assigned
assert job.instance is not None
assert job.instance.id == instance.id
assert instance.total_blocks == 4
assert instance.busy_blocks == 4

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_cannot_assign_multi_node_job_to_partially_busy_shared_instance(
self, test_db, session: AsyncSession
):
project = await create_project(session)
user = await create_user(session)
repo = await create_repo(
session=session,
project_id=project.id,
)
offer = get_instance_offer_with_availability(gpu_count=8, cpu_count=64, memory_gib=128)
instance = await create_instance(
session=session,
project=project,
status=InstanceStatus.IDLE,
backend=BackendType.AWS,
offer=offer,
total_blocks=4,
busy_blocks=1,
)
configuration = TaskConfiguration(image="debian", nodes=2)
run_spec = get_run_spec(run_name="run", repo_id=repo.name, configuration=configuration)
run = await create_run(
session=session,
run_name="run",
project=project,
repo=repo,
user=user,
run_spec=run_spec,
)
job = await create_job(
session=session,
run=run,
instance_assigned=False,
)
await process_submitted_jobs()
await session.refresh(job)
await session.refresh(instance)
res = await session.execute(select(JobModel).options(joinedload(JobModel.instance)))
job = res.unique().scalar_one()
assert job.status == JobStatus.SUBMITTED
assert job.instance_assigned
assert job.instance is None
assert instance.total_blocks == 4
assert instance.busy_blocks == 1

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_assigns_job_to_specific_fleet(self, test_db, session: AsyncSession):
Expand Down
Loading