Skip to content

[chore]: Drop is_core_model_instance #2536

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/dstack/_internal/cli/services/configurators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from dstack._internal.cli.services.args import env_var
from dstack._internal.core.errors import ConfigurationError
from dstack._internal.core.models.common import is_core_model_instance
from dstack._internal.core.models.configurations import (
AnyApplyConfiguration,
ApplyConfigurationType,
Expand Down Expand Up @@ -100,7 +99,7 @@ def apply_env_vars(self, env: Env, configurator_args: argparse.Namespace) -> Non
for k, v in cast(List[EnvVarTuple], configurator_args.env_vars):
env[k] = v
for k, v in env.items():
if is_core_model_instance(v, EnvSentinel):
if isinstance(v, EnvSentinel):
try:
env[k] = v.from_env(os.environ)
except ValueError as e:
Expand Down
3 changes: 1 addition & 2 deletions src/dstack/_internal/cli/utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from rich.table import Table

from dstack._internal.cli.utils.common import NO_OFFERS_WARNING, add_row_from_dict, console
from dstack._internal.core.models.common import is_core_model_instance
from dstack._internal.core.models.configurations import DevEnvironmentConfiguration
from dstack._internal.core.models.instances import InstanceAvailability
from dstack._internal.core.models.profiles import (
Expand Down Expand Up @@ -41,7 +40,7 @@ def print_run_plan(run_plan: RunPlan, offers_limit: int = 3):
else "-"
)
inactivity_duration = None
if is_core_model_instance(run_plan.run_spec.configuration, DevEnvironmentConfiguration):
if isinstance(run_plan.run_spec.configuration, DevEnvironmentConfiguration):
inactivity_duration = "-"
if isinstance(run_plan.run_spec.configuration.inactivity_duration, int):
inactivity_duration = format_pretty_duration(
Expand Down
3 changes: 1 addition & 2 deletions src/dstack/_internal/core/backends/aws/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from dstack._internal.core.backends.aws.models import AnyAWSCreds, AWSAccessKeyCreds
from dstack._internal.core.errors import BackendAuthError
from dstack._internal.core.models.common import is_core_model_instance


def authenticate(creds: AnyAWSCreds, region: str) -> Session:
Expand All @@ -14,7 +13,7 @@ def authenticate(creds: AnyAWSCreds, region: str) -> Session:


def get_session(creds: AnyAWSCreds, region: str) -> Session:
if is_core_model_instance(creds, AWSAccessKeyCreds):
if isinstance(creds, AWSAccessKeyCreds):
return boto3.session.Session(
region_name=region,
aws_access_key_id=creds.access_key,
Expand Down
4 changes: 2 additions & 2 deletions src/dstack/_internal/core/backends/aws/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from dstack._internal.core.backends.base.offers import get_catalog_offers
from dstack._internal.core.errors import ComputeError, NoCapacityError, PlacementGroupInUseError
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import CoreModel, is_core_model_instance
from dstack._internal.core.models.common import CoreModel
from dstack._internal.core.models.gateways import (
GatewayComputeConfiguration,
GatewayProvisioningData,
Expand Down Expand Up @@ -79,7 +79,7 @@ class AWSCompute(
def __init__(self, config: AWSConfig):
super().__init__()
self.config = config
if is_core_model_instance(config.creds, AWSAccessKeyCreds):
if isinstance(config.creds, AWSAccessKeyCreds):
self.session = boto3.Session(
aws_access_key_id=config.creds.access_key,
aws_secret_access_key=config.creds.secret_key,
Expand Down
5 changes: 2 additions & 3 deletions src/dstack/_internal/core/backends/aws/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from dstack._internal.core.models.backends.base import (
BackendType,
)
from dstack._internal.core.models.common import is_core_model_instance
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -58,12 +57,12 @@ class AWSConfigurator(Configurator):
BACKEND_CLASS = AWSBackend

def validate_config(self, config: AWSBackendConfigWithCreds, default_creds_enabled: bool):
if is_core_model_instance(config.creds, AWSDefaultCreds) and not default_creds_enabled:
if isinstance(config.creds, AWSDefaultCreds) and not default_creds_enabled:
raise_invalid_credentials_error(fields=[["creds"]])
try:
session = auth.authenticate(creds=config.creds, region=MAIN_REGION)
except Exception:
if is_core_model_instance(config.creds, AWSAccessKeyCreds):
if isinstance(config.creds, AWSAccessKeyCreds):
raise_invalid_credentials_error(
fields=[
["creds", "access_key"],
Expand Down
3 changes: 1 addition & 2 deletions src/dstack/_internal/core/backends/azure/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
AzureClientCreds,
)
from dstack._internal.core.errors import BackendAuthError
from dstack._internal.core.models.common import is_core_model_instance

AzureCredential = Union[ClientSecretCredential, DefaultAzureCredential]

Expand All @@ -21,7 +20,7 @@ def authenticate(creds: AnyAzureCreds) -> Tuple[AzureCredential, str]:


def get_credential(creds: AnyAzureCreds) -> Tuple[AzureCredential, str]:
if is_core_model_instance(creds, AzureClientCreds):
if isinstance(creds, AzureClientCreds):
credential = ClientSecretCredential(
tenant_id=creds.tenant_id,
client_id=creds.client_id,
Expand Down
9 changes: 4 additions & 5 deletions src/dstack/_internal/core/backends/azure/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
from dstack._internal.core.models.backends.base import (
BackendType,
)
from dstack._internal.core.models.common import is_core_model_instance

LOCATIONS = [
("(US) Central US", "centralus"),
Expand Down Expand Up @@ -76,14 +75,14 @@ class AzureConfigurator(Configurator):
BACKEND_CLASS = AzureBackend

def validate_config(self, config: AzureBackendConfigWithCreds, default_creds_enabled: bool):
if is_core_model_instance(config.creds, AzureDefaultCreds) and not default_creds_enabled:
if isinstance(config.creds, AzureDefaultCreds) and not default_creds_enabled:
raise_invalid_credentials_error(fields=[["creds"]])
if is_core_model_instance(config.creds, AzureClientCreds):
if isinstance(config.creds, AzureClientCreds):
self._set_client_creds_tenant_id(config.creds, config.tenant_id)
try:
credential, _ = auth.authenticate(config.creds)
except BackendAuthError:
if is_core_model_instance(config.creds, AzureClientCreds):
if isinstance(config.creds, AzureClientCreds):
raise_invalid_credentials_error(
fields=[
["creds", "tenant_id"],
Expand All @@ -105,7 +104,7 @@ def create_backend(
) -> BackendRecord:
if config.regions is None:
config.regions = DEFAULT_LOCATIONS
if is_core_model_instance(config.creds, AzureClientCreds):
if isinstance(config.creds, AzureClientCreds):
self._set_client_creds_tenant_id(config.creds, config.tenant_id)
credential, _ = auth.authenticate(config.creds)
if config.resource_group is None:
Expand Down
3 changes: 1 addition & 2 deletions src/dstack/_internal/core/backends/gcp/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
GCPServiceAccountCreds,
)
from dstack._internal.core.errors import BackendAuthError
from dstack._internal.core.models.common import is_core_model_instance


def authenticate(creds: AnyGCPCreds, project_id: Optional[str] = None) -> Tuple[Credentials, str]:
Expand All @@ -30,7 +29,7 @@ def authenticate(creds: AnyGCPCreds, project_id: Optional[str] = None) -> Tuple[


def get_credentials(creds: AnyGCPCreds) -> Tuple[Credentials, Optional[str]]:
if is_core_model_instance(creds, GCPServiceAccountCreds):
if isinstance(creds, GCPServiceAccountCreds):
try:
service_account_info = json.loads(creds.data)
credentials = service_account.Credentials.from_service_account_info(
Expand Down
5 changes: 2 additions & 3 deletions src/dstack/_internal/core/backends/gcp/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from dstack._internal.core.models.backends.base import (
BackendType,
)
from dstack._internal.core.models.common import is_core_model_instance

LOCATIONS = [
{
Expand Down Expand Up @@ -115,15 +114,15 @@ class GCPConfigurator(Configurator):
BACKEND_CLASS = GCPBackend

def validate_config(self, config: GCPBackendConfigWithCreds, default_creds_enabled: bool):
if is_core_model_instance(config.creds, GCPDefaultCreds) and not default_creds_enabled:
if isinstance(config.creds, GCPDefaultCreds) and not default_creds_enabled:
raise_invalid_credentials_error(fields=[["creds"]])
try:
credentials, _ = auth.authenticate(creds=config.creds, project_id=config.project_id)
except BackendAuthError as e:
details = None
if len(e.args) > 0:
details = e.args[0]
if is_core_model_instance(config.creds, GCPServiceAccountCreds):
if isinstance(config.creds, GCPServiceAccountCreds):
raise_invalid_credentials_error(fields=[["creds", "data"]], details=details)
else:
raise_invalid_credentials_error(fields=[["creds"]], details=details)
Expand Down
3 changes: 1 addition & 2 deletions src/dstack/_internal/core/backends/oci/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@

from dstack._internal.core.backends.oci.exceptions import any_oci_exception
from dstack._internal.core.backends.oci.models import AnyOCICreds, OCIDefaultCreds
from dstack._internal.core.models.common import is_core_model_instance


def get_client_config(creds: AnyOCICreds) -> Mapping[str, Any]:
if is_core_model_instance(creds, OCIDefaultCreds):
if isinstance(creds, OCIDefaultCreds):
return oci.config.from_file(file_location=creds.file, profile_name=creds.profile)
return creds.dict(exclude={"type"})

Expand Down
3 changes: 1 addition & 2 deletions src/dstack/_internal/core/backends/oci/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from dstack._internal.core.models.backends.base import (
BackendType,
)
from dstack._internal.core.models.common import is_core_model_instance

# where dstack images are published
SUPPORTED_REGIONS = frozenset(
Expand All @@ -48,7 +47,7 @@ class OCIConfigurator(Configurator):
BACKEND_CLASS = OCIBackend

def validate_config(self, config: OCIBackendConfigWithCreds, default_creds_enabled: bool):
if is_core_model_instance(config.creds, OCIDefaultCreds) and not default_creds_enabled:
if isinstance(config.creds, OCIDefaultCreds) and not default_creds_enabled:
raise_invalid_credentials_error(
fields=[["creds"]],
details="Default credentials are forbidden by dstack settings",
Expand Down
16 changes: 2 additions & 14 deletions src/dstack/_internal/core/models/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import re
from enum import Enum
from typing import Any, Type, TypeVar, Union
from typing import Union

from pydantic import Field
from pydantic_duality import DualBaseModel
from typing_extensions import Annotated, TypeGuard
from typing_extensions import Annotated


# DualBaseModel creates two classes for the model:
Expand Down Expand Up @@ -74,15 +74,3 @@ class ApplyAction(str, Enum):
class NetworkMode(str, Enum):
HOST = "host"
BRIDGE = "bridge"


_CM = TypeVar("_CM", bound=CoreModel)


def is_core_model_instance(instance: Any, class_: Type[_CM]) -> TypeGuard[_CM]:
"""
Implements isinstance check for CoreModel such that
models parsed with MyModel.__response__ pass the check against MyModel.
See https://github.com/dstackai/dstack/issues/1124
"""
return isinstance(instance, class_) or isinstance(instance, class_.__response__)
4 changes: 2 additions & 2 deletions src/dstack/_internal/core/models/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import BaseModel, Field, validator
from typing_extensions import Annotated, Self

from dstack._internal.core.models.common import CoreModel, is_core_model_instance
from dstack._internal.core.models.common import CoreModel

# VAR_NAME=VALUE, VAR_NAME=, or VAR_NAME
_ENV_STRING_REGEX = r"^([a-zA-Z_][a-zA-Z0-9_]*)(=.*$|$)"
Expand Down Expand Up @@ -118,7 +118,7 @@ def as_dict(self) -> Dict[str, str]:
unresolved: List[str] = []
dct: Dict[str, str] = {}
for k, v in self.items():
if is_core_model_instance(v, EnvSentinel):
if isinstance(v, EnvSentinel):
unresolved.append(k)
else:
# cast is required since TypeGuard is for positive cases only
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT
from dstack._internal.core.errors import GatewayError
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import NetworkMode, RegistryAuth, is_core_model_instance
from dstack._internal.core.models.common import NetworkMode, RegistryAuth
from dstack._internal.core.models.configurations import DevEnvironmentConfiguration
from dstack._internal.core.models.instances import (
InstanceStatus,
Expand Down Expand Up @@ -422,9 +422,9 @@ def _process_provisioning_with_shim(
volume_mounts: List[VolumeMountPoint] = []
instance_mounts: List[InstanceMountPoint] = []
for mount in run.run_spec.configuration.volumes:
if is_core_model_instance(mount, VolumeMountPoint):
if isinstance(mount, VolumeMountPoint):
volume_mounts.append(mount.copy())
elif is_core_model_instance(mount, InstanceMountPoint):
elif isinstance(mount, InstanceMountPoint):
instance_mounts.append(mount)
else:
assert False, f"unexpected mount point: {mount!r}"
Expand Down Expand Up @@ -657,7 +657,7 @@ def _terminate_if_inactivity_duration_exceeded(
run_model: RunModel, job_model: JobModel, no_connections_secs: Optional[int]
) -> None:
conf = RunSpec.__response__.parse_raw(run_model.run_spec).configuration
if not is_core_model_instance(conf, DevEnvironmentConfiguration) or not isinstance(
if not isinstance(conf, DevEnvironmentConfiguration) or not isinstance(
conf.inactivity_duration, int
):
# reset in case inactivity_duration was disabled via in-place update
Expand Down
3 changes: 1 addition & 2 deletions src/dstack/_internal/server/services/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
ResourceExistsError,
ServerClientError,
)
from dstack._internal.core.models.common import is_core_model_instance
from dstack._internal.core.models.envs import Env
from dstack._internal.core.models.fleets import (
Fleet,
Expand Down Expand Up @@ -630,7 +629,7 @@ def _validate_fleet_spec(spec: FleetSpec):
if spec.configuration.ssh_config.ssh_key is not None:
_validate_ssh_key(spec.configuration.ssh_config.ssh_key)
for host in spec.configuration.ssh_config.hosts:
if is_core_model_instance(host, SSHHostParams) and host.ssh_key is not None:
if isinstance(host, SSHHostParams) and host.ssh_key is not None:
_validate_ssh_key(host.ssh_key)
_validate_internal_ips(spec.configuration.ssh_config)

Expand Down
3 changes: 1 addition & 2 deletions src/dstack/_internal/server/services/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
SSHError,
)
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import is_core_model_instance
from dstack._internal.core.models.configurations import RunConfigurationType
from dstack._internal.core.models.instances import InstanceStatus
from dstack._internal.core.models.runs import (
Expand Down Expand Up @@ -585,7 +584,7 @@ async def get_job_configured_volume_models(
job_volumes = interpolate_job_volumes(run_spec.configuration.volumes, job_num)
volume_models = []
for mount_point in job_volumes:
if not is_core_model_instance(mount_point, VolumeMountPoint):
if not isinstance(mount_point, VolumeMountPoint):
continue
if isinstance(mount_point.name, str):
names = [mount_point.name]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import dstack.version as version
from dstack._internal.core.errors import DockerRegistryError, ServerClientError
from dstack._internal.core.models.common import RegistryAuth, is_core_model_instance
from dstack._internal.core.models.common import RegistryAuth
from dstack._internal.core.models.configurations import (
PortMapping,
PythonVersion,
Expand Down Expand Up @@ -274,7 +274,7 @@ def interpolate_job_volumes(
if isinstance(mount_point, str):
# pydantic validator ensures strings are converted to MountPoint
continue
if not is_core_model_instance(mount_point, VolumeMountPoint):
if not isinstance(mount_point, VolumeMountPoint):
job_volumes.append(mount_point.copy())
continue
if isinstance(mount_point.name, str):
Expand Down
3 changes: 1 addition & 2 deletions src/dstack/_internal/server/services/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
)
from dstack._internal.core.backends.models import BackendInfo
from dstack._internal.core.errors import ForbiddenError, ResourceExistsError, ServerClientError
from dstack._internal.core.models.common import is_core_model_instance
from dstack._internal.core.models.projects import Member, MemberPermissions, Project
from dstack._internal.core.models.users import GlobalRole, ProjectRole
from dstack._internal.server.models import MemberModel, ProjectModel, UserModel
Expand Down Expand Up @@ -386,7 +385,7 @@ def project_model_to_project(
backend_config = get_backend_config_from_backend_model(
configurator, b, include_creds=False
)
if is_core_model_instance(backend_config, DstackBackendConfig):
if isinstance(backend_config, DstackBackendConfig):
for backend_type in backend_config.base_backends:
backends.append(
BackendInfo(
Expand Down
3 changes: 1 addition & 2 deletions src/dstack/_internal/server/services/proxy/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import dstack._internal.server.services.jobs as jobs_services
from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT
from dstack._internal.core.models.common import is_core_model_instance
from dstack._internal.core.models.configurations import ServiceConfiguration
from dstack._internal.core.models.instances import RemoteConnectionInfo, SSHConnectionParams
from dstack._internal.core.models.runs import (
Expand Down Expand Up @@ -64,7 +63,7 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
return None
run = jobs[0].run
run_spec = RunSpec.__response__.parse_raw(run.run_spec)
if not is_core_model_instance(run_spec.configuration, ServiceConfiguration):
if not isinstance(run_spec.configuration, ServiceConfiguration):
return None
replicas = []
for job in jobs:
Expand Down
4 changes: 2 additions & 2 deletions src/dstack/_internal/server/services/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
ResourceNotExistsError,
ServerClientError,
)
from dstack._internal.core.models.common import ApplyAction, is_core_model_instance
from dstack._internal.core.models.common import ApplyAction
from dstack._internal.core.models.configurations import AnyRunConfiguration
from dstack._internal.core.models.instances import (
InstanceAvailability,
Expand Down Expand Up @@ -748,7 +748,7 @@ async def _generate_run_name(

def check_run_spec_requires_instance_mounts(run_spec: RunSpec) -> bool:
return any(
is_core_model_instance(mp, InstanceMountPoint) and not mp.optional
isinstance(mp, InstanceMountPoint) and not mp.optional
for mp in run_spec.configuration.volumes
)

Expand Down
Loading
Loading