diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index e975f684e..e216a6aef 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -1,10 +1,12 @@ import concurrent.futures import json +import threading from collections import defaultdict from typing import Callable, Dict, List, Literal, Optional, Tuple import google.api_core.exceptions import google.cloud.compute_v1 as compute_v1 +from cachetools import TTLCache, cachedmethod from google.cloud import tpu_v2 from gpuhunt import KNOWN_TPUS @@ -98,6 +100,8 @@ def __init__(self, config: GCPConfig): self.resource_policies_client = compute_v1.ResourcePoliciesClient( credentials=self.credentials ) + self._extra_subnets_cache_lock = threading.Lock() + self._extra_subnets_cache = TTLCache(maxsize=30, ttl=60) def get_offers( self, requirements: Optional[Requirements] = None @@ -193,9 +197,7 @@ def create_instance( config=self.config, region=instance_offer.region, ) - extra_subnets = _get_extra_subnets( - subnetworks_client=self.subnetworks_client, - config=self.config, + extra_subnets = self._get_extra_subnets( region=instance_offer.region, instance_type_name=instance_offer.instance.name, ) @@ -769,6 +771,38 @@ def detach_volume(self, volume: Volume, instance_id: str, force: bool = False): instance_id, ) + @cachedmethod( + cache=lambda self: self._extra_subnets_cache, + lock=lambda self: self._extra_subnets_cache_lock, + ) + def _get_extra_subnets( + self, + region: str, + instance_type_name: str, + ) -> List[Tuple[str, str]]: + if self.config.extra_vpcs is None: + return [] + if instance_type_name == "a3-megagpu-8g": + subnets_num = 8 + elif instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]: + subnets_num = 4 + else: + return [] + extra_subnets = [] + for vpc_name in self.config.extra_vpcs[:subnets_num]: + subnet = gcp_resources.get_vpc_subnet_or_error( + subnetworks_client=self.subnetworks_client, + vpc_project_id=self.config.vpc_project_id or self.config.project_id, + vpc_name=vpc_name, + region=region, + ) + vpc_resource_name = gcp_resources.vpc_name_to_vpc_resource_name( + project_id=self.config.vpc_project_id or self.config.project_id, + vpc_name=vpc_name, + ) + extra_subnets.append((vpc_resource_name, subnet)) + return extra_subnets + def _supported_instances_and_zones( regions: List[str], @@ -843,36 +877,6 @@ def _get_vpc_subnet( ) -def _get_extra_subnets( - subnetworks_client: compute_v1.SubnetworksClient, - config: GCPConfig, - region: str, - instance_type_name: str, -) -> List[Tuple[str, str]]: - if config.extra_vpcs is None: - return [] - if instance_type_name == "a3-megagpu-8g": - subnets_num = 8 - elif instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]: - subnets_num = 4 - else: - return [] - extra_subnets = [] - for vpc_name in config.extra_vpcs[:subnets_num]: - subnet = gcp_resources.get_vpc_subnet_or_error( - subnetworks_client=subnetworks_client, - vpc_project_id=config.vpc_project_id or config.project_id, - vpc_name=vpc_name, - region=region, - ) - vpc_resource_name = gcp_resources.vpc_name_to_vpc_resource_name( - project_id=config.vpc_project_id or config.project_id, - vpc_name=vpc_name, - ) - extra_subnets.append((vpc_resource_name, subnet)) - return extra_subnets - - def _get_image_id(instance_type_name: str, cuda: bool) -> str: if instance_type_name == "a3-megagpu-8g": image_name = "dstack-a3mega-5"