diff --git a/docs/docs/concepts/exports.md b/docs/docs/concepts/exports.md
index 3b3194146..a773094b4 100644
--- a/docs/docs/concepts/exports.md
+++ b/docs/docs/concepts/exports.md
@@ -134,6 +134,19 @@ $ dstack fleet list
Imported fleets can be used for runs just like the project's own fleets.
+
+
+```yaml
+type: dev-environment
+ide: vscode
+
+fleets:
+- my-local-fleet
+- team-a/my-fleet
+```
+
+
+
!!! info "Tenant isolation"
Exported fleets share the same access model as regular fleets. See [Tenant isolation](fleets.md#tenant-isolation) for details.
diff --git a/src/dstack/_internal/core/compatibility/common.py b/src/dstack/_internal/core/compatibility/common.py
new file mode 100644
index 000000000..6a3445a7b
--- /dev/null
+++ b/src/dstack/_internal/core/compatibility/common.py
@@ -0,0 +1,14 @@
+from dstack._internal.core.models.common import EntityReference
+from dstack._internal.core.models.profiles import ProfileParams
+
+
+def patch_profile_params(params: ProfileParams) -> None:
+ # If there are no project-prefixed fleets, replace all EntityReference with str
+ # for compatibility with pre-0.20.14 servers that don't support EntityReference.
+ if params.fleets is not None and all(
+ EntityReference.parse(f).project is None for f in params.fleets
+ ):
+ params.fleets = [
+ fleet_ref.format() if isinstance(fleet_ref, EntityReference) else fleet_ref
+ for fleet_ref in params.fleets
+ ]
diff --git a/src/dstack/_internal/core/compatibility/fleets.py b/src/dstack/_internal/core/compatibility/fleets.py
index e22903df3..04a92ed77 100644
--- a/src/dstack/_internal/core/compatibility/fleets.py
+++ b/src/dstack/_internal/core/compatibility/fleets.py
@@ -1,6 +1,10 @@
from typing import Optional
-from dstack._internal.core.models.common import IncludeExcludeDictType, IncludeExcludeSetType
+from dstack._internal.core.compatibility.common import patch_profile_params
+from dstack._internal.core.models.common import (
+ IncludeExcludeDictType,
+ IncludeExcludeSetType,
+)
from dstack._internal.core.models.fleets import ApplyFleetPlanInput, FleetSpec
@@ -56,3 +60,7 @@ def get_fleet_spec_excludes(fleet_spec: FleetSpec) -> Optional[IncludeExcludeDic
if spec_excludes:
return spec_excludes
return None
+
+
+def patch_fleet_spec(spec: FleetSpec) -> None:
+ patch_profile_params(spec.profile)
diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py
index 2b0e3a6b4..e08b8260f 100644
--- a/src/dstack/_internal/core/compatibility/runs.py
+++ b/src/dstack/_internal/core/compatibility/runs.py
@@ -1,6 +1,10 @@
from typing import Optional
-from dstack._internal.core.models.common import IncludeExcludeDictType, IncludeExcludeSetType
+from dstack._internal.core.compatibility.common import patch_profile_params
+from dstack._internal.core.models.common import (
+ IncludeExcludeDictType,
+ IncludeExcludeSetType,
+)
from dstack._internal.core.models.configurations import ServiceConfiguration
from dstack._internal.core.models.routers import SGLangServiceRouterConfig
from dstack._internal.core.models.runs import (
@@ -138,3 +142,9 @@ def get_job_submission_excludes(job_submissions: list[JobSubmission]) -> Include
submission_excludes["job_runtime_data"] = jrd_excludes
return submission_excludes
+
+
+def patch_run_spec(run_spec: RunSpec) -> None:
+ patch_profile_params(run_spec.configuration)
+ if run_spec.profile is not None:
+ patch_profile_params(run_spec.profile)
diff --git a/src/dstack/_internal/core/models/common.py b/src/dstack/_internal/core/models/common.py
index f55a032ba..a3bf68cff 100644
--- a/src/dstack/_internal/core/models/common.py
+++ b/src/dstack/_internal/core/models/common.py
@@ -143,3 +143,31 @@ class ApplyAction(str, Enum):
class NetworkMode(str, Enum):
HOST = "host"
BRIDGE = "bridge"
+
+
+class EntityReference(CoreModel):
+ """
+ Cross-project entity reference.
+ """
+
+ project: Annotated[
+ Optional[str],
+ Field(description="The project name. If unspecified, refers to the current project"),
+ ]
+ name: Annotated[str, Field(description="The entity name")]
+
+ @classmethod
+ def parse(cls, v: Union[str, "EntityReference"]) -> "EntityReference":
+ if isinstance(v, EntityReference):
+ return v
+ parts = v.split("/")
+ if len(parts) == 1:
+ return cls(project=None, name=parts[0])
+ if len(parts) == 2:
+ return cls(project=parts[0], name=parts[1])
+ raise ValueError("Invalid entity reference. Only `/` format is allowed")
+
+ def format(self) -> str:
+ if self.project is None:
+ return self.name
+ return f"{self.project}/{self.name}"
diff --git a/src/dstack/_internal/core/models/profiles.py b/src/dstack/_internal/core/models/profiles.py
index 97ce591aa..517514694 100644
--- a/src/dstack/_internal/core/models/profiles.py
+++ b/src/dstack/_internal/core/models/profiles.py
@@ -10,6 +10,7 @@
CoreConfig,
CoreModel,
Duration,
+ EntityReference,
generate_dual_core_model,
)
from dstack._internal.utils.common import list_enum_values_for_annotation
@@ -360,7 +361,21 @@ class ProfileParams(CoreModel):
Field(description=("The schedule for starting the run at specified time")),
] = None
fleets: Annotated[
- Optional[list[str]], Field(description="The fleets considered for reuse")
+ Optional[
+ list[
+ Union[
+ EntityReference,
+ str, # For server response compatibility with pre-0.20.14 clients
+ ]
+ ]
+ ],
+ Field(
+ description=(
+ "The fleets considered for reuse."
+ " For fleets owned by the current project, specify fleet names."
+ " For imported fleets, specify `/`"
+ ),
+ ),
] = None
tags: Annotated[
Optional[Dict[str, str]],
@@ -382,6 +397,7 @@ class ProfileParams(CoreModel):
_validate_idle_duration = validator("idle_duration", pre=True, allow_reuse=True)(
parse_idle_duration
)
+ _validate_fleets = validator("fleets", allow_reuse=True, each_item=True)(EntityReference.parse)
_validate_tags = validator("tags", pre=True, allow_reuse=True)(tags_validator)
diff --git a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py
index 52217eefb..eaa452380 100644
--- a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py
+++ b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py
@@ -975,6 +975,7 @@ async def _refetch_fleet_models_with_instances(
) -> list[FleetModel]:
res = await session.execute(
select(FleetModel)
+ .join(FleetModel.project) # can be referenced by fleet_filters
.outerjoin(FleetModel.instances)
.where(
FleetModel.id.in_(fleets_ids),
diff --git a/src/dstack/_internal/server/compatibility/common.py b/src/dstack/_internal/server/compatibility/common.py
index 227b45fda..ce982b673 100644
--- a/src/dstack/_internal/server/compatibility/common.py
+++ b/src/dstack/_internal/server/compatibility/common.py
@@ -2,10 +2,23 @@
from packaging.version import Version
+from dstack._internal.core.models.common import EntityReference
from dstack._internal.core.models.instances import (
InstanceAvailability,
InstanceOfferWithAvailability,
)
+from dstack._internal.core.models.profiles import ProfileParams
+
+
+def patch_profile_params(params: ProfileParams, client_version: Optional[Version]) -> None:
+ if client_version is None:
+ return
+ # Clients prior to 0.20.14 only support `list[str]` in `fleets`
+ if client_version < Version("0.20.14") and params.fleets is not None:
+ params.fleets = [
+ fleet_ref.format() if isinstance(fleet_ref, EntityReference) else fleet_ref
+ for fleet_ref in params.fleets
+ ]
def patch_offers_list(
diff --git a/src/dstack/_internal/server/compatibility/fleets.py b/src/dstack/_internal/server/compatibility/fleets.py
new file mode 100644
index 000000000..ddd90d14d
--- /dev/null
+++ b/src/dstack/_internal/server/compatibility/fleets.py
@@ -0,0 +1,23 @@
+from typing import Optional
+
+from packaging.version import Version
+
+from dstack._internal.core.models.fleets import Fleet, FleetPlan, FleetSpec
+from dstack._internal.server.compatibility.common import patch_offers_list, patch_profile_params
+
+
+def patch_fleet_plan(fleet_plan: FleetPlan, client_version: Optional[Version]) -> None:
+ patch_fleet_spec(fleet_plan.spec, client_version)
+ if fleet_plan.effective_spec is not None:
+ patch_fleet_spec(fleet_plan.effective_spec, client_version)
+ if fleet_plan.current_resource is not None:
+ patch_fleet(fleet_plan.current_resource, client_version)
+ patch_offers_list(fleet_plan.offers, client_version)
+
+
+def patch_fleet(fleet: Fleet, client_version: Optional[Version]) -> None:
+ patch_fleet_spec(fleet.spec, client_version)
+
+
+def patch_fleet_spec(fleet_spec: FleetSpec, client_version: Optional[Version]) -> None:
+ patch_profile_params(fleet_spec.profile, client_version)
diff --git a/src/dstack/_internal/server/compatibility/runs.py b/src/dstack/_internal/server/compatibility/runs.py
index 77df3f5fa..9a5d0d99b 100644
--- a/src/dstack/_internal/server/compatibility/runs.py
+++ b/src/dstack/_internal/server/compatibility/runs.py
@@ -4,7 +4,7 @@
from dstack._internal.core.models.configurations import SERVICE_HTTPS_DEFAULT, ServiceConfiguration
from dstack._internal.core.models.runs import Run, RunPlan, RunSpec
-from dstack._internal.server.compatibility.common import patch_offers_list
+from dstack._internal.server.compatibility.common import patch_offers_list, patch_profile_params
def patch_run_plan(run_plan: RunPlan, client_version: Optional[Version]) -> None:
@@ -41,3 +41,6 @@ def patch_run_spec(run_spec: RunSpec, client_version: Optional[Version]) -> None
and run_spec.configuration.https is None
):
run_spec.configuration.https = SERVICE_HTTPS_DEFAULT
+ patch_profile_params(run_spec.configuration, client_version)
+ if run_spec.profile is not None:
+ patch_profile_params(run_spec.profile, client_version)
diff --git a/src/dstack/_internal/server/routers/fleets.py b/src/dstack/_internal/server/routers/fleets.py
index 58c87d653..4aea9a03c 100644
--- a/src/dstack/_internal/server/routers/fleets.py
+++ b/src/dstack/_internal/server/routers/fleets.py
@@ -7,7 +7,7 @@
import dstack._internal.server.services.fleets as fleets_services
from dstack._internal.core.errors import ResourceNotExistsError
from dstack._internal.core.models.fleets import Fleet, FleetPlan
-from dstack._internal.server.compatibility.common import patch_offers_list
+from dstack._internal.server.compatibility.fleets import patch_fleet, patch_fleet_plan
from dstack._internal.server.db import get_session
from dstack._internal.server.deps import Project
from dstack._internal.server.models import ProjectModel, UserModel
@@ -50,6 +50,7 @@ async def list_fleets(
body: ListFleetsRequest,
session: AsyncSession = Depends(get_session),
user: UserModel = Depends(Authenticated()),
+ client_version: Optional[Version] = Depends(get_client_version),
):
"""
Returns all fleets and instances within them visible to user sorted by descending `created_at`.
@@ -59,19 +60,20 @@ async def list_fleets(
The results are paginated. To get the next page, pass `created_at` and `id` of
the last fleet from the previous page as `prev_created_at` and `prev_id`.
"""
- return CustomORJSONResponse(
- await fleets_services.list_fleets(
- session=session,
- user=user,
- project_name=body.project_name,
- only_active=body.only_active,
- include_imported=body.include_imported,
- prev_created_at=body.prev_created_at,
- prev_id=body.prev_id,
- limit=body.limit,
- ascending=body.ascending,
- )
+ fleet_list = await fleets_services.list_fleets(
+ session=session,
+ user=user,
+ project_name=body.project_name,
+ only_active=body.only_active,
+ include_imported=body.include_imported,
+ prev_created_at=body.prev_created_at,
+ prev_id=body.prev_id,
+ limit=body.limit,
+ ascending=body.ascending,
)
+ for fleet in fleet_list:
+ patch_fleet(fleet, client_version)
+ return CustomORJSONResponse(fleet_list)
@project_router.post("/list", response_model=List[Fleet])
@@ -79,6 +81,7 @@ async def list_project_fleets(
body: Optional[ListProjectFleetsRequest] = None,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
+ client_version: Optional[Version] = Depends(get_client_version),
):
"""
Returns all fleets in the project.
@@ -87,13 +90,14 @@ async def list_project_fleets(
_, project = user_project
if body is None:
body = ListProjectFleetsRequest()
- return CustomORJSONResponse(
- await fleets_services.list_project_fleets(
- session=session,
- project=project,
- include_imported=body.include_imported,
- )
+ fleet_list = await fleets_services.list_project_fleets(
+ session=session,
+ project=project,
+ include_imported=body.include_imported,
)
+ for fleet in fleet_list:
+ patch_fleet(fleet, client_version)
+ return CustomORJSONResponse(fleet_list)
@project_router.post("/get", response_model=Fleet)
@@ -102,6 +106,7 @@ async def get_fleet(
session: AsyncSession = Depends(get_session),
user: UserModel = Depends(Authenticated()),
project: ProjectModel = Depends(Project()),
+ client_version: Optional[Version] = Depends(get_client_version),
):
"""
Returns a fleet given `name` or `id`.
@@ -116,6 +121,7 @@ async def get_fleet(
)
if fleet is None:
raise ResourceNotExistsError()
+ patch_fleet(fleet, client_version)
return CustomORJSONResponse(fleet)
@@ -136,7 +142,7 @@ async def get_plan(
user=user,
spec=body.spec,
)
- patch_offers_list(plan.offers, client_version)
+ patch_fleet_plan(plan, client_version)
return CustomORJSONResponse(plan)
@@ -146,6 +152,7 @@ async def apply_plan(
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
pipeline_hinter: PipelineHinterProtocol = Depends(get_pipeline_hinter),
+ client_version: Optional[Version] = Depends(get_client_version),
):
"""
Creates a new fleet or updates an existing fleet.
@@ -153,16 +160,16 @@ async def apply_plan(
Use `force: true` to apply even if the current resource does not match.
"""
user, project = user_project
- return CustomORJSONResponse(
- await fleets_services.apply_plan(
- session=session,
- user=user,
- project=project,
- plan=body.plan,
- force=body.force,
- pipeline_hinter=pipeline_hinter,
- )
+ fleet = await fleets_services.apply_plan(
+ session=session,
+ user=user,
+ project=project,
+ plan=body.plan,
+ force=body.force,
+ pipeline_hinter=pipeline_hinter,
)
+ patch_fleet(fleet, client_version)
+ return CustomORJSONResponse(fleet)
@project_router.post("/create", response_model=Fleet, deprecated=True)
@@ -171,20 +178,21 @@ async def create_fleet(
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
pipeline_hinter: PipelineHinterProtocol = Depends(get_pipeline_hinter),
+ client_version: Optional[Version] = Depends(get_client_version),
):
"""
Creates a fleet given a fleet configuration.
"""
user, project = user_project
- return CustomORJSONResponse(
- await fleets_services.create_fleet(
- session=session,
- project=project,
- user=user,
- spec=body.spec,
- pipeline_hinter=pipeline_hinter,
- )
+ fleet = await fleets_services.create_fleet(
+ session=session,
+ project=project,
+ user=user,
+ spec=body.spec,
+ pipeline_hinter=pipeline_hinter,
)
+ patch_fleet(fleet, client_version)
+ return CustomORJSONResponse(fleet)
@project_router.post("/delete")
diff --git a/src/dstack/_internal/server/services/runs/plan.py b/src/dstack/_internal/server/services/runs/plan.py
index 2e608019d..ecf2f8bdd 100644
--- a/src/dstack/_internal/server/services/runs/plan.py
+++ b/src/dstack/_internal/server/services/runs/plan.py
@@ -6,6 +6,7 @@
from sqlalchemy.orm import contains_eager, noload
from dstack._internal.core.backends.base.backend import Backend
+from dstack._internal.core.models.common import EntityReference
from dstack._internal.core.models.fleets import FleetSpec, InstanceGroupPlacement
from dstack._internal.core.models.instances import (
InstanceAvailability,
@@ -228,7 +229,23 @@ async def get_run_candidate_fleet_models_filters(
if run_model is not None and run_model.fleet is not None:
fleet_filters.append(FleetModel.id == run_model.fleet_id)
if run_spec.merged_profile.fleets is not None:
- fleet_filters.append(FleetModel.name.in_(run_spec.merged_profile.fleets))
+ fleet_conditions = []
+ for ref in map(EntityReference.parse, run_spec.merged_profile.fleets):
+ if ref.project is None:
+ fleet_conditions.append(
+ and_(
+ FleetModel.name == ref.name,
+ FleetModel.project_id == project.id,
+ )
+ )
+ else:
+ fleet_conditions.append(
+ and_(
+ FleetModel.name == ref.name,
+ ProjectModel.name == ref.project,
+ )
+ )
+ fleet_filters.append(or_(*fleet_conditions))
instance_filters = [
InstanceModel.deleted == False,
InstanceModel.id.not_in(detaching_instances_ids),
@@ -247,6 +264,7 @@ async def select_run_candidate_fleet_models_with_filters(
# Then select left out fleets without instances.
stmt = (
select(FleetModel)
+ .join(FleetModel.project) # can be referenced by fleet_filters
.join(FleetModel.instances)
.where(*fleet_filters)
.where(*instance_filters)
@@ -265,6 +283,7 @@ async def select_run_candidate_fleet_models_with_filters(
fleet_models_with_instances_ids = [f.id for f in fleet_models_with_instances]
res = await session.execute(
select(FleetModel)
+ .join(FleetModel.project) # can be referenced by fleet_filters
.outerjoin(FleetModel.instances)
.where(
*fleet_filters,
diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py
index 437da4761..87db5ec8e 100644
--- a/src/dstack/_internal/server/testing/common.py
+++ b/src/dstack/_internal/server/testing/common.py
@@ -667,13 +667,17 @@ async def create_fleet(
return fm
-def get_fleet_spec(conf: Optional[FleetConfiguration] = None) -> FleetSpec:
+def get_fleet_spec(
+ conf: Optional[FleetConfiguration] = None, profile: Optional[Profile] = None
+) -> FleetSpec:
if conf is None:
conf = get_fleet_configuration()
+ if profile is None:
+ profile = Profile()
return FleetSpec(
configuration=conf,
configuration_path="fleet.dstack.yml",
- profile=Profile(),
+ profile=profile,
)
diff --git a/src/dstack/api/server/_fleets.py b/src/dstack/api/server/_fleets.py
index 95bb22e82..93f27e672 100644
--- a/src/dstack/api/server/_fleets.py
+++ b/src/dstack/api/server/_fleets.py
@@ -1,3 +1,4 @@
+import copy
from typing import List, Optional, Union
from uuid import UUID
@@ -7,6 +8,7 @@
get_apply_plan_excludes,
get_create_fleet_excludes,
get_get_plan_excludes,
+ patch_fleet_spec,
)
from dstack._internal.core.models.fleets import ApplyFleetPlanInput, Fleet, FleetPlan, FleetSpec
from dstack._internal.server.schemas.fleets import (
@@ -47,6 +49,8 @@ def get_plan(
spec: FleetSpec,
) -> FleetPlan:
body = GetFleetPlanRequest(spec=spec)
+ body = copy.deepcopy(body)
+ patch_fleet_spec(body.spec)
body_json = body.json(exclude=get_get_plan_excludes(spec))
resp = self._request(f"/api/project/{project_name}/fleets/get_plan", body=body_json)
return parse_obj_as(FleetPlan.__response__, resp.json())
@@ -59,6 +63,10 @@ def apply_plan(
) -> Fleet:
plan_input = ApplyFleetPlanInput.__response__.parse_obj(plan)
body = ApplyFleetPlanRequest(plan=plan_input, force=force)
+ body = copy.deepcopy(body)
+ patch_fleet_spec(body.plan.spec)
+ if body.plan.current_resource is not None:
+ patch_fleet_spec(body.plan.current_resource.spec)
body_json = body.json(exclude=get_apply_plan_excludes(plan_input))
resp = self._request(f"/api/project/{project_name}/fleets/apply", body=body_json)
return parse_obj_as(Fleet.__response__, resp.json())
@@ -79,6 +87,8 @@ def create(
spec: FleetSpec,
) -> Fleet:
body = CreateFleetRequest(spec=spec)
+ body = copy.deepcopy(body)
+ patch_fleet_spec(body.spec)
body_json = body.json(exclude=get_create_fleet_excludes(spec))
resp = self._request(f"/api/project/{project_name}/fleets/create", body=body_json)
return parse_obj_as(Fleet.__response__, resp.json())
diff --git a/src/dstack/api/server/_runs.py b/src/dstack/api/server/_runs.py
index ead179763..0543f3384 100644
--- a/src/dstack/api/server/_runs.py
+++ b/src/dstack/api/server/_runs.py
@@ -1,3 +1,4 @@
+import copy
from datetime import datetime
from typing import List, Optional, Union
from uuid import UUID
@@ -8,6 +9,7 @@
get_apply_plan_excludes,
get_get_plan_excludes,
get_list_runs_excludes,
+ patch_run_spec,
)
from dstack._internal.core.models.runs import (
ApplyRunPlanInput,
@@ -73,6 +75,8 @@ def get_plan(
self, project_name: str, run_spec: RunSpec, max_offers: Optional[int] = None
) -> RunPlan:
body = GetRunPlanRequest(run_spec=run_spec, max_offers=max_offers)
+ body = copy.deepcopy(body)
+ patch_run_spec(body.run_spec)
resp = self._request(
f"/api/project/{project_name}/runs/get_plan",
body=body.json(exclude=get_get_plan_excludes(body)),
@@ -87,6 +91,10 @@ def apply_plan(
) -> Run:
plan_input: ApplyRunPlanInput = ApplyRunPlanInput.__response__.parse_obj(plan)
body = ApplyRunPlanRequest(plan=plan_input, force=force)
+ body = copy.deepcopy(body)
+ patch_run_spec(body.plan.run_spec)
+ if body.plan.current_resource is not None:
+ patch_run_spec(body.plan.current_resource.run_spec)
resp = self._request(
f"/api/project/{project_name}/runs/apply",
body=body.json(exclude=get_apply_plan_excludes(plan_input)),
diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py
index a0d6124d3..1191567b5 100644
--- a/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py
+++ b/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py
@@ -8,7 +8,10 @@
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import NetworkMode
-from dstack._internal.core.models.configurations import TaskConfiguration
+from dstack._internal.core.models.configurations import (
+ DevEnvironmentConfiguration,
+ TaskConfiguration,
+)
from dstack._internal.core.models.fleets import FleetNodesSpec, InstanceGroupPlacement
from dstack._internal.core.models.health import HealthStatus
from dstack._internal.core.models.instances import (
@@ -34,6 +37,7 @@
)
from dstack._internal.server.models import (
ComputeGroupModel,
+ FleetModel,
InstanceModel,
JobModel,
PlacementGroupModel,
@@ -1218,6 +1222,120 @@ async def test_assigns_job_to_elastic_non_empty_busy_fleet_if_fleets_specified(
assert job.instance_id is None
assert job.fleet_id == fleet.id
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ @pytest.mark.parametrize(
+ ("configured_fleet", "expected_fleet_project_name"),
+ [
+ ("exporter-a/test-fleet", "exporter-a"),
+ ("exporter-b/test-fleet", "exporter-b"),
+ ("importer/test-fleet", "importer"),
+ ("test-fleet", "importer"),
+ ],
+ )
+ async def test_assigns_job_to_specified_fleet_across_projects(
+ self,
+ test_db,
+ session: AsyncSession,
+ configured_fleet: str,
+ expected_fleet_project_name: str,
+ ):
+ user = await create_user(session)
+ exporter_a = await create_project(session, name="exporter-a", owner=user)
+ exporter_b = await create_project(session, name="exporter-b", owner=user)
+ importer = await create_project(session, name="importer", owner=user)
+ fleet_a = await create_fleet(session=session, project=exporter_a, name="test-fleet")
+ fleet_b = await create_fleet(session=session, project=exporter_b, name="test-fleet")
+ await create_fleet(session=session, project=importer, name="test-fleet")
+ await create_export(
+ session=session,
+ exporter_project=exporter_a,
+ importer_projects=[importer],
+ exported_fleets=[fleet_a],
+ )
+ await create_export(
+ session=session,
+ exporter_project=exporter_b,
+ importer_projects=[importer],
+ exported_fleets=[fleet_b],
+ )
+ repo = await create_repo(session=session, project_id=importer.id)
+ run_spec = get_run_spec(
+ repo_id=repo.name,
+ configuration=DevEnvironmentConfiguration.parse_obj(
+ {"type": "dev-environment", "ide": "vscode", "fleets": [configured_fleet]}
+ ),
+ )
+ run = await create_run(
+ session=session,
+ project=importer,
+ repo=repo,
+ user=user,
+ run_spec=run_spec,
+ )
+ job = await create_job(session=session, run=run, instance_assigned=False)
+
+ await process_submitted_jobs()
+ res = await session.execute(
+ select(JobModel)
+ .where(JobModel.id == job.id)
+ .options(joinedload(JobModel.fleet).joinedload(FleetModel.project))
+ .execution_options(populate_existing=True)
+ )
+ job = res.scalar_one()
+ assert job.status == JobStatus.SUBMITTED
+ assert job.instance_assigned
+ assert job.fleet is not None
+ assert job.fleet.project.name == expected_fleet_project_name
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ async def test_not_assigns_job_to_imported_fleet_if_specified_without_project_prefix(
+ self,
+ test_db,
+ session: AsyncSession,
+ ):
+ user = await create_user(session)
+ exporter_a = await create_project(session, name="exporter-a", owner=user)
+ importer = await create_project(session, name="importer", owner=user)
+ fleet_a = await create_fleet(session=session, project=exporter_a, name="test-fleet")
+ await create_export(
+ session=session,
+ exporter_project=exporter_a,
+ importer_projects=[importer],
+ exported_fleets=[fleet_a],
+ )
+ repo = await create_repo(session=session, project_id=importer.id)
+ run_spec = get_run_spec(
+ repo_id=repo.name,
+ configuration=DevEnvironmentConfiguration.parse_obj(
+ {
+ "type": "dev-environment",
+ "ide": "vscode",
+ "fleets": ["test-fleet"], # won't work, should be exporter-a/test-fleet
+ }
+ ),
+ )
+ run = await create_run(
+ session=session,
+ project=importer,
+ repo=repo,
+ user=user,
+ run_spec=run_spec,
+ )
+ job = await create_job(session=session, run=run, instance_assigned=False)
+
+ await process_submitted_jobs()
+ res = await session.execute(
+ select(JobModel)
+ .where(JobModel.id == job.id)
+ .options(joinedload(JobModel.fleet).joinedload(FleetModel.project))
+ .execution_options(populate_existing=True)
+ )
+ job = res.scalar_one()
+ assert job.status == JobStatus.TERMINATING
+ assert job.termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY
+
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_creates_new_instance_in_existing_empty_fleet(
diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py
index d14e74e80..47544b921 100644
--- a/src/tests/_internal/server/routers/test_fleets.py
+++ b/src/tests/_internal/server/routers/test_fleets.py
@@ -1,6 +1,6 @@
import json
from datetime import datetime, timezone
-from typing import Optional
+from typing import Optional, Union
from unittest.mock import Mock, patch
from uuid import uuid4
@@ -11,6 +11,7 @@
from sqlalchemy.ext.asyncio import AsyncSession
from dstack._internal.core.models.backends.base import BackendType
+from dstack._internal.core.models.common import EntityReference
from dstack._internal.core.models.fleets import (
FleetConfiguration,
FleetStatus,
@@ -25,6 +26,7 @@
Resources,
SSHKey,
)
+from dstack._internal.core.models.profiles import Profile
from dstack._internal.core.models.users import GlobalRole, ProjectRole
from dstack._internal.server.models import FleetModel, InstanceModel
from dstack._internal.server.services.fleets import fleet_model_to_fleet
@@ -847,6 +849,66 @@ async def test_returns_403_on_foreign_fleet_if_not_imported(
)
assert response.status_code == 403
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "client_version,expected_fleets",
+ [
+ (
+ "0.20.13",
+ [
+ "my-fleet",
+ "other-project/other-fleet",
+ ],
+ ),
+ (
+ "0.20.14",
+ [
+ {"project": None, "name": "my-fleet"},
+ {"project": "other-project", "name": "other-fleet"},
+ ],
+ ),
+ (
+ None,
+ [
+ {"project": None, "name": "my-fleet"},
+ {"project": "other-project", "name": "other-fleet"},
+ ],
+ ),
+ ],
+ )
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ async def test_patches_profile_fleets_for_old_clients(
+ self,
+ test_db,
+ session: AsyncSession,
+ client: AsyncClient,
+ client_version: Optional[str],
+ expected_fleets: list,
+ ) -> None:
+ user = await create_user(session=session)
+ project = await create_project(session=session, owner=user)
+
+ fleets: list[Union[EntityReference, str]] = [
+ EntityReference(project=None, name="my-fleet"),
+ EntityReference(project="other-project", name="other-fleet"),
+ ]
+ spec = get_fleet_spec(
+ profile=Profile(fleets=fleets),
+ )
+ fleet = await create_fleet(session=session, project=project, spec=spec)
+
+ headers = get_auth_headers(user.token)
+ if client_version is not None:
+ headers["X-API-Version"] = client_version
+ response = await client.post(
+ f"/api/project/{project.name}/fleets/get",
+ headers=headers,
+ json={"id": str(fleet.id)},
+ )
+
+ assert response.status_code == 200
+ assert response.json()["spec"]["profile"]["fleets"] == expected_fleets
+
class TestApplyFleetPlan:
@pytest.mark.asyncio
diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py
index 1d72ba552..e17d71188 100644
--- a/src/tests/_internal/server/routers/test_runs.py
+++ b/src/tests/_internal/server/routers/test_runs.py
@@ -15,7 +15,7 @@
from dstack._internal import settings
from dstack._internal.core.errors import GatewayError
from dstack._internal.core.models.backends.base import BackendType
-from dstack._internal.core.models.common import ApplyAction
+from dstack._internal.core.models.common import ApplyAction, EntityReference
from dstack._internal.core.models.configurations import (
AnyRunConfiguration,
DevEnvironmentConfiguration,
@@ -32,7 +32,7 @@
InstanceType,
Resources,
)
-from dstack._internal.core.models.profiles import Schedule
+from dstack._internal.core.models.profiles import Profile, Schedule
from dstack._internal.core.models.resources import Range
from dstack._internal.core.models.runs import (
ApplyRunPlanInput,
@@ -1122,6 +1122,77 @@ async def test_patches_service_configuration_probes_for_old_clients(
assert response.status_code == 200
assert response.json()["run_spec"]["configuration"]["probes"] == expected_probes
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "client_version,expected_fleets",
+ [
+ (
+ "0.20.13",
+ [
+ "my-fleet",
+ "other-project/other-fleet",
+ ],
+ ),
+ (
+ "0.20.14",
+ [
+ {"project": None, "name": "my-fleet"},
+ {"project": "other-project", "name": "other-fleet"},
+ ],
+ ),
+ (
+ None,
+ [
+ {"project": None, "name": "my-fleet"},
+ {"project": "other-project", "name": "other-fleet"},
+ ],
+ ),
+ ],
+ )
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ async def test_patches_fleets_for_old_clients(
+ self,
+ test_db,
+ session: AsyncSession,
+ client: AsyncClient,
+ client_version: Optional[str],
+ expected_fleets: list,
+ ) -> None:
+ user = await create_user(session=session)
+ project = await create_project(session=session, owner=user)
+ repo = await create_repo(session=session, project_id=project.id)
+
+ fleets: list[Union[EntityReference, str]] = [
+ EntityReference(project=None, name="my-fleet"),
+ EntityReference(project="other-project", name="other-fleet"),
+ ]
+ run_spec = get_run_spec(
+ configuration=TaskConfiguration(
+ commands=["echo hello"],
+ fleets=fleets,
+ ),
+ repo_id=repo.name,
+ profile=Profile(
+ fleets=fleets,
+ ),
+ )
+ run = await create_run(
+ session=session, project=project, repo=repo, user=user, run_spec=run_spec
+ )
+
+ headers = get_auth_headers(user.token)
+ if client_version is not None:
+ headers["X-API-Version"] = client_version
+ response = await client.post(
+ f"/api/project/{project.name}/runs/get",
+ headers=headers,
+ json={"run_name": run.run_name},
+ )
+
+ assert response.status_code == 200
+ assert response.json()["run_spec"]["configuration"]["fleets"] == expected_fleets
+ assert response.json()["run_spec"]["profile"]["fleets"] == expected_fleets
+
class TestGetRunPlan:
@pytest.mark.asyncio
@@ -1479,6 +1550,161 @@ async def test_returns_no_offers_if_imported_ssh_fleet_is_empty(
assert response_json["project_name"] == "importer-project"
assert len(response_json["job_plans"][0]["offers"]) == 0
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ @pytest.mark.parametrize(
+ ("configured_fleet", "expected_price"),
+ [
+ ("exporter-a/test-fleet", 1.0),
+ ("exporter-b/test-fleet", 2.0),
+ ("importer/test-fleet", 3.0),
+ ("test-fleet", 3.0),
+ ],
+ )
+ async def test_returns_run_plan_offers_from_specified_fleet_across_projects(
+ self,
+ test_db,
+ session: AsyncSession,
+ client: AsyncClient,
+ configured_fleet: str,
+ expected_price: float,
+ ) -> None:
+ user = await create_user(session, global_role=GlobalRole.USER)
+ exporter_a = await create_project(session, name="exporter-a", owner=user)
+ exporter_b = await create_project(session, name="exporter-b", owner=user)
+ importer = await create_project(session, name="importer", owner=user)
+ await add_project_member(
+ session=session,
+ project=importer,
+ user=user,
+ project_role=ProjectRole.USER,
+ )
+ fleet_a = await create_fleet(
+ session=session,
+ project=exporter_a,
+ name="test-fleet",
+ spec=get_fleet_spec(get_ssh_fleet_configuration()),
+ )
+ await create_instance(
+ session=session,
+ project=exporter_a,
+ fleet=fleet_a,
+ backend=BackendType.REMOTE,
+ price=1.0,
+ )
+ fleet_b = await create_fleet(
+ session=session,
+ project=exporter_b,
+ name="test-fleet",
+ spec=get_fleet_spec(get_ssh_fleet_configuration()),
+ )
+ await create_instance(
+ session=session,
+ project=exporter_b,
+ fleet=fleet_b,
+ backend=BackendType.REMOTE,
+ price=2.0,
+ )
+ fleet_importer = await create_fleet(
+ session=session,
+ project=importer,
+ name="test-fleet",
+ spec=get_fleet_spec(get_ssh_fleet_configuration()),
+ )
+ await create_instance(
+ session=session,
+ project=importer,
+ fleet=fleet_importer,
+ backend=BackendType.REMOTE,
+ price=3.0,
+ )
+ await create_export(
+ session=session,
+ exporter_project=exporter_a,
+ importer_projects=[importer],
+ exported_fleets=[fleet_a],
+ )
+ await create_export(
+ session=session,
+ exporter_project=exporter_b,
+ importer_projects=[importer],
+ exported_fleets=[fleet_b],
+ )
+
+ run_spec = {
+ "configuration": {
+ "type": "dev-environment",
+ "ide": "vscode",
+ "fleets": [configured_fleet],
+ }
+ }
+ body = {"run_spec": run_spec}
+ response = await client.post(
+ "/api/project/importer/runs/get_plan",
+ headers=get_auth_headers(user.token),
+ json=body,
+ )
+ assert response.status_code == 200, response.json()
+ response_json = response.json()
+ assert response_json["project_name"] == "importer"
+ offers = response_json["job_plans"][0]["offers"]
+ assert offers[0]["price"] == expected_price
+ assert len(offers) == 1
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ async def test_returns_no_offers_if_imported_fleet_specified_without_project_prefix(
+ self,
+ test_db,
+ session: AsyncSession,
+ client: AsyncClient,
+ ) -> None:
+ importer_user = await create_user(session, global_role=GlobalRole.USER)
+ exporter_a = await create_project(session, name="exporter-a")
+ importer = await create_project(session, name="importer", owner=importer_user)
+ await add_project_member(
+ session=session,
+ project=importer,
+ user=importer_user,
+ project_role=ProjectRole.USER,
+ )
+ fleet_a = await create_fleet(
+ session=session,
+ project=exporter_a,
+ name="test-fleet",
+ spec=get_fleet_spec(get_ssh_fleet_configuration()),
+ )
+ await create_instance(
+ session=session,
+ project=exporter_a,
+ fleet=fleet_a,
+ backend=BackendType.REMOTE,
+ )
+ await create_export(
+ session=session,
+ exporter_project=exporter_a,
+ importer_projects=[importer],
+ exported_fleets=[fleet_a],
+ )
+
+ run_spec = {
+ "configuration": {
+ "type": "dev-environment",
+ "ide": "vscode",
+ "fleets": ["test-fleet"], # won't work, should be exporter-a/test-fleet
+ }
+ }
+ body = {"run_spec": run_spec}
+ response = await client.post(
+ "/api/project/importer/runs/get_plan",
+ headers=get_auth_headers(importer_user.token),
+ json=body,
+ )
+ assert response.status_code == 200, response.json()
+ response_json = response.json()
+ assert response_json["project_name"] == "importer"
+ assert len(response_json["job_plans"][0]["offers"]) == 0
+
@pytest.mark.parametrize(
("client_version", "expected_availability"),
[