diff --git a/docs/docs/concepts/exports.md b/docs/docs/concepts/exports.md index 3b3194146..a773094b4 100644 --- a/docs/docs/concepts/exports.md +++ b/docs/docs/concepts/exports.md @@ -134,6 +134,19 @@ $ dstack fleet list Imported fleets can be used for runs just like the project's own fleets. +
+ +```yaml +type: dev-environment +ide: vscode + +fleets: +- my-local-fleet +- team-a/my-fleet +``` + +
+ !!! info "Tenant isolation" Exported fleets share the same access model as regular fleets. See [Tenant isolation](fleets.md#tenant-isolation) for details. diff --git a/src/dstack/_internal/core/compatibility/common.py b/src/dstack/_internal/core/compatibility/common.py new file mode 100644 index 000000000..6a3445a7b --- /dev/null +++ b/src/dstack/_internal/core/compatibility/common.py @@ -0,0 +1,14 @@ +from dstack._internal.core.models.common import EntityReference +from dstack._internal.core.models.profiles import ProfileParams + + +def patch_profile_params(params: ProfileParams) -> None: + # If there are no project-prefixed fleets, replace all EntityReference with str + # for compatibility with pre-0.20.14 servers that don't support EntityReference. + if params.fleets is not None and all( + EntityReference.parse(f).project is None for f in params.fleets + ): + params.fleets = [ + fleet_ref.format() if isinstance(fleet_ref, EntityReference) else fleet_ref + for fleet_ref in params.fleets + ] diff --git a/src/dstack/_internal/core/compatibility/fleets.py b/src/dstack/_internal/core/compatibility/fleets.py index e22903df3..04a92ed77 100644 --- a/src/dstack/_internal/core/compatibility/fleets.py +++ b/src/dstack/_internal/core/compatibility/fleets.py @@ -1,6 +1,10 @@ from typing import Optional -from dstack._internal.core.models.common import IncludeExcludeDictType, IncludeExcludeSetType +from dstack._internal.core.compatibility.common import patch_profile_params +from dstack._internal.core.models.common import ( + IncludeExcludeDictType, + IncludeExcludeSetType, +) from dstack._internal.core.models.fleets import ApplyFleetPlanInput, FleetSpec @@ -56,3 +60,7 @@ def get_fleet_spec_excludes(fleet_spec: FleetSpec) -> Optional[IncludeExcludeDic if spec_excludes: return spec_excludes return None + + +def patch_fleet_spec(spec: FleetSpec) -> None: + patch_profile_params(spec.profile) diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py index 2b0e3a6b4..e08b8260f 100644 --- a/src/dstack/_internal/core/compatibility/runs.py +++ b/src/dstack/_internal/core/compatibility/runs.py @@ -1,6 +1,10 @@ from typing import Optional -from dstack._internal.core.models.common import IncludeExcludeDictType, IncludeExcludeSetType +from dstack._internal.core.compatibility.common import patch_profile_params +from dstack._internal.core.models.common import ( + IncludeExcludeDictType, + IncludeExcludeSetType, +) from dstack._internal.core.models.configurations import ServiceConfiguration from dstack._internal.core.models.routers import SGLangServiceRouterConfig from dstack._internal.core.models.runs import ( @@ -138,3 +142,9 @@ def get_job_submission_excludes(job_submissions: list[JobSubmission]) -> Include submission_excludes["job_runtime_data"] = jrd_excludes return submission_excludes + + +def patch_run_spec(run_spec: RunSpec) -> None: + patch_profile_params(run_spec.configuration) + if run_spec.profile is not None: + patch_profile_params(run_spec.profile) diff --git a/src/dstack/_internal/core/models/common.py b/src/dstack/_internal/core/models/common.py index f55a032ba..a3bf68cff 100644 --- a/src/dstack/_internal/core/models/common.py +++ b/src/dstack/_internal/core/models/common.py @@ -143,3 +143,31 @@ class ApplyAction(str, Enum): class NetworkMode(str, Enum): HOST = "host" BRIDGE = "bridge" + + +class EntityReference(CoreModel): + """ + Cross-project entity reference. + """ + + project: Annotated[ + Optional[str], + Field(description="The project name. If unspecified, refers to the current project"), + ] + name: Annotated[str, Field(description="The entity name")] + + @classmethod + def parse(cls, v: Union[str, "EntityReference"]) -> "EntityReference": + if isinstance(v, EntityReference): + return v + parts = v.split("/") + if len(parts) == 1: + return cls(project=None, name=parts[0]) + if len(parts) == 2: + return cls(project=parts[0], name=parts[1]) + raise ValueError("Invalid entity reference. Only `/` format is allowed") + + def format(self) -> str: + if self.project is None: + return self.name + return f"{self.project}/{self.name}" diff --git a/src/dstack/_internal/core/models/profiles.py b/src/dstack/_internal/core/models/profiles.py index 97ce591aa..517514694 100644 --- a/src/dstack/_internal/core/models/profiles.py +++ b/src/dstack/_internal/core/models/profiles.py @@ -10,6 +10,7 @@ CoreConfig, CoreModel, Duration, + EntityReference, generate_dual_core_model, ) from dstack._internal.utils.common import list_enum_values_for_annotation @@ -360,7 +361,21 @@ class ProfileParams(CoreModel): Field(description=("The schedule for starting the run at specified time")), ] = None fleets: Annotated[ - Optional[list[str]], Field(description="The fleets considered for reuse") + Optional[ + list[ + Union[ + EntityReference, + str, # For server response compatibility with pre-0.20.14 clients + ] + ] + ], + Field( + description=( + "The fleets considered for reuse." + " For fleets owned by the current project, specify fleet names." + " For imported fleets, specify `/`" + ), + ), ] = None tags: Annotated[ Optional[Dict[str, str]], @@ -382,6 +397,7 @@ class ProfileParams(CoreModel): _validate_idle_duration = validator("idle_duration", pre=True, allow_reuse=True)( parse_idle_duration ) + _validate_fleets = validator("fleets", allow_reuse=True, each_item=True)(EntityReference.parse) _validate_tags = validator("tags", pre=True, allow_reuse=True)(tags_validator) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py index 52217eefb..eaa452380 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py @@ -975,6 +975,7 @@ async def _refetch_fleet_models_with_instances( ) -> list[FleetModel]: res = await session.execute( select(FleetModel) + .join(FleetModel.project) # can be referenced by fleet_filters .outerjoin(FleetModel.instances) .where( FleetModel.id.in_(fleets_ids), diff --git a/src/dstack/_internal/server/compatibility/common.py b/src/dstack/_internal/server/compatibility/common.py index 227b45fda..ce982b673 100644 --- a/src/dstack/_internal/server/compatibility/common.py +++ b/src/dstack/_internal/server/compatibility/common.py @@ -2,10 +2,23 @@ from packaging.version import Version +from dstack._internal.core.models.common import EntityReference from dstack._internal.core.models.instances import ( InstanceAvailability, InstanceOfferWithAvailability, ) +from dstack._internal.core.models.profiles import ProfileParams + + +def patch_profile_params(params: ProfileParams, client_version: Optional[Version]) -> None: + if client_version is None: + return + # Clients prior to 0.20.14 only support `list[str]` in `fleets` + if client_version < Version("0.20.14") and params.fleets is not None: + params.fleets = [ + fleet_ref.format() if isinstance(fleet_ref, EntityReference) else fleet_ref + for fleet_ref in params.fleets + ] def patch_offers_list( diff --git a/src/dstack/_internal/server/compatibility/fleets.py b/src/dstack/_internal/server/compatibility/fleets.py new file mode 100644 index 000000000..ddd90d14d --- /dev/null +++ b/src/dstack/_internal/server/compatibility/fleets.py @@ -0,0 +1,23 @@ +from typing import Optional + +from packaging.version import Version + +from dstack._internal.core.models.fleets import Fleet, FleetPlan, FleetSpec +from dstack._internal.server.compatibility.common import patch_offers_list, patch_profile_params + + +def patch_fleet_plan(fleet_plan: FleetPlan, client_version: Optional[Version]) -> None: + patch_fleet_spec(fleet_plan.spec, client_version) + if fleet_plan.effective_spec is not None: + patch_fleet_spec(fleet_plan.effective_spec, client_version) + if fleet_plan.current_resource is not None: + patch_fleet(fleet_plan.current_resource, client_version) + patch_offers_list(fleet_plan.offers, client_version) + + +def patch_fleet(fleet: Fleet, client_version: Optional[Version]) -> None: + patch_fleet_spec(fleet.spec, client_version) + + +def patch_fleet_spec(fleet_spec: FleetSpec, client_version: Optional[Version]) -> None: + patch_profile_params(fleet_spec.profile, client_version) diff --git a/src/dstack/_internal/server/compatibility/runs.py b/src/dstack/_internal/server/compatibility/runs.py index 77df3f5fa..9a5d0d99b 100644 --- a/src/dstack/_internal/server/compatibility/runs.py +++ b/src/dstack/_internal/server/compatibility/runs.py @@ -4,7 +4,7 @@ from dstack._internal.core.models.configurations import SERVICE_HTTPS_DEFAULT, ServiceConfiguration from dstack._internal.core.models.runs import Run, RunPlan, RunSpec -from dstack._internal.server.compatibility.common import patch_offers_list +from dstack._internal.server.compatibility.common import patch_offers_list, patch_profile_params def patch_run_plan(run_plan: RunPlan, client_version: Optional[Version]) -> None: @@ -41,3 +41,6 @@ def patch_run_spec(run_spec: RunSpec, client_version: Optional[Version]) -> None and run_spec.configuration.https is None ): run_spec.configuration.https = SERVICE_HTTPS_DEFAULT + patch_profile_params(run_spec.configuration, client_version) + if run_spec.profile is not None: + patch_profile_params(run_spec.profile, client_version) diff --git a/src/dstack/_internal/server/routers/fleets.py b/src/dstack/_internal/server/routers/fleets.py index 58c87d653..4aea9a03c 100644 --- a/src/dstack/_internal/server/routers/fleets.py +++ b/src/dstack/_internal/server/routers/fleets.py @@ -7,7 +7,7 @@ import dstack._internal.server.services.fleets as fleets_services from dstack._internal.core.errors import ResourceNotExistsError from dstack._internal.core.models.fleets import Fleet, FleetPlan -from dstack._internal.server.compatibility.common import patch_offers_list +from dstack._internal.server.compatibility.fleets import patch_fleet, patch_fleet_plan from dstack._internal.server.db import get_session from dstack._internal.server.deps import Project from dstack._internal.server.models import ProjectModel, UserModel @@ -50,6 +50,7 @@ async def list_fleets( body: ListFleetsRequest, session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), + client_version: Optional[Version] = Depends(get_client_version), ): """ Returns all fleets and instances within them visible to user sorted by descending `created_at`. @@ -59,19 +60,20 @@ async def list_fleets( The results are paginated. To get the next page, pass `created_at` and `id` of the last fleet from the previous page as `prev_created_at` and `prev_id`. """ - return CustomORJSONResponse( - await fleets_services.list_fleets( - session=session, - user=user, - project_name=body.project_name, - only_active=body.only_active, - include_imported=body.include_imported, - prev_created_at=body.prev_created_at, - prev_id=body.prev_id, - limit=body.limit, - ascending=body.ascending, - ) + fleet_list = await fleets_services.list_fleets( + session=session, + user=user, + project_name=body.project_name, + only_active=body.only_active, + include_imported=body.include_imported, + prev_created_at=body.prev_created_at, + prev_id=body.prev_id, + limit=body.limit, + ascending=body.ascending, ) + for fleet in fleet_list: + patch_fleet(fleet, client_version) + return CustomORJSONResponse(fleet_list) @project_router.post("/list", response_model=List[Fleet]) @@ -79,6 +81,7 @@ async def list_project_fleets( body: Optional[ListProjectFleetsRequest] = None, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), + client_version: Optional[Version] = Depends(get_client_version), ): """ Returns all fleets in the project. @@ -87,13 +90,14 @@ async def list_project_fleets( _, project = user_project if body is None: body = ListProjectFleetsRequest() - return CustomORJSONResponse( - await fleets_services.list_project_fleets( - session=session, - project=project, - include_imported=body.include_imported, - ) + fleet_list = await fleets_services.list_project_fleets( + session=session, + project=project, + include_imported=body.include_imported, ) + for fleet in fleet_list: + patch_fleet(fleet, client_version) + return CustomORJSONResponse(fleet_list) @project_router.post("/get", response_model=Fleet) @@ -102,6 +106,7 @@ async def get_fleet( session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), project: ProjectModel = Depends(Project()), + client_version: Optional[Version] = Depends(get_client_version), ): """ Returns a fleet given `name` or `id`. @@ -116,6 +121,7 @@ async def get_fleet( ) if fleet is None: raise ResourceNotExistsError() + patch_fleet(fleet, client_version) return CustomORJSONResponse(fleet) @@ -136,7 +142,7 @@ async def get_plan( user=user, spec=body.spec, ) - patch_offers_list(plan.offers, client_version) + patch_fleet_plan(plan, client_version) return CustomORJSONResponse(plan) @@ -146,6 +152,7 @@ async def apply_plan( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), pipeline_hinter: PipelineHinterProtocol = Depends(get_pipeline_hinter), + client_version: Optional[Version] = Depends(get_client_version), ): """ Creates a new fleet or updates an existing fleet. @@ -153,16 +160,16 @@ async def apply_plan( Use `force: true` to apply even if the current resource does not match. """ user, project = user_project - return CustomORJSONResponse( - await fleets_services.apply_plan( - session=session, - user=user, - project=project, - plan=body.plan, - force=body.force, - pipeline_hinter=pipeline_hinter, - ) + fleet = await fleets_services.apply_plan( + session=session, + user=user, + project=project, + plan=body.plan, + force=body.force, + pipeline_hinter=pipeline_hinter, ) + patch_fleet(fleet, client_version) + return CustomORJSONResponse(fleet) @project_router.post("/create", response_model=Fleet, deprecated=True) @@ -171,20 +178,21 @@ async def create_fleet( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), pipeline_hinter: PipelineHinterProtocol = Depends(get_pipeline_hinter), + client_version: Optional[Version] = Depends(get_client_version), ): """ Creates a fleet given a fleet configuration. """ user, project = user_project - return CustomORJSONResponse( - await fleets_services.create_fleet( - session=session, - project=project, - user=user, - spec=body.spec, - pipeline_hinter=pipeline_hinter, - ) + fleet = await fleets_services.create_fleet( + session=session, + project=project, + user=user, + spec=body.spec, + pipeline_hinter=pipeline_hinter, ) + patch_fleet(fleet, client_version) + return CustomORJSONResponse(fleet) @project_router.post("/delete") diff --git a/src/dstack/_internal/server/services/runs/plan.py b/src/dstack/_internal/server/services/runs/plan.py index 2e608019d..ecf2f8bdd 100644 --- a/src/dstack/_internal/server/services/runs/plan.py +++ b/src/dstack/_internal/server/services/runs/plan.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import contains_eager, noload from dstack._internal.core.backends.base.backend import Backend +from dstack._internal.core.models.common import EntityReference from dstack._internal.core.models.fleets import FleetSpec, InstanceGroupPlacement from dstack._internal.core.models.instances import ( InstanceAvailability, @@ -228,7 +229,23 @@ async def get_run_candidate_fleet_models_filters( if run_model is not None and run_model.fleet is not None: fleet_filters.append(FleetModel.id == run_model.fleet_id) if run_spec.merged_profile.fleets is not None: - fleet_filters.append(FleetModel.name.in_(run_spec.merged_profile.fleets)) + fleet_conditions = [] + for ref in map(EntityReference.parse, run_spec.merged_profile.fleets): + if ref.project is None: + fleet_conditions.append( + and_( + FleetModel.name == ref.name, + FleetModel.project_id == project.id, + ) + ) + else: + fleet_conditions.append( + and_( + FleetModel.name == ref.name, + ProjectModel.name == ref.project, + ) + ) + fleet_filters.append(or_(*fleet_conditions)) instance_filters = [ InstanceModel.deleted == False, InstanceModel.id.not_in(detaching_instances_ids), @@ -247,6 +264,7 @@ async def select_run_candidate_fleet_models_with_filters( # Then select left out fleets without instances. stmt = ( select(FleetModel) + .join(FleetModel.project) # can be referenced by fleet_filters .join(FleetModel.instances) .where(*fleet_filters) .where(*instance_filters) @@ -265,6 +283,7 @@ async def select_run_candidate_fleet_models_with_filters( fleet_models_with_instances_ids = [f.id for f in fleet_models_with_instances] res = await session.execute( select(FleetModel) + .join(FleetModel.project) # can be referenced by fleet_filters .outerjoin(FleetModel.instances) .where( *fleet_filters, diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 437da4761..87db5ec8e 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -667,13 +667,17 @@ async def create_fleet( return fm -def get_fleet_spec(conf: Optional[FleetConfiguration] = None) -> FleetSpec: +def get_fleet_spec( + conf: Optional[FleetConfiguration] = None, profile: Optional[Profile] = None +) -> FleetSpec: if conf is None: conf = get_fleet_configuration() + if profile is None: + profile = Profile() return FleetSpec( configuration=conf, configuration_path="fleet.dstack.yml", - profile=Profile(), + profile=profile, ) diff --git a/src/dstack/api/server/_fleets.py b/src/dstack/api/server/_fleets.py index 95bb22e82..93f27e672 100644 --- a/src/dstack/api/server/_fleets.py +++ b/src/dstack/api/server/_fleets.py @@ -1,3 +1,4 @@ +import copy from typing import List, Optional, Union from uuid import UUID @@ -7,6 +8,7 @@ get_apply_plan_excludes, get_create_fleet_excludes, get_get_plan_excludes, + patch_fleet_spec, ) from dstack._internal.core.models.fleets import ApplyFleetPlanInput, Fleet, FleetPlan, FleetSpec from dstack._internal.server.schemas.fleets import ( @@ -47,6 +49,8 @@ def get_plan( spec: FleetSpec, ) -> FleetPlan: body = GetFleetPlanRequest(spec=spec) + body = copy.deepcopy(body) + patch_fleet_spec(body.spec) body_json = body.json(exclude=get_get_plan_excludes(spec)) resp = self._request(f"/api/project/{project_name}/fleets/get_plan", body=body_json) return parse_obj_as(FleetPlan.__response__, resp.json()) @@ -59,6 +63,10 @@ def apply_plan( ) -> Fleet: plan_input = ApplyFleetPlanInput.__response__.parse_obj(plan) body = ApplyFleetPlanRequest(plan=plan_input, force=force) + body = copy.deepcopy(body) + patch_fleet_spec(body.plan.spec) + if body.plan.current_resource is not None: + patch_fleet_spec(body.plan.current_resource.spec) body_json = body.json(exclude=get_apply_plan_excludes(plan_input)) resp = self._request(f"/api/project/{project_name}/fleets/apply", body=body_json) return parse_obj_as(Fleet.__response__, resp.json()) @@ -79,6 +87,8 @@ def create( spec: FleetSpec, ) -> Fleet: body = CreateFleetRequest(spec=spec) + body = copy.deepcopy(body) + patch_fleet_spec(body.spec) body_json = body.json(exclude=get_create_fleet_excludes(spec)) resp = self._request(f"/api/project/{project_name}/fleets/create", body=body_json) return parse_obj_as(Fleet.__response__, resp.json()) diff --git a/src/dstack/api/server/_runs.py b/src/dstack/api/server/_runs.py index ead179763..0543f3384 100644 --- a/src/dstack/api/server/_runs.py +++ b/src/dstack/api/server/_runs.py @@ -1,3 +1,4 @@ +import copy from datetime import datetime from typing import List, Optional, Union from uuid import UUID @@ -8,6 +9,7 @@ get_apply_plan_excludes, get_get_plan_excludes, get_list_runs_excludes, + patch_run_spec, ) from dstack._internal.core.models.runs import ( ApplyRunPlanInput, @@ -73,6 +75,8 @@ def get_plan( self, project_name: str, run_spec: RunSpec, max_offers: Optional[int] = None ) -> RunPlan: body = GetRunPlanRequest(run_spec=run_spec, max_offers=max_offers) + body = copy.deepcopy(body) + patch_run_spec(body.run_spec) resp = self._request( f"/api/project/{project_name}/runs/get_plan", body=body.json(exclude=get_get_plan_excludes(body)), @@ -87,6 +91,10 @@ def apply_plan( ) -> Run: plan_input: ApplyRunPlanInput = ApplyRunPlanInput.__response__.parse_obj(plan) body = ApplyRunPlanRequest(plan=plan_input, force=force) + body = copy.deepcopy(body) + patch_run_spec(body.plan.run_spec) + if body.plan.current_resource is not None: + patch_run_spec(body.plan.current_resource.run_spec) resp = self._request( f"/api/project/{project_name}/runs/apply", body=body.json(exclude=get_apply_plan_excludes(plan_input)), diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py index a0d6124d3..1191567b5 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py @@ -8,7 +8,10 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import NetworkMode -from dstack._internal.core.models.configurations import TaskConfiguration +from dstack._internal.core.models.configurations import ( + DevEnvironmentConfiguration, + TaskConfiguration, +) from dstack._internal.core.models.fleets import FleetNodesSpec, InstanceGroupPlacement from dstack._internal.core.models.health import HealthStatus from dstack._internal.core.models.instances import ( @@ -34,6 +37,7 @@ ) from dstack._internal.server.models import ( ComputeGroupModel, + FleetModel, InstanceModel, JobModel, PlacementGroupModel, @@ -1218,6 +1222,120 @@ async def test_assigns_job_to_elastic_non_empty_busy_fleet_if_fleets_specified( assert job.instance_id is None assert job.fleet_id == fleet.id + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.parametrize( + ("configured_fleet", "expected_fleet_project_name"), + [ + ("exporter-a/test-fleet", "exporter-a"), + ("exporter-b/test-fleet", "exporter-b"), + ("importer/test-fleet", "importer"), + ("test-fleet", "importer"), + ], + ) + async def test_assigns_job_to_specified_fleet_across_projects( + self, + test_db, + session: AsyncSession, + configured_fleet: str, + expected_fleet_project_name: str, + ): + user = await create_user(session) + exporter_a = await create_project(session, name="exporter-a", owner=user) + exporter_b = await create_project(session, name="exporter-b", owner=user) + importer = await create_project(session, name="importer", owner=user) + fleet_a = await create_fleet(session=session, project=exporter_a, name="test-fleet") + fleet_b = await create_fleet(session=session, project=exporter_b, name="test-fleet") + await create_fleet(session=session, project=importer, name="test-fleet") + await create_export( + session=session, + exporter_project=exporter_a, + importer_projects=[importer], + exported_fleets=[fleet_a], + ) + await create_export( + session=session, + exporter_project=exporter_b, + importer_projects=[importer], + exported_fleets=[fleet_b], + ) + repo = await create_repo(session=session, project_id=importer.id) + run_spec = get_run_spec( + repo_id=repo.name, + configuration=DevEnvironmentConfiguration.parse_obj( + {"type": "dev-environment", "ide": "vscode", "fleets": [configured_fleet]} + ), + ) + run = await create_run( + session=session, + project=importer, + repo=repo, + user=user, + run_spec=run_spec, + ) + job = await create_job(session=session, run=run, instance_assigned=False) + + await process_submitted_jobs() + res = await session.execute( + select(JobModel) + .where(JobModel.id == job.id) + .options(joinedload(JobModel.fleet).joinedload(FleetModel.project)) + .execution_options(populate_existing=True) + ) + job = res.scalar_one() + assert job.status == JobStatus.SUBMITTED + assert job.instance_assigned + assert job.fleet is not None + assert job.fleet.project.name == expected_fleet_project_name + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_not_assigns_job_to_imported_fleet_if_specified_without_project_prefix( + self, + test_db, + session: AsyncSession, + ): + user = await create_user(session) + exporter_a = await create_project(session, name="exporter-a", owner=user) + importer = await create_project(session, name="importer", owner=user) + fleet_a = await create_fleet(session=session, project=exporter_a, name="test-fleet") + await create_export( + session=session, + exporter_project=exporter_a, + importer_projects=[importer], + exported_fleets=[fleet_a], + ) + repo = await create_repo(session=session, project_id=importer.id) + run_spec = get_run_spec( + repo_id=repo.name, + configuration=DevEnvironmentConfiguration.parse_obj( + { + "type": "dev-environment", + "ide": "vscode", + "fleets": ["test-fleet"], # won't work, should be exporter-a/test-fleet + } + ), + ) + run = await create_run( + session=session, + project=importer, + repo=repo, + user=user, + run_spec=run_spec, + ) + job = await create_job(session=session, run=run, instance_assigned=False) + + await process_submitted_jobs() + res = await session.execute( + select(JobModel) + .where(JobModel.id == job.id) + .options(joinedload(JobModel.fleet).joinedload(FleetModel.project)) + .execution_options(populate_existing=True) + ) + job = res.scalar_one() + assert job.status == JobStatus.TERMINATING + assert job.termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_creates_new_instance_in_existing_empty_fleet( diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index d14e74e80..47544b921 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -1,6 +1,6 @@ import json from datetime import datetime, timezone -from typing import Optional +from typing import Optional, Union from unittest.mock import Mock, patch from uuid import uuid4 @@ -11,6 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.common import EntityReference from dstack._internal.core.models.fleets import ( FleetConfiguration, FleetStatus, @@ -25,6 +26,7 @@ Resources, SSHKey, ) +from dstack._internal.core.models.profiles import Profile from dstack._internal.core.models.users import GlobalRole, ProjectRole from dstack._internal.server.models import FleetModel, InstanceModel from dstack._internal.server.services.fleets import fleet_model_to_fleet @@ -847,6 +849,66 @@ async def test_returns_403_on_foreign_fleet_if_not_imported( ) assert response.status_code == 403 + @pytest.mark.asyncio + @pytest.mark.parametrize( + "client_version,expected_fleets", + [ + ( + "0.20.13", + [ + "my-fleet", + "other-project/other-fleet", + ], + ), + ( + "0.20.14", + [ + {"project": None, "name": "my-fleet"}, + {"project": "other-project", "name": "other-fleet"}, + ], + ), + ( + None, + [ + {"project": None, "name": "my-fleet"}, + {"project": "other-project", "name": "other-fleet"}, + ], + ), + ], + ) + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_patches_profile_fleets_for_old_clients( + self, + test_db, + session: AsyncSession, + client: AsyncClient, + client_version: Optional[str], + expected_fleets: list, + ) -> None: + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + + fleets: list[Union[EntityReference, str]] = [ + EntityReference(project=None, name="my-fleet"), + EntityReference(project="other-project", name="other-fleet"), + ] + spec = get_fleet_spec( + profile=Profile(fleets=fleets), + ) + fleet = await create_fleet(session=session, project=project, spec=spec) + + headers = get_auth_headers(user.token) + if client_version is not None: + headers["X-API-Version"] = client_version + response = await client.post( + f"/api/project/{project.name}/fleets/get", + headers=headers, + json={"id": str(fleet.id)}, + ) + + assert response.status_code == 200 + assert response.json()["spec"]["profile"]["fleets"] == expected_fleets + class TestApplyFleetPlan: @pytest.mark.asyncio diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 1d72ba552..e17d71188 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -15,7 +15,7 @@ from dstack._internal import settings from dstack._internal.core.errors import GatewayError from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.common import ApplyAction +from dstack._internal.core.models.common import ApplyAction, EntityReference from dstack._internal.core.models.configurations import ( AnyRunConfiguration, DevEnvironmentConfiguration, @@ -32,7 +32,7 @@ InstanceType, Resources, ) -from dstack._internal.core.models.profiles import Schedule +from dstack._internal.core.models.profiles import Profile, Schedule from dstack._internal.core.models.resources import Range from dstack._internal.core.models.runs import ( ApplyRunPlanInput, @@ -1122,6 +1122,77 @@ async def test_patches_service_configuration_probes_for_old_clients( assert response.status_code == 200 assert response.json()["run_spec"]["configuration"]["probes"] == expected_probes + @pytest.mark.asyncio + @pytest.mark.parametrize( + "client_version,expected_fleets", + [ + ( + "0.20.13", + [ + "my-fleet", + "other-project/other-fleet", + ], + ), + ( + "0.20.14", + [ + {"project": None, "name": "my-fleet"}, + {"project": "other-project", "name": "other-fleet"}, + ], + ), + ( + None, + [ + {"project": None, "name": "my-fleet"}, + {"project": "other-project", "name": "other-fleet"}, + ], + ), + ], + ) + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_patches_fleets_for_old_clients( + self, + test_db, + session: AsyncSession, + client: AsyncClient, + client_version: Optional[str], + expected_fleets: list, + ) -> None: + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + + fleets: list[Union[EntityReference, str]] = [ + EntityReference(project=None, name="my-fleet"), + EntityReference(project="other-project", name="other-fleet"), + ] + run_spec = get_run_spec( + configuration=TaskConfiguration( + commands=["echo hello"], + fleets=fleets, + ), + repo_id=repo.name, + profile=Profile( + fleets=fleets, + ), + ) + run = await create_run( + session=session, project=project, repo=repo, user=user, run_spec=run_spec + ) + + headers = get_auth_headers(user.token) + if client_version is not None: + headers["X-API-Version"] = client_version + response = await client.post( + f"/api/project/{project.name}/runs/get", + headers=headers, + json={"run_name": run.run_name}, + ) + + assert response.status_code == 200 + assert response.json()["run_spec"]["configuration"]["fleets"] == expected_fleets + assert response.json()["run_spec"]["profile"]["fleets"] == expected_fleets + class TestGetRunPlan: @pytest.mark.asyncio @@ -1479,6 +1550,161 @@ async def test_returns_no_offers_if_imported_ssh_fleet_is_empty( assert response_json["project_name"] == "importer-project" assert len(response_json["job_plans"][0]["offers"]) == 0 + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.parametrize( + ("configured_fleet", "expected_price"), + [ + ("exporter-a/test-fleet", 1.0), + ("exporter-b/test-fleet", 2.0), + ("importer/test-fleet", 3.0), + ("test-fleet", 3.0), + ], + ) + async def test_returns_run_plan_offers_from_specified_fleet_across_projects( + self, + test_db, + session: AsyncSession, + client: AsyncClient, + configured_fleet: str, + expected_price: float, + ) -> None: + user = await create_user(session, global_role=GlobalRole.USER) + exporter_a = await create_project(session, name="exporter-a", owner=user) + exporter_b = await create_project(session, name="exporter-b", owner=user) + importer = await create_project(session, name="importer", owner=user) + await add_project_member( + session=session, + project=importer, + user=user, + project_role=ProjectRole.USER, + ) + fleet_a = await create_fleet( + session=session, + project=exporter_a, + name="test-fleet", + spec=get_fleet_spec(get_ssh_fleet_configuration()), + ) + await create_instance( + session=session, + project=exporter_a, + fleet=fleet_a, + backend=BackendType.REMOTE, + price=1.0, + ) + fleet_b = await create_fleet( + session=session, + project=exporter_b, + name="test-fleet", + spec=get_fleet_spec(get_ssh_fleet_configuration()), + ) + await create_instance( + session=session, + project=exporter_b, + fleet=fleet_b, + backend=BackendType.REMOTE, + price=2.0, + ) + fleet_importer = await create_fleet( + session=session, + project=importer, + name="test-fleet", + spec=get_fleet_spec(get_ssh_fleet_configuration()), + ) + await create_instance( + session=session, + project=importer, + fleet=fleet_importer, + backend=BackendType.REMOTE, + price=3.0, + ) + await create_export( + session=session, + exporter_project=exporter_a, + importer_projects=[importer], + exported_fleets=[fleet_a], + ) + await create_export( + session=session, + exporter_project=exporter_b, + importer_projects=[importer], + exported_fleets=[fleet_b], + ) + + run_spec = { + "configuration": { + "type": "dev-environment", + "ide": "vscode", + "fleets": [configured_fleet], + } + } + body = {"run_spec": run_spec} + response = await client.post( + "/api/project/importer/runs/get_plan", + headers=get_auth_headers(user.token), + json=body, + ) + assert response.status_code == 200, response.json() + response_json = response.json() + assert response_json["project_name"] == "importer" + offers = response_json["job_plans"][0]["offers"] + assert offers[0]["price"] == expected_price + assert len(offers) == 1 + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_no_offers_if_imported_fleet_specified_without_project_prefix( + self, + test_db, + session: AsyncSession, + client: AsyncClient, + ) -> None: + importer_user = await create_user(session, global_role=GlobalRole.USER) + exporter_a = await create_project(session, name="exporter-a") + importer = await create_project(session, name="importer", owner=importer_user) + await add_project_member( + session=session, + project=importer, + user=importer_user, + project_role=ProjectRole.USER, + ) + fleet_a = await create_fleet( + session=session, + project=exporter_a, + name="test-fleet", + spec=get_fleet_spec(get_ssh_fleet_configuration()), + ) + await create_instance( + session=session, + project=exporter_a, + fleet=fleet_a, + backend=BackendType.REMOTE, + ) + await create_export( + session=session, + exporter_project=exporter_a, + importer_projects=[importer], + exported_fleets=[fleet_a], + ) + + run_spec = { + "configuration": { + "type": "dev-environment", + "ide": "vscode", + "fleets": ["test-fleet"], # won't work, should be exporter-a/test-fleet + } + } + body = {"run_spec": run_spec} + response = await client.post( + "/api/project/importer/runs/get_plan", + headers=get_auth_headers(importer_user.token), + json=body, + ) + assert response.status_code == 200, response.json() + response_json = response.json() + assert response_json["project_name"] == "importer" + assert len(response_json["job_plans"][0]["offers"]) == 0 + @pytest.mark.parametrize( ("client_version", "expected_availability"), [