diff --git a/docs/docs/concepts/plugins.md b/docs/docs/concepts/plugins.md
new file mode 100644
index 000000000..ae8c287b9
--- /dev/null
+++ b/docs/docs/concepts/plugins.md
@@ -0,0 +1,116 @@
+# Plugins
+
+The `dstack` plugin system allows extending `dstack` server functionality using external Python packages.
+
+!!! info "Experimental"
+ Plugins are currently an _experimental_ feature.
+ Backward compatibility is not guaranteed across releases.
+
+## Enable plugins
+
+To enable a plugin, list it under `plugins` in [`server/config.yml`](../reference/server/config.yml.md):
+
+
+
+```yaml
+plugins:
+ - my_dstack_plugin
+ - some_other_plugin
+projects:
+- name: main
+```
+
+
+
+On the next server restart, you should see a log message indicating that the plugin is loaded.
+
+## Create plugins
+
+To create a plugin, create a Python package that implements a subclass of
+`dstack.plugins.Plugin` and exports this subclass as a "dstack.plugins" entry point.
+
+1. Init the plugin package:
+
+
+
+ ```shell
+ $ uv init --library
+ ```
+
+
+
+2. Define `ApplyPolicy` and `Plugin` subclasses:
+
+
+
+ ```python
+ from dstack.plugins import ApplyPolicy, Plugin, RunSpec, get_plugin_logger
+
+ logger = get_plugin_logger(__name__)
+
+ class ExamplePolicy(ApplyPolicy):
+ def on_run_apply(self, user: str, project: str, spec: RunSpec) -> RunSpec:
+ # ...
+ return spec
+
+ class ExamplePlugin(Plugin):
+
+ def get_apply_policies(self) -> list[ApplyPolicy]:
+ return [ExamplePolicy()]
+ ```
+
+
+
+3. Specify a "dstack.plugins" entry point in `pyproject.toml`:
+
+
+
+ ```toml
+ [project.entry-points."dstack.plugins"]
+ example_plugin = "example_plugin:ExamplePlugin"
+ ```
+
+
+
+Then you can install the plugin package into your Python environment and enable it via `server/config.yml`.
+
+??? info "Plugins in Docker"
+ If you deploy `dstack` using a Docker image you can add plugins either
+ by including them in your custom image built upon the `dstack` server image,
+ or by mounting installed plugins as volumes.
+
+## Apply policies
+
+Currently the only plugin functionality is apply policies.
+Apply policies allow modifying specs of runs, fleets, volumes, and gateways submitted on `dstack apply`.
+Subclass `dstack.plugins.ApplyPolicy` to implement them.
+
+Here's an example of how to enforce certain rules using apply policies:
+
+
+
+```python
+class ExamplePolicy(ApplyPolicy):
+ def on_run_apply(self, user: str, project: str, spec: RunSpec) -> RunSpec:
+ # Forcing some limits
+ spec.configuration.max_price = 2.0
+ spec.configuration.max_duration = "1d"
+ # Setting some extra tags
+ if spec.configuration.tags is None:
+ spec.configuration.tags = {}
+ spec.configuration.tags |= {
+ "team": "my_team",
+ }
+ # Forbid something
+ if spec.configuration.privileged:
+ logger.warning("User %s tries to run privileged containers", user)
+ raise ValueError("Running privileged containers is forbidden")
+ # Set some service-specific properties
+ if isinstance(spec.configuration, Service):
+ spec.configuration.https = True
+ return spec
+```
+
+
+
+For more information on the plugin development, see the [plugin example](https://github.com/dstackai/dstack/tree/master/examples/plugins/example_plugin).
diff --git a/examples/plugins/example_plugin/.python-version b/examples/plugins/example_plugin/.python-version
new file mode 100644
index 000000000..2c0733315
--- /dev/null
+++ b/examples/plugins/example_plugin/.python-version
@@ -0,0 +1 @@
+3.11
diff --git a/examples/plugins/example_plugin/README.md b/examples/plugins/example_plugin/README.md
new file mode 100644
index 000000000..112cb3a29
--- /dev/null
+++ b/examples/plugins/example_plugin/README.md
@@ -0,0 +1,52 @@
+## Overview
+
+This is a basic `dstack` plugin example.
+You can use it as a reference point when implementing new `dstack` plugins.
+
+## Steps
+
+1. Init the plugin package:
+
+ ```
+ uv init --library
+ ```
+
+2. Define `ApplyPolicy` and `Plugin` subclasses:
+
+ ```python
+ from dstack.plugins import ApplyPolicy, Plugin, RunSpec, get_plugin_logger
+
+
+ logger = get_plugin_logger(__name__)
+
+
+ class ExamplePolicy(ApplyPolicy):
+
+ def on_run_apply(self, user: str, project: str, spec: RunSpec) -> RunSpec:
+ # ...
+ return spec
+
+
+ class ExamplePlugin(Plugin):
+
+ def get_apply_policies(self) -> list[ApplyPolicy]:
+ return [ExamplePolicy()]
+
+ ```
+
+3. Specify a "dstack.plugins" entry point in `pyproject.toml`:
+
+ ```toml
+ [project.entry-points."dstack.plugins"]
+ example_plugin = "example_plugin:ExamplePlugin"
+ ```
+
+4. Make sure to install the plugin and enable it in the `server/config.yml`:
+
+ ```yaml
+ plugins:
+ - example_plugin
+ projects:
+ - name: main
+ # ...
+ ```
diff --git a/examples/plugins/example_plugin/pyproject.toml b/examples/plugins/example_plugin/pyproject.toml
new file mode 100644
index 000000000..bc83d509a
--- /dev/null
+++ b/examples/plugins/example_plugin/pyproject.toml
@@ -0,0 +1,17 @@
+[project]
+name = "example-plugin"
+version = "0.1.0"
+description = "A dstack plugin example"
+readme = "README.md"
+authors = [
+ { name = "Victor Skvortsov", email = "victor@dstack.ai" }
+]
+requires-python = ">=3.9"
+dependencies = []
+
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[project.entry-points."dstack.plugins"]
+example_plugin = "example_plugin:ExamplePlugin"
diff --git a/examples/plugins/example_plugin/src/example_plugin/__init__.py b/examples/plugins/example_plugin/src/example_plugin/__init__.py
new file mode 100644
index 000000000..f431e5c28
--- /dev/null
+++ b/examples/plugins/example_plugin/src/example_plugin/__init__.py
@@ -0,0 +1,34 @@
+from dstack.api import Service
+from dstack.plugins import ApplyPolicy, GatewaySpec, Plugin, RunSpec, get_plugin_logger
+
+logger = get_plugin_logger(__name__)
+
+
+class ExamplePolicy(ApplyPolicy):
+ def on_run_apply(self, user: str, project: str, spec: RunSpec) -> RunSpec:
+ # Forcing some limits
+ spec.configuration.max_price = 2.0
+ spec.configuration.max_duration = "1d"
+ # Setting some extra tags
+ if spec.configuration.tags is None:
+ spec.configuration.tags = {}
+ spec.configuration.tags |= {
+ "team": "my_team",
+ }
+ # Forbid something
+ if spec.configuration.privileged:
+ logger.warning("User %s tries to run privileged containers", user)
+ raise ValueError("Running privileged containers is forbidden")
+ # Set some service-specific properties
+ if isinstance(spec.configuration, Service):
+ spec.configuration.https = True
+ return spec
+
+ def on_gateway_apply(self, user: str, project: str, spec: GatewaySpec) -> GatewaySpec:
+ # Forbid creating new gateways altogether
+ raise ValueError("Creating gateways is forbidden")
+
+
+class ExamplePlugin(Plugin):
+ def get_apply_policies(self) -> list[ApplyPolicy]:
+ return [ExamplePolicy()]
diff --git a/examples/plugins/example_plugin/src/example_plugin/py.typed b/examples/plugins/example_plugin/src/example_plugin/py.typed
new file mode 100644
index 000000000..e69de29bb
diff --git a/mkdocs.yml b/mkdocs.yml
index c71075435..bcb74d76b 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -206,6 +206,7 @@ nav:
- Backends: docs/concepts/backends.md
- Projects: docs/concepts/projects.md
- Gateways: docs/concepts/gateways.md
+ - Plugins: docs/concepts/plugins.md
- Guides:
- Protips: docs/guides/protips.md
- Server deployment: docs/guides/server-deployment.md
diff --git a/pyproject.toml b/pyproject.toml
index 8ec49b0ea..a501a880d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -125,6 +125,7 @@ server = [
"python-json-logger>=3.1.0",
"prometheus-client",
"grpcio>=1.50",
+ "backports.entry-points-selectable",
]
aws = [
"boto3",
diff --git a/src/dstack/_internal/core/models/fleets.py b/src/dstack/_internal/core/models/fleets.py
index 044ea05ed..0e5580309 100644
--- a/src/dstack/_internal/core/models/fleets.py
+++ b/src/dstack/_internal/core/models/fleets.py
@@ -269,6 +269,8 @@ class FleetSpec(CoreModel):
configuration_path: Optional[str] = None
profile: Profile
autocreated: bool = False
+ # merged_profile stores profile parameters merged from profile and configuration.
+ # Read profile parameters from merged_profile instead of profile directly.
# TODO: make merged_profile a computed field after migrating to pydanticV2
merged_profile: Annotated[Profile, Field(exclude=True)] = None
diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py
index 14a776cc9..fd675dc87 100644
--- a/src/dstack/_internal/core/models/runs.py
+++ b/src/dstack/_internal/core/models/runs.py
@@ -357,6 +357,8 @@ class RunSpec(CoreModel):
description="The contents of the SSH public key that will be used to connect to the run."
),
]
+ # merged_profile stores profile parameters merged from profile and configuration.
+ # Read profile parameters from merged_profile instead of profile directly.
# TODO: make merged_profile a computed field after migrating to pydanticV2
merged_profile: Annotated[Profile, Field(exclude=True)] = None
diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py
index 6215a0353..efa8600a4 100644
--- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py
+++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py
@@ -197,7 +197,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
pool_instances = list(res.unique().scalars().all())
instances_ids = sorted([i.id for i in pool_instances])
if get_db().dialect_name == "sqlite":
- # Start new transaction to see commited changes after lock
+ # Start new transaction to see committed changes after lock
await session.commit()
async with get_locker().lock_ctx(InstanceModel.__tablename__, instances_ids):
# If another job freed the instance but is still trying to detach volumes,
diff --git a/src/dstack/_internal/server/routers/gateways.py b/src/dstack/_internal/server/routers/gateways.py
index ce0d94d60..604519af0 100644
--- a/src/dstack/_internal/server/routers/gateways.py
+++ b/src/dstack/_internal/server/routers/gateways.py
@@ -47,9 +47,10 @@ async def create_gateway(
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
) -> models.Gateway:
- _, project = user_project
+ user, project = user_project
return await gateways.create_gateway(
session=session,
+ user=user,
project=project,
configuration=body.configuration,
)
diff --git a/src/dstack/_internal/server/services/config.py b/src/dstack/_internal/server/services/config.py
index f56303f05..1b7268b39 100644
--- a/src/dstack/_internal/server/services/config.py
+++ b/src/dstack/_internal/server/services/config.py
@@ -29,6 +29,7 @@
DefaultPermissions,
set_default_permissions,
)
+from dstack._internal.server.services.plugins import load_plugins
from dstack._internal.utils.logging import get_logger
logger = get_logger(__name__)
@@ -38,7 +39,7 @@
# If a collection has nested collections, it will be assigned the block style. Otherwise it will have the flow style.
#
# We want mapping to always be displayed in block-style but lists without nested objects in flow-style.
-# So we define a custom representeter
+# So we define a custom representer.
def seq_representer(dumper, sequence):
@@ -75,7 +76,10 @@ class ServerConfig(CoreModel):
] = None
default_permissions: Annotated[
Optional[DefaultPermissions], Field(description="The default user permissions")
- ]
+ ] = None
+ plugins: Annotated[
+ Optional[List[str]], Field(description="The server-side plugins to enable")
+ ] = None
class ServerConfigManager:
@@ -112,6 +116,7 @@ async def apply_config(self, session: AsyncSession, owner: UserModel):
await self._apply_project_config(
session=session, owner=owner, project_config=project_config
)
+ load_plugins(enabled_plugins=self.config.plugins or [])
async def _apply_project_config(
self,
diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py
index f2263de1b..d0e17cbb5 100644
--- a/src/dstack/_internal/server/services/fleets.py
+++ b/src/dstack/_internal/server/services/fleets.py
@@ -55,6 +55,7 @@
get_locker,
string_to_lock_id,
)
+from dstack._internal.server.services.plugins import apply_plugin_policies
from dstack._internal.server.services.projects import (
get_member,
get_member_permissions,
@@ -234,7 +235,14 @@ async def get_plan(
user: UserModel,
spec: FleetSpec,
) -> FleetPlan:
+ # Spec must be copied by parsing to calculate merged_profile
effective_spec = FleetSpec.parse_obj(spec.dict())
+ effective_spec = apply_plugin_policies(
+ user=user.name,
+ project=project.name,
+ spec=effective_spec,
+ )
+ effective_spec = FleetSpec.parse_obj(effective_spec.dict())
current_fleet: Optional[Fleet] = None
current_fleet_id: Optional[uuid.UUID] = None
if effective_spec.configuration.name is not None:
@@ -330,6 +338,13 @@ async def create_fleet(
user: UserModel,
spec: FleetSpec,
) -> Fleet:
+ # Spec must be copied by parsing to calculate merged_profile
+ spec = apply_plugin_policies(
+ user=user.name,
+ project=project.name,
+ spec=spec,
+ )
+ spec = FleetSpec.parse_obj(spec.dict())
_validate_fleet_spec(spec)
if spec.configuration.ssh_config is not None:
diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py
index 3a0cbfa87..d271d9fd7 100644
--- a/src/dstack/_internal/server/services/gateways/__init__.py
+++ b/src/dstack/_internal/server/services/gateways/__init__.py
@@ -31,13 +31,19 @@
Gateway,
GatewayComputeConfiguration,
GatewayConfiguration,
+ GatewaySpec,
GatewayStatus,
LetsEncryptGatewayCertificate,
)
from dstack._internal.core.services import validate_dstack_resource_name
from dstack._internal.server import settings
from dstack._internal.server.db import get_db
-from dstack._internal.server.models import GatewayComputeModel, GatewayModel, ProjectModel
+from dstack._internal.server.models import (
+ GatewayComputeModel,
+ GatewayModel,
+ ProjectModel,
+ UserModel,
+)
from dstack._internal.server.services.backends import (
check_backend_type_available,
get_project_backend_by_type_or_error,
@@ -50,6 +56,7 @@
get_locker,
string_to_lock_id,
)
+from dstack._internal.server.services.plugins import apply_plugin_policies
from dstack._internal.server.utils.common import gather_map_async
from dstack._internal.utils.common import get_current_datetime, run_async
from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes
@@ -129,9 +136,17 @@ async def create_gateway_compute(
async def create_gateway(
session: AsyncSession,
+ user: UserModel,
project: ProjectModel,
configuration: GatewayConfiguration,
) -> Gateway:
+ spec = apply_plugin_policies(
+ user=user.name,
+ project=project.name,
+ # Create pseudo spec until the gateway API is updated to accept spec
+ spec=GatewaySpec(configuration=configuration),
+ )
+ configuration = spec.configuration
_validate_gateway_configuration(configuration)
backend_model, _ = await get_project_backend_with_model_by_type_or_error(
@@ -140,7 +155,7 @@ async def create_gateway(
lock_namespace = f"gateway_names_{project.name}"
if get_db().dialect_name == "sqlite":
- # Start new transaction to see commited changes after lock
+ # Start new transaction to see committed changes after lock
await session.commit()
elif get_db().dialect_name == "postgresql":
await session.execute(
diff --git a/src/dstack/_internal/server/services/plugins.py b/src/dstack/_internal/server/services/plugins.py
new file mode 100644
index 000000000..a8e5be8a0
--- /dev/null
+++ b/src/dstack/_internal/server/services/plugins.py
@@ -0,0 +1,77 @@
+import itertools
+from importlib import import_module
+
+from backports.entry_points_selectable import entry_points # backport for Python 3.9
+
+from dstack._internal.core.errors import ServerClientError
+from dstack._internal.utils.logging import get_logger
+from dstack.plugins import ApplyPolicy, ApplySpec, Plugin
+
+logger = get_logger(__name__)
+
+
+_PLUGINS: list[Plugin] = []
+
+
+def load_plugins(enabled_plugins: list[str]):
+ _PLUGINS.clear()
+ plugins_entrypoints = entry_points(group="dstack.plugins")
+ plugins_to_load = enabled_plugins.copy()
+ for entrypoint in plugins_entrypoints:
+ if entrypoint.name not in enabled_plugins:
+ logger.info(
+ ("Found not enabled plugin %s. Plugin will not be loaded."),
+ entrypoint.name,
+ )
+ continue
+ try:
+ module_path, _, class_name = entrypoint.value.partition(":")
+ module = import_module(module_path)
+ except ImportError:
+ logger.warning(
+ (
+ "Failed to load plugin %s when importing %s."
+ " Ensure the module is on the import path."
+ ),
+ entrypoint.name,
+ entrypoint.value,
+ )
+ continue
+ plugin_class = getattr(module, class_name, None)
+ if plugin_class is None:
+ logger.warning(
+ ("Failed to load plugin %s: plugin class %s not found in module %s."),
+ entrypoint.name,
+ class_name,
+ module_path,
+ )
+ continue
+ if not issubclass(plugin_class, Plugin):
+ logger.warning(
+ ("Failed to load plugin %s: plugin class %s is not a subclass of Plugin."),
+ entrypoint.name,
+ class_name,
+ )
+ continue
+ plugins_to_load.remove(entrypoint.name)
+ _PLUGINS.append(plugin_class())
+ logger.info("Loaded plugin %s", entrypoint.name)
+ if plugins_to_load:
+ logger.warning("Enabled plugins not found: %s", plugins_to_load)
+
+
+def apply_plugin_policies(user: str, project: str, spec: ApplySpec) -> ApplySpec:
+ policies = _get_apply_policies()
+ for policy in policies:
+ try:
+ spec = policy.on_apply(user=user, project=project, spec=spec)
+ except ValueError as e:
+ msg = None
+ if len(e.args) > 0:
+ msg = e.args[0]
+ raise ServerClientError(msg)
+ return spec
+
+
+def _get_apply_policies() -> list[ApplyPolicy]:
+ return list(itertools.chain(*[p.get_apply_policies() for p in _PLUGINS]))
diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py
index 9c591f9db..354151401 100644
--- a/src/dstack/_internal/server/services/runs.py
+++ b/src/dstack/_internal/server/services/runs.py
@@ -79,6 +79,7 @@
from dstack._internal.server.services.locking import get_locker, string_to_lock_id
from dstack._internal.server.services.logging import fmt
from dstack._internal.server.services.offers import get_offers_by_requirements
+from dstack._internal.server.services.plugins import apply_plugin_policies
from dstack._internal.server.services.projects import list_project_models, list_user_project_models
from dstack._internal.server.services.users import get_user_model_by_name
from dstack._internal.utils.logging import get_logger
@@ -279,7 +280,14 @@ async def get_plan(
run_spec: RunSpec,
max_offers: Optional[int],
) -> RunPlan:
+ # Spec must be copied by parsing to calculate merged_profile
effective_run_spec = RunSpec.parse_obj(run_spec.dict())
+ effective_run_spec = apply_plugin_policies(
+ user=user.name,
+ project=project.name,
+ spec=effective_run_spec,
+ )
+ effective_run_spec = RunSpec.parse_obj(effective_run_spec.dict())
_validate_run_spec_and_set_defaults(effective_run_spec)
profile = effective_run_spec.merged_profile
@@ -370,28 +378,36 @@ async def apply_plan(
plan: ApplyRunPlanInput,
force: bool,
) -> Run:
- _validate_run_spec_and_set_defaults(plan.run_spec)
- if plan.run_spec.run_name is None:
+ run_spec = plan.run_spec
+ run_spec = apply_plugin_policies(
+ user=user.name,
+ project=project.name,
+ spec=run_spec,
+ )
+ # Spec must be copied by parsing to calculate merged_profile
+ run_spec = RunSpec.parse_obj(run_spec.dict())
+ _validate_run_spec_and_set_defaults(run_spec)
+ if run_spec.run_name is None:
return await submit_run(
session=session,
user=user,
project=project,
- run_spec=plan.run_spec,
+ run_spec=run_spec,
)
current_resource = await get_run_by_name(
session=session,
project=project,
- run_name=plan.run_spec.run_name,
+ run_name=run_spec.run_name,
)
if current_resource is None or current_resource.status.is_finished():
return await submit_run(
session=session,
user=user,
project=project,
- run_spec=plan.run_spec,
+ run_spec=run_spec,
)
try:
- _check_can_update_run_spec(current_resource.run_spec, plan.run_spec)
+ _check_can_update_run_spec(current_resource.run_spec, run_spec)
except ServerClientError:
# The except is only needed to raise an appropriate error if run is active
if not current_resource.status.is_finished():
@@ -409,14 +425,12 @@ async def apply_plan(
# FIXME: potentially long write transaction
# Avoid getting run_model after update
await session.execute(
- update(RunModel)
- .where(RunModel.id == current_resource.id)
- .values(run_spec=plan.run_spec.json())
+ update(RunModel).where(RunModel.id == current_resource.id).values(run_spec=run_spec.json())
)
run = await get_run_by_name(
session=session,
project=project,
- run_name=plan.run_spec.run_name,
+ run_name=run_spec.run_name,
)
return common_utils.get_or_error(run)
@@ -436,7 +450,7 @@ async def submit_run(
lock_namespace = f"run_names_{project.name}"
if get_db().dialect_name == "sqlite":
- # Start new transaction to see commited changes after lock
+ # Start new transaction to see committed changes after lock
await session.commit()
elif get_db().dialect_name == "postgresql":
await session.execute(
diff --git a/src/dstack/_internal/server/services/volumes.py b/src/dstack/_internal/server/services/volumes.py
index 228343e1e..689520620 100644
--- a/src/dstack/_internal/server/services/volumes.py
+++ b/src/dstack/_internal/server/services/volumes.py
@@ -21,6 +21,7 @@
VolumeConfiguration,
VolumeInstance,
VolumeProvisioningData,
+ VolumeSpec,
VolumeStatus,
)
from dstack._internal.core.services import validate_dstack_resource_name
@@ -38,6 +39,7 @@
get_locker,
string_to_lock_id,
)
+from dstack._internal.server.services.plugins import apply_plugin_policies
from dstack._internal.server.services.projects import list_project_models, list_user_project_models
from dstack._internal.utils import common, random_names
from dstack._internal.utils.logging import get_logger
@@ -203,11 +205,18 @@ async def create_volume(
user: UserModel,
configuration: VolumeConfiguration,
) -> Volume:
+ spec = apply_plugin_policies(
+ user=user.name,
+ project=project.name,
+ # Create pseudo spec until the volume API is updated to accept spec
+ spec=VolumeSpec(configuration=configuration),
+ )
+ configuration = spec.configuration
_validate_volume_configuration(configuration)
lock_namespace = f"volume_names_{project.name}"
if get_db().dialect_name == "sqlite":
- # Start new transaction to see commited changes after lock
+ # Start new transaction to see committed changes after lock
await session.commit()
elif get_db().dialect_name == "postgresql":
await session.execute(
diff --git a/src/dstack/plugins/__init__.py b/src/dstack/plugins/__init__.py
new file mode 100644
index 000000000..93970043a
--- /dev/null
+++ b/src/dstack/plugins/__init__.py
@@ -0,0 +1,8 @@
+# ruff: noqa: F401
+from dstack._internal.core.models.fleets import FleetSpec
+from dstack._internal.core.models.gateways import GatewaySpec
+from dstack._internal.core.models.runs import RunSpec
+from dstack._internal.core.models.volumes import VolumeSpec
+from dstack.plugins._base import ApplyPolicy, Plugin
+from dstack.plugins._models import ApplySpec
+from dstack.plugins._utils import get_plugin_logger
diff --git a/src/dstack/plugins/_base.py b/src/dstack/plugins/_base.py
new file mode 100644
index 000000000..a30ae0c33
--- /dev/null
+++ b/src/dstack/plugins/_base.py
@@ -0,0 +1,72 @@
+from dstack._internal.core.models.fleets import FleetSpec
+from dstack._internal.core.models.gateways import GatewaySpec
+from dstack._internal.core.models.runs import RunSpec
+from dstack._internal.core.models.volumes import VolumeSpec
+from dstack.plugins._models import ApplySpec
+
+
+class ApplyPolicy:
+ """
+ A base apply policy class to modify specs on `dstack apply`.
+ Subclass it and return the subclass instance in `Plugin.get_apply_policies()`.
+ """
+
+ def on_apply(self, user: str, project: str, spec: ApplySpec) -> ApplySpec:
+ """
+ Modify `spec` before it's applied.
+ Raise `ValueError` for `spec` to be rejected as invalid.
+
+ This method can be called twice:
+ * first when a user gets a plan
+ * second when a user applies a plan
+
+ In both cases, the original spec is passed, so the method does not
+ need to check if it modified the spec before.
+
+ It's safe to modify and return `spec` without copying.
+ """
+ if isinstance(spec, RunSpec):
+ return self.on_run_apply(user=user, project=project, spec=spec)
+ if isinstance(spec, FleetSpec):
+ return self.on_fleet_apply(user=user, project=project, spec=spec)
+ if isinstance(spec, VolumeSpec):
+ return self.on_volume_apply(user=user, project=project, spec=spec)
+ if isinstance(spec, GatewaySpec):
+ return self.on_gateway_apply(user=user, project=project, spec=spec)
+ raise ValueError(f"Unknown spec type {type(spec)}")
+
+ def on_run_apply(self, user: str, project: str, spec: RunSpec) -> RunSpec:
+ """
+ Called by the default `on_apply()` implementation for runs.
+ """
+ return spec
+
+ def on_fleet_apply(self, user: str, project: str, spec: FleetSpec) -> FleetSpec:
+ """
+ Called by the default `on_apply()` implementation for fleets.
+ """
+ return spec
+
+ def on_volume_apply(self, user: str, project: str, spec: VolumeSpec) -> VolumeSpec:
+ """
+ Called by the default `on_apply()` implementation for volumes.
+ """
+ return spec
+
+ def on_gateway_apply(self, user: str, project: str, spec: GatewaySpec) -> GatewaySpec:
+ """
+ Called by the default `on_apply()` implementation for gateways.
+ """
+ return spec
+
+
+class Plugin:
+ """
+ A base plugin class.
+ Plugins must subclass it, implement public methods,
+ and register the subclass as an entrypoint of the package
+ (https://packaging.python.org/en/latest/specifications/entry-points/).
+ """
+
+ def get_apply_policies(self) -> list[ApplyPolicy]:
+ return []
diff --git a/src/dstack/plugins/_models.py b/src/dstack/plugins/_models.py
new file mode 100644
index 000000000..124e0e593
--- /dev/null
+++ b/src/dstack/plugins/_models.py
@@ -0,0 +1,8 @@
+from typing import TypeVar
+
+from dstack._internal.core.models.fleets import FleetSpec
+from dstack._internal.core.models.gateways import GatewaySpec
+from dstack._internal.core.models.runs import RunSpec
+from dstack._internal.core.models.volumes import VolumeSpec
+
+ApplySpec = TypeVar("ApplySpec", RunSpec, FleetSpec, VolumeSpec, GatewaySpec)
diff --git a/src/dstack/plugins/_utils.py b/src/dstack/plugins/_utils.py
new file mode 100644
index 000000000..9de3ff260
--- /dev/null
+++ b/src/dstack/plugins/_utils.py
@@ -0,0 +1,19 @@
+import logging
+
+from dstack._internal.utils.logging import get_logger
+
+
+def get_plugin_logger(name: str) -> logging.Logger:
+ """
+ Use this function to set up loggers in plugins.
+
+ Put at the top of the plugin modules:
+
+ ```
+ from dstack.plugins import get_plugin_logger
+
+ logger = get_plugin_logger(__name__)
+ ```
+
+ """
+ return get_logger(f"dstack.plugins.{name}")
diff --git a/src/tests/_internal/server/services/test_plugins.py b/src/tests/_internal/server/services/test_plugins.py
new file mode 100644
index 000000000..460764c4d
--- /dev/null
+++ b/src/tests/_internal/server/services/test_plugins.py
@@ -0,0 +1,245 @@
+import logging
+from importlib.metadata import EntryPoint
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from dstack._internal.server.services.plugins import _PLUGINS, load_plugins
+from dstack.plugins import Plugin
+
+
+class DummyPlugin1(Plugin):
+ pass
+
+
+class DummyPlugin2(Plugin):
+ pass
+
+
+class NotAPlugin:
+ pass
+
+
+@pytest.fixture(autouse=True)
+def clear_plugins():
+ _PLUGINS.clear()
+ yield
+ _PLUGINS.clear()
+
+
+class TestLoadPlugins:
+ @patch("dstack._internal.server.services.plugins.entry_points")
+ @patch("dstack._internal.server.services.plugins.import_module")
+ def test_load_single_plugin(self, mock_import_module, mock_entry_points, caplog):
+ mock_entry_points.return_value = [
+ EntryPoint(
+ name="plugin1",
+ value="dummy.plugins:DummyPlugin1",
+ group="dstack.plugins",
+ )
+ ]
+ mock_module = MagicMock()
+ mock_module.DummyPlugin1 = DummyPlugin1
+ mock_import_module.return_value = mock_module
+
+ with caplog.at_level(logging.INFO):
+ load_plugins(["plugin1"])
+
+ assert len(_PLUGINS) == 1
+ assert isinstance(_PLUGINS[0], DummyPlugin1)
+ mock_entry_points.assert_called_once_with(group="dstack.plugins")
+ mock_import_module.assert_called_once_with("dummy.plugins")
+ assert "Loaded plugin plugin1" in caplog.text
+
+ @patch("dstack._internal.server.services.plugins.entry_points")
+ @patch("dstack._internal.server.services.plugins.import_module")
+ def test_load_multiple_plugins(self, mock_import_module, mock_entry_points, caplog):
+ mock_entry_points.return_value = [
+ EntryPoint(
+ name="plugin1",
+ value="dummy.plugins:DummyPlugin1",
+ group="dstack.plugins",
+ ),
+ EntryPoint(
+ name="plugin2",
+ value="dummy.plugins:DummyPlugin2",
+ group="dstack.plugins",
+ ),
+ ]
+ mock_module = MagicMock()
+ mock_module.DummyPlugin1 = DummyPlugin1
+ mock_module.DummyPlugin2 = DummyPlugin2
+ mock_import_module.return_value = mock_module
+
+ with caplog.at_level(logging.INFO):
+ load_plugins(["plugin1", "plugin2"])
+
+ assert len(_PLUGINS) == 2
+ assert isinstance(_PLUGINS[0], DummyPlugin1)
+ assert isinstance(_PLUGINS[1], DummyPlugin2)
+ assert "Loaded plugin plugin1" in caplog.text
+ assert "Loaded plugin plugin2" in caplog.text
+
+ @patch("dstack._internal.server.services.plugins.entry_points")
+ @patch("dstack._internal.server.services.plugins.import_module")
+ def test_plugin_not_enabled(self, mock_import_module, mock_entry_points, caplog):
+ mock_entry_points.return_value = [
+ EntryPoint(
+ name="plugin1",
+ value="dummy.plugins:DummyPlugin1",
+ group="dstack.plugins",
+ )
+ ]
+
+ with caplog.at_level(logging.INFO):
+ load_plugins([]) # Enable no plugins
+
+ assert len(_PLUGINS) == 0
+ mock_import_module.assert_not_called()
+ assert "Found not enabled plugin plugin1" in caplog.text
+
+ @patch("dstack._internal.server.services.plugins.entry_points")
+ @patch("dstack._internal.server.services.plugins.import_module")
+ def test_enabled_plugin_not_found(self, mock_import_module, mock_entry_points, caplog):
+ mock_entry_points.return_value = [
+ EntryPoint(
+ name="plugin1",
+ value="dummy.plugins:DummyPlugin1",
+ group="dstack.plugins",
+ )
+ ]
+
+ with caplog.at_level(logging.INFO):
+ load_plugins(["plugin2"]) # Enable a plugin that doesn't have an entry point
+
+ assert len(_PLUGINS) == 0
+ mock_import_module.assert_not_called()
+ assert "Found not enabled plugin plugin1" in caplog.text
+ assert "Enabled plugins not found: ['plugin2']" in caplog.text
+
+ @patch("dstack._internal.server.services.plugins.entry_points")
+ @patch(
+ "dstack._internal.server.services.plugins.import_module",
+ side_effect=ImportError("Module not found"),
+ )
+ def test_import_error(self, mock_import_module, mock_entry_points, caplog):
+ mock_entry_points.return_value = [
+ EntryPoint(
+ name="plugin1",
+ value="dummy.plugins:DummyPlugin1",
+ group="dstack.plugins",
+ )
+ ]
+
+ with caplog.at_level(logging.INFO):
+ load_plugins(["plugin1"])
+
+ assert len(_PLUGINS) == 0
+ assert (
+ "Failed to load plugin plugin1 when importing dummy.plugins:DummyPlugin1"
+ in caplog.text
+ )
+ assert "Enabled plugins not found: ['plugin1']" in caplog.text # Because loading failed
+
+ @patch("dstack._internal.server.services.plugins.entry_points")
+ @patch("dstack._internal.server.services.plugins.import_module")
+ def test_class_not_found(self, mock_import_module, mock_entry_points, caplog):
+ mock_entry_points.return_value = [
+ EntryPoint(
+ name="plugin1",
+ value="dummy.plugins:NonExistentClass",
+ group="dstack.plugins",
+ )
+ ]
+ mock_module = MagicMock()
+ # Simulate the class not being present
+ del mock_module.NonExistentClass
+ mock_import_module.return_value = mock_module
+
+ with caplog.at_level(logging.INFO):
+ load_plugins(["plugin1"])
+
+ assert len(_PLUGINS) == 0
+ assert (
+ "Failed to load plugin plugin1: plugin class NonExistentClass not found" in caplog.text
+ )
+ assert "Enabled plugins not found: ['plugin1']" in caplog.text
+
+ @patch("dstack._internal.server.services.plugins.entry_points")
+ @patch("dstack._internal.server.services.plugins.import_module")
+ def test_not_a_plugin_subclass(self, mock_import_module, mock_entry_points, caplog):
+ mock_entry_points.return_value = [
+ EntryPoint(
+ name="plugin1",
+ value="dummy.plugins:NotAPlugin",
+ group="dstack.plugins",
+ )
+ ]
+ mock_module = MagicMock()
+ mock_module.NotAPlugin = NotAPlugin
+ mock_import_module.return_value = mock_module
+
+ with caplog.at_level(logging.INFO):
+ load_plugins(["plugin1"])
+
+ assert len(_PLUGINS) == 0
+ assert (
+ "Failed to load plugin plugin1: plugin class NotAPlugin is not a subclass of Plugin"
+ in caplog.text
+ )
+ assert "Enabled plugins not found: ['plugin1']" in caplog.text
+
+ @patch("dstack._internal.server.services.plugins.entry_points")
+ @patch("dstack._internal.server.services.plugins.import_module")
+ def test_clears_existing_plugins(self, mock_import_module, mock_entry_points):
+ # Pre-populate _PLUGINS
+ _PLUGINS.append(DummyPlugin1())
+
+ mock_entry_points.return_value = [
+ EntryPoint(
+ name="plugin2",
+ value="dummy.plugins:DummyPlugin2",
+ group="dstack.plugins",
+ )
+ ]
+ mock_module = MagicMock()
+ mock_module.DummyPlugin2 = DummyPlugin2
+ mock_import_module.return_value = mock_module
+
+ load_plugins(["plugin2"])
+
+ assert len(_PLUGINS) == 1 # Should only contain plugin2
+ assert isinstance(_PLUGINS[0], DummyPlugin2)
+
+ @patch("dstack._internal.server.services.plugins.entry_points")
+ @patch("dstack._internal.server.services.plugins.import_module")
+ def test_load_no_plugins_found(self, mock_import_module, mock_entry_points, caplog):
+ mock_entry_points.return_value = [] # No entry points found
+
+ with caplog.at_level(logging.INFO):
+ load_plugins(["plugin1"]) # Try to enable one
+
+ assert len(_PLUGINS) == 0
+ mock_import_module.assert_not_called()
+ assert "Enabled plugins not found: ['plugin1']" in caplog.text
+
+ @patch("dstack._internal.server.services.plugins.entry_points")
+ @patch("dstack._internal.server.services.plugins.import_module")
+ def test_load_no_plugins_enabled(self, mock_import_module, mock_entry_points, caplog):
+ mock_entry_points.return_value = [
+ EntryPoint(
+ name="plugin1",
+ value="dummy.plugins:DummyPlugin1",
+ group="dstack.plugins",
+ )
+ ]
+
+ with caplog.at_level(logging.INFO):
+ load_plugins([]) # Enable none
+
+ assert len(_PLUGINS) == 0
+ mock_import_module.assert_not_called()
+ assert "Found not enabled plugin plugin1" in caplog.text
+ assert (
+ "Enabled plugins not found" not in caplog.text
+ ) # Should not warn if none were enabled