Skip to content

Implement plugins #2581

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions docs/docs/concepts/plugins.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Plugins

The `dstack` plugin system allows extending `dstack` server functionality using external Python packages.

!!! info "Experimental"
Plugins are currently an _experimental_ feature.
Backward compatibility is not guaranteed across releases.

## Enable plugins

To enable a plugin, list it under `plugins` in [`server/config.yml`](../reference/server/config.yml.md):

<div editor-title="server/config.yml">

```yaml
plugins:
- my_dstack_plugin
- some_other_plugin
projects:
- name: main
```

</div>

On the next server restart, you should see a log message indicating that the plugin is loaded.

## Create plugins

To create a plugin, create a Python package that implements a subclass of
`dstack.plugins.Plugin` and exports this subclass as a "dstack.plugins" entry point.

1. Init the plugin package:

<div class="termy">

```shell
$ uv init --library
```

</div>

2. Define `ApplyPolicy` and `Plugin` subclasses:

<div editor-title="src/example_plugin/__init__.py">

```python
from dstack.plugins import ApplyPolicy, Plugin, RunSpec, get_plugin_logger

logger = get_plugin_logger(__name__)

class ExamplePolicy(ApplyPolicy):
def on_run_apply(self, user: str, project: str, spec: RunSpec) -> RunSpec:
# ...
return spec

class ExamplePlugin(Plugin):

def get_apply_policies(self) -> list[ApplyPolicy]:
return [ExamplePolicy()]
```

</div>

3. Specify a "dstack.plugins" entry point in `pyproject.toml`:

<div editor-title="pyproject.toml">

```toml
[project.entry-points."dstack.plugins"]
example_plugin = "example_plugin:ExamplePlugin"
```

</div>

Then you can install the plugin package into your Python environment and enable it via `server/config.yml`.

??? info "Plugins in Docker"
If you deploy `dstack` using a Docker image you can add plugins either
by including them in your custom image built upon the `dstack` server image,
or by mounting installed plugins as volumes.

## Apply policies

Currently the only plugin functionality is apply policies.
Apply policies allow modifying specs of runs, fleets, volumes, and gateways submitted on `dstack apply`.
Subclass `dstack.plugins.ApplyPolicy` to implement them.

Here's an example of how to enforce certain rules using apply policies:

<div editor-title="src/example_plugin/__init__.py">

```python
class ExamplePolicy(ApplyPolicy):
def on_run_apply(self, user: str, project: str, spec: RunSpec) -> RunSpec:
# Forcing some limits
spec.configuration.max_price = 2.0
spec.configuration.max_duration = "1d"
# Setting some extra tags
if spec.configuration.tags is None:
spec.configuration.tags = {}
spec.configuration.tags |= {
"team": "my_team",
}
# Forbid something
if spec.configuration.privileged:
logger.warning("User %s tries to run privileged containers", user)
raise ValueError("Running privileged containers is forbidden")
# Set some service-specific properties
if isinstance(spec.configuration, Service):
spec.configuration.https = True
return spec
```

</div>

For more information on the plugin development, see the [plugin example](https://github.com/dstackai/dstack/tree/master/examples/plugins/example_plugin).
1 change: 1 addition & 0 deletions examples/plugins/example_plugin/.python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.11
52 changes: 52 additions & 0 deletions examples/plugins/example_plugin/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
## Overview

This is a basic `dstack` plugin example.
You can use it as a reference point when implementing new `dstack` plugins.

## Steps

1. Init the plugin package:

```
uv init --library
```

2. Define `ApplyPolicy` and `Plugin` subclasses:

```python
from dstack.plugins import ApplyPolicy, Plugin, RunSpec, get_plugin_logger


logger = get_plugin_logger(__name__)


class ExamplePolicy(ApplyPolicy):

def on_run_apply(self, user: str, project: str, spec: RunSpec) -> RunSpec:
# ...
return spec


class ExamplePlugin(Plugin):

def get_apply_policies(self) -> list[ApplyPolicy]:
return [ExamplePolicy()]

```

3. Specify a "dstack.plugins" entry point in `pyproject.toml`:

```toml
[project.entry-points."dstack.plugins"]
example_plugin = "example_plugin:ExamplePlugin"
```

4. Make sure to install the plugin and enable it in the `server/config.yml`:

```yaml
plugins:
- example_plugin
projects:
- name: main
# ...
```
17 changes: 17 additions & 0 deletions examples/plugins/example_plugin/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[project]
name = "example-plugin"
version = "0.1.0"
description = "A dstack plugin example"
readme = "README.md"
authors = [
{ name = "Victor Skvortsov", email = "[email protected]" }
]
requires-python = ">=3.9"
dependencies = []

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project.entry-points."dstack.plugins"]
example_plugin = "example_plugin:ExamplePlugin"
34 changes: 34 additions & 0 deletions examples/plugins/example_plugin/src/example_plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from dstack.api import Service
from dstack.plugins import ApplyPolicy, GatewaySpec, Plugin, RunSpec, get_plugin_logger

logger = get_plugin_logger(__name__)


class ExamplePolicy(ApplyPolicy):
def on_run_apply(self, user: str, project: str, spec: RunSpec) -> RunSpec:
# Forcing some limits
spec.configuration.max_price = 2.0
spec.configuration.max_duration = "1d"
# Setting some extra tags
if spec.configuration.tags is None:
spec.configuration.tags = {}
spec.configuration.tags |= {
"team": "my_team",
}
# Forbid something
if spec.configuration.privileged:
logger.warning("User %s tries to run privileged containers", user)
raise ValueError("Running privileged containers is forbidden")
# Set some service-specific properties
if isinstance(spec.configuration, Service):
spec.configuration.https = True
return spec

def on_gateway_apply(self, user: str, project: str, spec: GatewaySpec) -> GatewaySpec:
# Forbid creating new gateways altogether
raise ValueError("Creating gateways is forbidden")


class ExamplePlugin(Plugin):
def get_apply_policies(self) -> list[ApplyPolicy]:
return [ExamplePolicy()]
Empty file.
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ nav:
- Backends: docs/concepts/backends.md
- Projects: docs/concepts/projects.md
- Gateways: docs/concepts/gateways.md
- Plugins: docs/concepts/plugins.md
- Guides:
- Protips: docs/guides/protips.md
- Server deployment: docs/guides/server-deployment.md
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ server = [
"python-json-logger>=3.1.0",
"prometheus-client",
"grpcio>=1.50",
"backports.entry-points-selectable",
]
aws = [
"boto3",
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/core/models/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ class FleetSpec(CoreModel):
configuration_path: Optional[str] = None
profile: Profile
autocreated: bool = False
# merged_profile stores profile parameters merged from profile and configuration.
# Read profile parameters from merged_profile instead of profile directly.
# TODO: make merged_profile a computed field after migrating to pydanticV2
merged_profile: Annotated[Profile, Field(exclude=True)] = None

Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ class RunSpec(CoreModel):
description="The contents of the SSH public key that will be used to connect to the run."
),
]
# merged_profile stores profile parameters merged from profile and configuration.
# Read profile parameters from merged_profile instead of profile directly.
# TODO: make merged_profile a computed field after migrating to pydanticV2
merged_profile: Annotated[Profile, Field(exclude=True)] = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
pool_instances = list(res.unique().scalars().all())
instances_ids = sorted([i.id for i in pool_instances])
if get_db().dialect_name == "sqlite":
# Start new transaction to see commited changes after lock
# Start new transaction to see committed changes after lock
await session.commit()
async with get_locker().lock_ctx(InstanceModel.__tablename__, instances_ids):
# If another job freed the instance but is still trying to detach volumes,
Expand Down
3 changes: 2 additions & 1 deletion src/dstack/_internal/server/routers/gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ async def create_gateway(
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
) -> models.Gateway:
_, project = user_project
user, project = user_project
return await gateways.create_gateway(
session=session,
user=user,
project=project,
configuration=body.configuration,
)
Expand Down
9 changes: 7 additions & 2 deletions src/dstack/_internal/server/services/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
DefaultPermissions,
set_default_permissions,
)
from dstack._internal.server.services.plugins import load_plugins
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)
Expand All @@ -38,7 +39,7 @@
# If a collection has nested collections, it will be assigned the block style. Otherwise it will have the flow style.
#
# We want mapping to always be displayed in block-style but lists without nested objects in flow-style.
# So we define a custom representeter
# So we define a custom representer.


def seq_representer(dumper, sequence):
Expand Down Expand Up @@ -75,7 +76,10 @@ class ServerConfig(CoreModel):
] = None
default_permissions: Annotated[
Optional[DefaultPermissions], Field(description="The default user permissions")
]
] = None
plugins: Annotated[
Optional[List[str]], Field(description="The server-side plugins to enable")
] = None


class ServerConfigManager:
Expand Down Expand Up @@ -112,6 +116,7 @@ async def apply_config(self, session: AsyncSession, owner: UserModel):
await self._apply_project_config(
session=session, owner=owner, project_config=project_config
)
load_plugins(enabled_plugins=self.config.plugins or [])

async def _apply_project_config(
self,
Expand Down
15 changes: 15 additions & 0 deletions src/dstack/_internal/server/services/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
get_locker,
string_to_lock_id,
)
from dstack._internal.server.services.plugins import apply_plugin_policies
from dstack._internal.server.services.projects import (
get_member,
get_member_permissions,
Expand Down Expand Up @@ -234,7 +235,14 @@ async def get_plan(
user: UserModel,
spec: FleetSpec,
) -> FleetPlan:
# Spec must be copied by parsing to calculate merged_profile
effective_spec = FleetSpec.parse_obj(spec.dict())
effective_spec = apply_plugin_policies(
user=user.name,
project=project.name,
spec=effective_spec,
)
effective_spec = FleetSpec.parse_obj(effective_spec.dict())
current_fleet: Optional[Fleet] = None
current_fleet_id: Optional[uuid.UUID] = None
if effective_spec.configuration.name is not None:
Expand Down Expand Up @@ -330,6 +338,13 @@ async def create_fleet(
user: UserModel,
spec: FleetSpec,
) -> Fleet:
# Spec must be copied by parsing to calculate merged_profile
spec = apply_plugin_policies(
user=user.name,
project=project.name,
spec=spec,
)
spec = FleetSpec.parse_obj(spec.dict())
_validate_fleet_spec(spec)

if spec.configuration.ssh_config is not None:
Expand Down
Loading
Loading