diff --git a/docs/docs/concepts/plugins.md b/docs/docs/concepts/plugins.md new file mode 100644 index 000000000..ae8c287b9 --- /dev/null +++ b/docs/docs/concepts/plugins.md @@ -0,0 +1,116 @@ +# Plugins + +The `dstack` plugin system allows extending `dstack` server functionality using external Python packages. + +!!! info "Experimental" + Plugins are currently an _experimental_ feature. + Backward compatibility is not guaranteed across releases. + +## Enable plugins + +To enable a plugin, list it under `plugins` in [`server/config.yml`](../reference/server/config.yml.md): + +
+ +```yaml +plugins: + - my_dstack_plugin + - some_other_plugin +projects: +- name: main +``` + +
+ +On the next server restart, you should see a log message indicating that the plugin is loaded. + +## Create plugins + +To create a plugin, create a Python package that implements a subclass of +`dstack.plugins.Plugin` and exports this subclass as a "dstack.plugins" entry point. + +1. Init the plugin package: + +
+ + ```shell + $ uv init --library + ``` + +
+ +2. Define `ApplyPolicy` and `Plugin` subclasses: + +
+ + ```python + from dstack.plugins import ApplyPolicy, Plugin, RunSpec, get_plugin_logger + + logger = get_plugin_logger(__name__) + + class ExamplePolicy(ApplyPolicy): + def on_run_apply(self, user: str, project: str, spec: RunSpec) -> RunSpec: + # ... + return spec + + class ExamplePlugin(Plugin): + + def get_apply_policies(self) -> list[ApplyPolicy]: + return [ExamplePolicy()] + ``` + +
+ +3. Specify a "dstack.plugins" entry point in `pyproject.toml`: + +
+ + ```toml + [project.entry-points."dstack.plugins"] + example_plugin = "example_plugin:ExamplePlugin" + ``` + +
+ +Then you can install the plugin package into your Python environment and enable it via `server/config.yml`. + +??? info "Plugins in Docker" + If you deploy `dstack` using a Docker image you can add plugins either + by including them in your custom image built upon the `dstack` server image, + or by mounting installed plugins as volumes. + +## Apply policies + +Currently the only plugin functionality is apply policies. +Apply policies allow modifying specs of runs, fleets, volumes, and gateways submitted on `dstack apply`. +Subclass `dstack.plugins.ApplyPolicy` to implement them. + +Here's an example of how to enforce certain rules using apply policies: + +
+ +```python +class ExamplePolicy(ApplyPolicy): + def on_run_apply(self, user: str, project: str, spec: RunSpec) -> RunSpec: + # Forcing some limits + spec.configuration.max_price = 2.0 + spec.configuration.max_duration = "1d" + # Setting some extra tags + if spec.configuration.tags is None: + spec.configuration.tags = {} + spec.configuration.tags |= { + "team": "my_team", + } + # Forbid something + if spec.configuration.privileged: + logger.warning("User %s tries to run privileged containers", user) + raise ValueError("Running privileged containers is forbidden") + # Set some service-specific properties + if isinstance(spec.configuration, Service): + spec.configuration.https = True + return spec +``` + +
+ +For more information on the plugin development, see the [plugin example](https://github.com/dstackai/dstack/tree/master/examples/plugins/example_plugin). diff --git a/examples/plugins/example_plugin/.python-version b/examples/plugins/example_plugin/.python-version new file mode 100644 index 000000000..2c0733315 --- /dev/null +++ b/examples/plugins/example_plugin/.python-version @@ -0,0 +1 @@ +3.11 diff --git a/examples/plugins/example_plugin/README.md b/examples/plugins/example_plugin/README.md new file mode 100644 index 000000000..112cb3a29 --- /dev/null +++ b/examples/plugins/example_plugin/README.md @@ -0,0 +1,52 @@ +## Overview + +This is a basic `dstack` plugin example. +You can use it as a reference point when implementing new `dstack` plugins. + +## Steps + +1. Init the plugin package: + + ``` + uv init --library + ``` + +2. Define `ApplyPolicy` and `Plugin` subclasses: + + ```python + from dstack.plugins import ApplyPolicy, Plugin, RunSpec, get_plugin_logger + + + logger = get_plugin_logger(__name__) + + + class ExamplePolicy(ApplyPolicy): + + def on_run_apply(self, user: str, project: str, spec: RunSpec) -> RunSpec: + # ... + return spec + + + class ExamplePlugin(Plugin): + + def get_apply_policies(self) -> list[ApplyPolicy]: + return [ExamplePolicy()] + + ``` + +3. Specify a "dstack.plugins" entry point in `pyproject.toml`: + + ```toml + [project.entry-points."dstack.plugins"] + example_plugin = "example_plugin:ExamplePlugin" + ``` + +4. Make sure to install the plugin and enable it in the `server/config.yml`: + + ```yaml + plugins: + - example_plugin + projects: + - name: main + # ... + ``` diff --git a/examples/plugins/example_plugin/pyproject.toml b/examples/plugins/example_plugin/pyproject.toml new file mode 100644 index 000000000..bc83d509a --- /dev/null +++ b/examples/plugins/example_plugin/pyproject.toml @@ -0,0 +1,17 @@ +[project] +name = "example-plugin" +version = "0.1.0" +description = "A dstack plugin example" +readme = "README.md" +authors = [ + { name = "Victor Skvortsov", email = "victor@dstack.ai" } +] +requires-python = ">=3.9" +dependencies = [] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project.entry-points."dstack.plugins"] +example_plugin = "example_plugin:ExamplePlugin" diff --git a/examples/plugins/example_plugin/src/example_plugin/__init__.py b/examples/plugins/example_plugin/src/example_plugin/__init__.py new file mode 100644 index 000000000..f431e5c28 --- /dev/null +++ b/examples/plugins/example_plugin/src/example_plugin/__init__.py @@ -0,0 +1,34 @@ +from dstack.api import Service +from dstack.plugins import ApplyPolicy, GatewaySpec, Plugin, RunSpec, get_plugin_logger + +logger = get_plugin_logger(__name__) + + +class ExamplePolicy(ApplyPolicy): + def on_run_apply(self, user: str, project: str, spec: RunSpec) -> RunSpec: + # Forcing some limits + spec.configuration.max_price = 2.0 + spec.configuration.max_duration = "1d" + # Setting some extra tags + if spec.configuration.tags is None: + spec.configuration.tags = {} + spec.configuration.tags |= { + "team": "my_team", + } + # Forbid something + if spec.configuration.privileged: + logger.warning("User %s tries to run privileged containers", user) + raise ValueError("Running privileged containers is forbidden") + # Set some service-specific properties + if isinstance(spec.configuration, Service): + spec.configuration.https = True + return spec + + def on_gateway_apply(self, user: str, project: str, spec: GatewaySpec) -> GatewaySpec: + # Forbid creating new gateways altogether + raise ValueError("Creating gateways is forbidden") + + +class ExamplePlugin(Plugin): + def get_apply_policies(self) -> list[ApplyPolicy]: + return [ExamplePolicy()] diff --git a/examples/plugins/example_plugin/src/example_plugin/py.typed b/examples/plugins/example_plugin/src/example_plugin/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/mkdocs.yml b/mkdocs.yml index c71075435..bcb74d76b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -206,6 +206,7 @@ nav: - Backends: docs/concepts/backends.md - Projects: docs/concepts/projects.md - Gateways: docs/concepts/gateways.md + - Plugins: docs/concepts/plugins.md - Guides: - Protips: docs/guides/protips.md - Server deployment: docs/guides/server-deployment.md diff --git a/pyproject.toml b/pyproject.toml index 8ec49b0ea..a501a880d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,6 +125,7 @@ server = [ "python-json-logger>=3.1.0", "prometheus-client", "grpcio>=1.50", + "backports.entry-points-selectable", ] aws = [ "boto3", diff --git a/src/dstack/_internal/core/models/fleets.py b/src/dstack/_internal/core/models/fleets.py index 044ea05ed..0e5580309 100644 --- a/src/dstack/_internal/core/models/fleets.py +++ b/src/dstack/_internal/core/models/fleets.py @@ -269,6 +269,8 @@ class FleetSpec(CoreModel): configuration_path: Optional[str] = None profile: Profile autocreated: bool = False + # merged_profile stores profile parameters merged from profile and configuration. + # Read profile parameters from merged_profile instead of profile directly. # TODO: make merged_profile a computed field after migrating to pydanticV2 merged_profile: Annotated[Profile, Field(exclude=True)] = None diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index 14a776cc9..fd675dc87 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -357,6 +357,8 @@ class RunSpec(CoreModel): description="The contents of the SSH public key that will be used to connect to the run." ), ] + # merged_profile stores profile parameters merged from profile and configuration. + # Read profile parameters from merged_profile instead of profile directly. # TODO: make merged_profile a computed field after migrating to pydanticV2 merged_profile: Annotated[Profile, Field(exclude=True)] = None diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index 6215a0353..efa8600a4 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -197,7 +197,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): pool_instances = list(res.unique().scalars().all()) instances_ids = sorted([i.id for i in pool_instances]) if get_db().dialect_name == "sqlite": - # Start new transaction to see commited changes after lock + # Start new transaction to see committed changes after lock await session.commit() async with get_locker().lock_ctx(InstanceModel.__tablename__, instances_ids): # If another job freed the instance but is still trying to detach volumes, diff --git a/src/dstack/_internal/server/routers/gateways.py b/src/dstack/_internal/server/routers/gateways.py index ce0d94d60..604519af0 100644 --- a/src/dstack/_internal/server/routers/gateways.py +++ b/src/dstack/_internal/server/routers/gateways.py @@ -47,9 +47,10 @@ async def create_gateway( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), ) -> models.Gateway: - _, project = user_project + user, project = user_project return await gateways.create_gateway( session=session, + user=user, project=project, configuration=body.configuration, ) diff --git a/src/dstack/_internal/server/services/config.py b/src/dstack/_internal/server/services/config.py index f56303f05..1b7268b39 100644 --- a/src/dstack/_internal/server/services/config.py +++ b/src/dstack/_internal/server/services/config.py @@ -29,6 +29,7 @@ DefaultPermissions, set_default_permissions, ) +from dstack._internal.server.services.plugins import load_plugins from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -38,7 +39,7 @@ # If a collection has nested collections, it will be assigned the block style. Otherwise it will have the flow style. # # We want mapping to always be displayed in block-style but lists without nested objects in flow-style. -# So we define a custom representeter +# So we define a custom representer. def seq_representer(dumper, sequence): @@ -75,7 +76,10 @@ class ServerConfig(CoreModel): ] = None default_permissions: Annotated[ Optional[DefaultPermissions], Field(description="The default user permissions") - ] + ] = None + plugins: Annotated[ + Optional[List[str]], Field(description="The server-side plugins to enable") + ] = None class ServerConfigManager: @@ -112,6 +116,7 @@ async def apply_config(self, session: AsyncSession, owner: UserModel): await self._apply_project_config( session=session, owner=owner, project_config=project_config ) + load_plugins(enabled_plugins=self.config.plugins or []) async def _apply_project_config( self, diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index f2263de1b..d0e17cbb5 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -55,6 +55,7 @@ get_locker, string_to_lock_id, ) +from dstack._internal.server.services.plugins import apply_plugin_policies from dstack._internal.server.services.projects import ( get_member, get_member_permissions, @@ -234,7 +235,14 @@ async def get_plan( user: UserModel, spec: FleetSpec, ) -> FleetPlan: + # Spec must be copied by parsing to calculate merged_profile effective_spec = FleetSpec.parse_obj(spec.dict()) + effective_spec = apply_plugin_policies( + user=user.name, + project=project.name, + spec=effective_spec, + ) + effective_spec = FleetSpec.parse_obj(effective_spec.dict()) current_fleet: Optional[Fleet] = None current_fleet_id: Optional[uuid.UUID] = None if effective_spec.configuration.name is not None: @@ -330,6 +338,13 @@ async def create_fleet( user: UserModel, spec: FleetSpec, ) -> Fleet: + # Spec must be copied by parsing to calculate merged_profile + spec = apply_plugin_policies( + user=user.name, + project=project.name, + spec=spec, + ) + spec = FleetSpec.parse_obj(spec.dict()) _validate_fleet_spec(spec) if spec.configuration.ssh_config is not None: diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index 3a0cbfa87..d271d9fd7 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -31,13 +31,19 @@ Gateway, GatewayComputeConfiguration, GatewayConfiguration, + GatewaySpec, GatewayStatus, LetsEncryptGatewayCertificate, ) from dstack._internal.core.services import validate_dstack_resource_name from dstack._internal.server import settings from dstack._internal.server.db import get_db -from dstack._internal.server.models import GatewayComputeModel, GatewayModel, ProjectModel +from dstack._internal.server.models import ( + GatewayComputeModel, + GatewayModel, + ProjectModel, + UserModel, +) from dstack._internal.server.services.backends import ( check_backend_type_available, get_project_backend_by_type_or_error, @@ -50,6 +56,7 @@ get_locker, string_to_lock_id, ) +from dstack._internal.server.services.plugins import apply_plugin_policies from dstack._internal.server.utils.common import gather_map_async from dstack._internal.utils.common import get_current_datetime, run_async from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes @@ -129,9 +136,17 @@ async def create_gateway_compute( async def create_gateway( session: AsyncSession, + user: UserModel, project: ProjectModel, configuration: GatewayConfiguration, ) -> Gateway: + spec = apply_plugin_policies( + user=user.name, + project=project.name, + # Create pseudo spec until the gateway API is updated to accept spec + spec=GatewaySpec(configuration=configuration), + ) + configuration = spec.configuration _validate_gateway_configuration(configuration) backend_model, _ = await get_project_backend_with_model_by_type_or_error( @@ -140,7 +155,7 @@ async def create_gateway( lock_namespace = f"gateway_names_{project.name}" if get_db().dialect_name == "sqlite": - # Start new transaction to see commited changes after lock + # Start new transaction to see committed changes after lock await session.commit() elif get_db().dialect_name == "postgresql": await session.execute( diff --git a/src/dstack/_internal/server/services/plugins.py b/src/dstack/_internal/server/services/plugins.py new file mode 100644 index 000000000..a8e5be8a0 --- /dev/null +++ b/src/dstack/_internal/server/services/plugins.py @@ -0,0 +1,77 @@ +import itertools +from importlib import import_module + +from backports.entry_points_selectable import entry_points # backport for Python 3.9 + +from dstack._internal.core.errors import ServerClientError +from dstack._internal.utils.logging import get_logger +from dstack.plugins import ApplyPolicy, ApplySpec, Plugin + +logger = get_logger(__name__) + + +_PLUGINS: list[Plugin] = [] + + +def load_plugins(enabled_plugins: list[str]): + _PLUGINS.clear() + plugins_entrypoints = entry_points(group="dstack.plugins") + plugins_to_load = enabled_plugins.copy() + for entrypoint in plugins_entrypoints: + if entrypoint.name not in enabled_plugins: + logger.info( + ("Found not enabled plugin %s. Plugin will not be loaded."), + entrypoint.name, + ) + continue + try: + module_path, _, class_name = entrypoint.value.partition(":") + module = import_module(module_path) + except ImportError: + logger.warning( + ( + "Failed to load plugin %s when importing %s." + " Ensure the module is on the import path." + ), + entrypoint.name, + entrypoint.value, + ) + continue + plugin_class = getattr(module, class_name, None) + if plugin_class is None: + logger.warning( + ("Failed to load plugin %s: plugin class %s not found in module %s."), + entrypoint.name, + class_name, + module_path, + ) + continue + if not issubclass(plugin_class, Plugin): + logger.warning( + ("Failed to load plugin %s: plugin class %s is not a subclass of Plugin."), + entrypoint.name, + class_name, + ) + continue + plugins_to_load.remove(entrypoint.name) + _PLUGINS.append(plugin_class()) + logger.info("Loaded plugin %s", entrypoint.name) + if plugins_to_load: + logger.warning("Enabled plugins not found: %s", plugins_to_load) + + +def apply_plugin_policies(user: str, project: str, spec: ApplySpec) -> ApplySpec: + policies = _get_apply_policies() + for policy in policies: + try: + spec = policy.on_apply(user=user, project=project, spec=spec) + except ValueError as e: + msg = None + if len(e.args) > 0: + msg = e.args[0] + raise ServerClientError(msg) + return spec + + +def _get_apply_policies() -> list[ApplyPolicy]: + return list(itertools.chain(*[p.get_apply_policies() for p in _PLUGINS])) diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 9c591f9db..354151401 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -79,6 +79,7 @@ from dstack._internal.server.services.locking import get_locker, string_to_lock_id from dstack._internal.server.services.logging import fmt from dstack._internal.server.services.offers import get_offers_by_requirements +from dstack._internal.server.services.plugins import apply_plugin_policies from dstack._internal.server.services.projects import list_project_models, list_user_project_models from dstack._internal.server.services.users import get_user_model_by_name from dstack._internal.utils.logging import get_logger @@ -279,7 +280,14 @@ async def get_plan( run_spec: RunSpec, max_offers: Optional[int], ) -> RunPlan: + # Spec must be copied by parsing to calculate merged_profile effective_run_spec = RunSpec.parse_obj(run_spec.dict()) + effective_run_spec = apply_plugin_policies( + user=user.name, + project=project.name, + spec=effective_run_spec, + ) + effective_run_spec = RunSpec.parse_obj(effective_run_spec.dict()) _validate_run_spec_and_set_defaults(effective_run_spec) profile = effective_run_spec.merged_profile @@ -370,28 +378,36 @@ async def apply_plan( plan: ApplyRunPlanInput, force: bool, ) -> Run: - _validate_run_spec_and_set_defaults(plan.run_spec) - if plan.run_spec.run_name is None: + run_spec = plan.run_spec + run_spec = apply_plugin_policies( + user=user.name, + project=project.name, + spec=run_spec, + ) + # Spec must be copied by parsing to calculate merged_profile + run_spec = RunSpec.parse_obj(run_spec.dict()) + _validate_run_spec_and_set_defaults(run_spec) + if run_spec.run_name is None: return await submit_run( session=session, user=user, project=project, - run_spec=plan.run_spec, + run_spec=run_spec, ) current_resource = await get_run_by_name( session=session, project=project, - run_name=plan.run_spec.run_name, + run_name=run_spec.run_name, ) if current_resource is None or current_resource.status.is_finished(): return await submit_run( session=session, user=user, project=project, - run_spec=plan.run_spec, + run_spec=run_spec, ) try: - _check_can_update_run_spec(current_resource.run_spec, plan.run_spec) + _check_can_update_run_spec(current_resource.run_spec, run_spec) except ServerClientError: # The except is only needed to raise an appropriate error if run is active if not current_resource.status.is_finished(): @@ -409,14 +425,12 @@ async def apply_plan( # FIXME: potentially long write transaction # Avoid getting run_model after update await session.execute( - update(RunModel) - .where(RunModel.id == current_resource.id) - .values(run_spec=plan.run_spec.json()) + update(RunModel).where(RunModel.id == current_resource.id).values(run_spec=run_spec.json()) ) run = await get_run_by_name( session=session, project=project, - run_name=plan.run_spec.run_name, + run_name=run_spec.run_name, ) return common_utils.get_or_error(run) @@ -436,7 +450,7 @@ async def submit_run( lock_namespace = f"run_names_{project.name}" if get_db().dialect_name == "sqlite": - # Start new transaction to see commited changes after lock + # Start new transaction to see committed changes after lock await session.commit() elif get_db().dialect_name == "postgresql": await session.execute( diff --git a/src/dstack/_internal/server/services/volumes.py b/src/dstack/_internal/server/services/volumes.py index 228343e1e..689520620 100644 --- a/src/dstack/_internal/server/services/volumes.py +++ b/src/dstack/_internal/server/services/volumes.py @@ -21,6 +21,7 @@ VolumeConfiguration, VolumeInstance, VolumeProvisioningData, + VolumeSpec, VolumeStatus, ) from dstack._internal.core.services import validate_dstack_resource_name @@ -38,6 +39,7 @@ get_locker, string_to_lock_id, ) +from dstack._internal.server.services.plugins import apply_plugin_policies from dstack._internal.server.services.projects import list_project_models, list_user_project_models from dstack._internal.utils import common, random_names from dstack._internal.utils.logging import get_logger @@ -203,11 +205,18 @@ async def create_volume( user: UserModel, configuration: VolumeConfiguration, ) -> Volume: + spec = apply_plugin_policies( + user=user.name, + project=project.name, + # Create pseudo spec until the volume API is updated to accept spec + spec=VolumeSpec(configuration=configuration), + ) + configuration = spec.configuration _validate_volume_configuration(configuration) lock_namespace = f"volume_names_{project.name}" if get_db().dialect_name == "sqlite": - # Start new transaction to see commited changes after lock + # Start new transaction to see committed changes after lock await session.commit() elif get_db().dialect_name == "postgresql": await session.execute( diff --git a/src/dstack/plugins/__init__.py b/src/dstack/plugins/__init__.py new file mode 100644 index 000000000..93970043a --- /dev/null +++ b/src/dstack/plugins/__init__.py @@ -0,0 +1,8 @@ +# ruff: noqa: F401 +from dstack._internal.core.models.fleets import FleetSpec +from dstack._internal.core.models.gateways import GatewaySpec +from dstack._internal.core.models.runs import RunSpec +from dstack._internal.core.models.volumes import VolumeSpec +from dstack.plugins._base import ApplyPolicy, Plugin +from dstack.plugins._models import ApplySpec +from dstack.plugins._utils import get_plugin_logger diff --git a/src/dstack/plugins/_base.py b/src/dstack/plugins/_base.py new file mode 100644 index 000000000..a30ae0c33 --- /dev/null +++ b/src/dstack/plugins/_base.py @@ -0,0 +1,72 @@ +from dstack._internal.core.models.fleets import FleetSpec +from dstack._internal.core.models.gateways import GatewaySpec +from dstack._internal.core.models.runs import RunSpec +from dstack._internal.core.models.volumes import VolumeSpec +from dstack.plugins._models import ApplySpec + + +class ApplyPolicy: + """ + A base apply policy class to modify specs on `dstack apply`. + Subclass it and return the subclass instance in `Plugin.get_apply_policies()`. + """ + + def on_apply(self, user: str, project: str, spec: ApplySpec) -> ApplySpec: + """ + Modify `spec` before it's applied. + Raise `ValueError` for `spec` to be rejected as invalid. + + This method can be called twice: + * first when a user gets a plan + * second when a user applies a plan + + In both cases, the original spec is passed, so the method does not + need to check if it modified the spec before. + + It's safe to modify and return `spec` without copying. + """ + if isinstance(spec, RunSpec): + return self.on_run_apply(user=user, project=project, spec=spec) + if isinstance(spec, FleetSpec): + return self.on_fleet_apply(user=user, project=project, spec=spec) + if isinstance(spec, VolumeSpec): + return self.on_volume_apply(user=user, project=project, spec=spec) + if isinstance(spec, GatewaySpec): + return self.on_gateway_apply(user=user, project=project, spec=spec) + raise ValueError(f"Unknown spec type {type(spec)}") + + def on_run_apply(self, user: str, project: str, spec: RunSpec) -> RunSpec: + """ + Called by the default `on_apply()` implementation for runs. + """ + return spec + + def on_fleet_apply(self, user: str, project: str, spec: FleetSpec) -> FleetSpec: + """ + Called by the default `on_apply()` implementation for fleets. + """ + return spec + + def on_volume_apply(self, user: str, project: str, spec: VolumeSpec) -> VolumeSpec: + """ + Called by the default `on_apply()` implementation for volumes. + """ + return spec + + def on_gateway_apply(self, user: str, project: str, spec: GatewaySpec) -> GatewaySpec: + """ + Called by the default `on_apply()` implementation for gateways. + """ + return spec + + +class Plugin: + """ + A base plugin class. + Plugins must subclass it, implement public methods, + and register the subclass as an entrypoint of the package + (https://packaging.python.org/en/latest/specifications/entry-points/). + """ + + def get_apply_policies(self) -> list[ApplyPolicy]: + return [] diff --git a/src/dstack/plugins/_models.py b/src/dstack/plugins/_models.py new file mode 100644 index 000000000..124e0e593 --- /dev/null +++ b/src/dstack/plugins/_models.py @@ -0,0 +1,8 @@ +from typing import TypeVar + +from dstack._internal.core.models.fleets import FleetSpec +from dstack._internal.core.models.gateways import GatewaySpec +from dstack._internal.core.models.runs import RunSpec +from dstack._internal.core.models.volumes import VolumeSpec + +ApplySpec = TypeVar("ApplySpec", RunSpec, FleetSpec, VolumeSpec, GatewaySpec) diff --git a/src/dstack/plugins/_utils.py b/src/dstack/plugins/_utils.py new file mode 100644 index 000000000..9de3ff260 --- /dev/null +++ b/src/dstack/plugins/_utils.py @@ -0,0 +1,19 @@ +import logging + +from dstack._internal.utils.logging import get_logger + + +def get_plugin_logger(name: str) -> logging.Logger: + """ + Use this function to set up loggers in plugins. + + Put at the top of the plugin modules: + + ``` + from dstack.plugins import get_plugin_logger + + logger = get_plugin_logger(__name__) + ``` + + """ + return get_logger(f"dstack.plugins.{name}") diff --git a/src/tests/_internal/server/services/test_plugins.py b/src/tests/_internal/server/services/test_plugins.py new file mode 100644 index 000000000..460764c4d --- /dev/null +++ b/src/tests/_internal/server/services/test_plugins.py @@ -0,0 +1,245 @@ +import logging +from importlib.metadata import EntryPoint +from unittest.mock import MagicMock, patch + +import pytest + +from dstack._internal.server.services.plugins import _PLUGINS, load_plugins +from dstack.plugins import Plugin + + +class DummyPlugin1(Plugin): + pass + + +class DummyPlugin2(Plugin): + pass + + +class NotAPlugin: + pass + + +@pytest.fixture(autouse=True) +def clear_plugins(): + _PLUGINS.clear() + yield + _PLUGINS.clear() + + +class TestLoadPlugins: + @patch("dstack._internal.server.services.plugins.entry_points") + @patch("dstack._internal.server.services.plugins.import_module") + def test_load_single_plugin(self, mock_import_module, mock_entry_points, caplog): + mock_entry_points.return_value = [ + EntryPoint( + name="plugin1", + value="dummy.plugins:DummyPlugin1", + group="dstack.plugins", + ) + ] + mock_module = MagicMock() + mock_module.DummyPlugin1 = DummyPlugin1 + mock_import_module.return_value = mock_module + + with caplog.at_level(logging.INFO): + load_plugins(["plugin1"]) + + assert len(_PLUGINS) == 1 + assert isinstance(_PLUGINS[0], DummyPlugin1) + mock_entry_points.assert_called_once_with(group="dstack.plugins") + mock_import_module.assert_called_once_with("dummy.plugins") + assert "Loaded plugin plugin1" in caplog.text + + @patch("dstack._internal.server.services.plugins.entry_points") + @patch("dstack._internal.server.services.plugins.import_module") + def test_load_multiple_plugins(self, mock_import_module, mock_entry_points, caplog): + mock_entry_points.return_value = [ + EntryPoint( + name="plugin1", + value="dummy.plugins:DummyPlugin1", + group="dstack.plugins", + ), + EntryPoint( + name="plugin2", + value="dummy.plugins:DummyPlugin2", + group="dstack.plugins", + ), + ] + mock_module = MagicMock() + mock_module.DummyPlugin1 = DummyPlugin1 + mock_module.DummyPlugin2 = DummyPlugin2 + mock_import_module.return_value = mock_module + + with caplog.at_level(logging.INFO): + load_plugins(["plugin1", "plugin2"]) + + assert len(_PLUGINS) == 2 + assert isinstance(_PLUGINS[0], DummyPlugin1) + assert isinstance(_PLUGINS[1], DummyPlugin2) + assert "Loaded plugin plugin1" in caplog.text + assert "Loaded plugin plugin2" in caplog.text + + @patch("dstack._internal.server.services.plugins.entry_points") + @patch("dstack._internal.server.services.plugins.import_module") + def test_plugin_not_enabled(self, mock_import_module, mock_entry_points, caplog): + mock_entry_points.return_value = [ + EntryPoint( + name="plugin1", + value="dummy.plugins:DummyPlugin1", + group="dstack.plugins", + ) + ] + + with caplog.at_level(logging.INFO): + load_plugins([]) # Enable no plugins + + assert len(_PLUGINS) == 0 + mock_import_module.assert_not_called() + assert "Found not enabled plugin plugin1" in caplog.text + + @patch("dstack._internal.server.services.plugins.entry_points") + @patch("dstack._internal.server.services.plugins.import_module") + def test_enabled_plugin_not_found(self, mock_import_module, mock_entry_points, caplog): + mock_entry_points.return_value = [ + EntryPoint( + name="plugin1", + value="dummy.plugins:DummyPlugin1", + group="dstack.plugins", + ) + ] + + with caplog.at_level(logging.INFO): + load_plugins(["plugin2"]) # Enable a plugin that doesn't have an entry point + + assert len(_PLUGINS) == 0 + mock_import_module.assert_not_called() + assert "Found not enabled plugin plugin1" in caplog.text + assert "Enabled plugins not found: ['plugin2']" in caplog.text + + @patch("dstack._internal.server.services.plugins.entry_points") + @patch( + "dstack._internal.server.services.plugins.import_module", + side_effect=ImportError("Module not found"), + ) + def test_import_error(self, mock_import_module, mock_entry_points, caplog): + mock_entry_points.return_value = [ + EntryPoint( + name="plugin1", + value="dummy.plugins:DummyPlugin1", + group="dstack.plugins", + ) + ] + + with caplog.at_level(logging.INFO): + load_plugins(["plugin1"]) + + assert len(_PLUGINS) == 0 + assert ( + "Failed to load plugin plugin1 when importing dummy.plugins:DummyPlugin1" + in caplog.text + ) + assert "Enabled plugins not found: ['plugin1']" in caplog.text # Because loading failed + + @patch("dstack._internal.server.services.plugins.entry_points") + @patch("dstack._internal.server.services.plugins.import_module") + def test_class_not_found(self, mock_import_module, mock_entry_points, caplog): + mock_entry_points.return_value = [ + EntryPoint( + name="plugin1", + value="dummy.plugins:NonExistentClass", + group="dstack.plugins", + ) + ] + mock_module = MagicMock() + # Simulate the class not being present + del mock_module.NonExistentClass + mock_import_module.return_value = mock_module + + with caplog.at_level(logging.INFO): + load_plugins(["plugin1"]) + + assert len(_PLUGINS) == 0 + assert ( + "Failed to load plugin plugin1: plugin class NonExistentClass not found" in caplog.text + ) + assert "Enabled plugins not found: ['plugin1']" in caplog.text + + @patch("dstack._internal.server.services.plugins.entry_points") + @patch("dstack._internal.server.services.plugins.import_module") + def test_not_a_plugin_subclass(self, mock_import_module, mock_entry_points, caplog): + mock_entry_points.return_value = [ + EntryPoint( + name="plugin1", + value="dummy.plugins:NotAPlugin", + group="dstack.plugins", + ) + ] + mock_module = MagicMock() + mock_module.NotAPlugin = NotAPlugin + mock_import_module.return_value = mock_module + + with caplog.at_level(logging.INFO): + load_plugins(["plugin1"]) + + assert len(_PLUGINS) == 0 + assert ( + "Failed to load plugin plugin1: plugin class NotAPlugin is not a subclass of Plugin" + in caplog.text + ) + assert "Enabled plugins not found: ['plugin1']" in caplog.text + + @patch("dstack._internal.server.services.plugins.entry_points") + @patch("dstack._internal.server.services.plugins.import_module") + def test_clears_existing_plugins(self, mock_import_module, mock_entry_points): + # Pre-populate _PLUGINS + _PLUGINS.append(DummyPlugin1()) + + mock_entry_points.return_value = [ + EntryPoint( + name="plugin2", + value="dummy.plugins:DummyPlugin2", + group="dstack.plugins", + ) + ] + mock_module = MagicMock() + mock_module.DummyPlugin2 = DummyPlugin2 + mock_import_module.return_value = mock_module + + load_plugins(["plugin2"]) + + assert len(_PLUGINS) == 1 # Should only contain plugin2 + assert isinstance(_PLUGINS[0], DummyPlugin2) + + @patch("dstack._internal.server.services.plugins.entry_points") + @patch("dstack._internal.server.services.plugins.import_module") + def test_load_no_plugins_found(self, mock_import_module, mock_entry_points, caplog): + mock_entry_points.return_value = [] # No entry points found + + with caplog.at_level(logging.INFO): + load_plugins(["plugin1"]) # Try to enable one + + assert len(_PLUGINS) == 0 + mock_import_module.assert_not_called() + assert "Enabled plugins not found: ['plugin1']" in caplog.text + + @patch("dstack._internal.server.services.plugins.entry_points") + @patch("dstack._internal.server.services.plugins.import_module") + def test_load_no_plugins_enabled(self, mock_import_module, mock_entry_points, caplog): + mock_entry_points.return_value = [ + EntryPoint( + name="plugin1", + value="dummy.plugins:DummyPlugin1", + group="dstack.plugins", + ) + ] + + with caplog.at_level(logging.INFO): + load_plugins([]) # Enable none + + assert len(_PLUGINS) == 0 + mock_import_module.assert_not_called() + assert "Found not enabled plugin plugin1" in caplog.text + assert ( + "Enabled plugins not found" not in caplog.text + ) # Should not warn if none were enabled