From 9f2ac41d83b696bd699e46977cecc602baff7919 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 18 Mar 2026 14:57:04 +0500 Subject: [PATCH 01/14] Drop unrechable non-services code in _handle_run_replicas --- .../server/background/scheduled_tasks/runs.py | 46 ++----------------- 1 file changed, 5 insertions(+), 41 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/runs.py b/src/dstack/_internal/server/background/scheduled_tasks/runs.py index 5bb373ae0..0e1e0370f 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/runs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/runs.py @@ -259,6 +259,7 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel): await scale_run_replicas_per_group(session, run_model, replicas) else: + # Non-service pending runs may have 0 job submissions and require new submission, e.g. scheduled tasks. run_model.desired_replica_count = 1 await scale_run_replicas(session, run_model, replicas_diff=run_model.desired_replica_count) @@ -570,47 +571,10 @@ async def _handle_run_replicas( run_spec=run_spec, ) if has_out_of_date_replicas(run_model): - assert run_spec.configuration.type == "service", ( - "Rolling deployment is only supported for services" - ) - non_terminated_replica_count = len( - {j.replica_num for j in run_model.jobs if not j.status.is_finished()} - ) - # Avoid using too much hardware during a deployment - never have - # more than max_replica_count non-terminated replicas. - if non_terminated_replica_count < max_replica_count: - # Start more up-to-date replicas that will eventually replace out-of-date replicas. - await scale_run_replicas( - session, - run_model, - replicas_diff=max_replica_count - non_terminated_replica_count, - ) - - replicas_to_stop_count = 0 - # stop any out-of-date replicas that are not registered - replicas_to_stop_count += sum( - any(j.deployment_num < run_model.deployment_num for j in jobs) - and any( - j.status not in [JobStatus.TERMINATING] + JobStatus.finished_statuses() - for j in jobs - ) - and not is_replica_registered(jobs) - for _, jobs in group_jobs_by_replica_latest(run_model.jobs) - ) - # stop excessive registered out-of-date replicas, except those that are already `terminating` - non_terminating_registered_replicas_count = sum( - is_replica_registered(jobs) and all(j.status != JobStatus.TERMINATING for j in jobs) - for _, jobs in group_jobs_by_replica_latest(run_model.jobs) - ) - replicas_to_stop_count += max( - 0, non_terminating_registered_replicas_count - run_model.desired_replica_count - ) - if replicas_to_stop_count: - await scale_run_replicas( - session, - run_model, - replicas_diff=-replicas_to_stop_count, - ) + # Currently, only services can change job spec on update, + # so for other runs out-of-date replicas are not possible. + # Keeping assert in case this changes. + assert "Rolling deployment is only supported for services" async def _update_jobs_to_new_deployment_in_place( From ff5a54f83aee723c5ebee227f514d9bf086e3ae2 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 18 Mar 2026 15:05:53 +0500 Subject: [PATCH 02/14] Add more process_runs tests --- .../background/scheduled_tasks/test_runs.py | 227 +++++++++++++++++- 1 file changed, 226 insertions(+), 1 deletion(-) diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_runs.py b/src/tests/_internal/server/background/scheduled_tasks/test_runs.py index ffb63de35..eb45ebcb6 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_runs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_runs.py @@ -11,11 +11,18 @@ import dstack._internal.server.background.scheduled_tasks.runs as process_runs from dstack._internal.core.models.configurations import ( ProbeConfig, + ReplicaGroup, ServiceConfiguration, TaskConfiguration, ) from dstack._internal.core.models.instances import InstanceStatus -from dstack._internal.core.models.profiles import Profile, ProfileRetry, RetryEvent, Schedule +from dstack._internal.core.models.profiles import ( + Profile, + ProfileRetry, + RetryEvent, + Schedule, + StopCriteria, +) from dstack._internal.core.models.resources import Range from dstack._internal.core.models.runs import ( JobSpec, @@ -232,6 +239,43 @@ async def test_retry_running_to_failed(self, test_db, session: AsyncSession): assert run.status == RunStatus.TERMINATING assert run.termination_reason == RunTerminationReason.JOB_FAILED + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_retry_running_to_retry_limit_exceeded(self, test_db, session: AsyncSession): + run = await make_run( + session, + status=RunStatus.RUNNING, + retry=ProfileRetry(duration=180, on_events=[RetryEvent.NO_CAPACITY]), + ) + now = run.submitted_at + datetime.timedelta(minutes=10) + await create_job( + session=session, + run=run, + status=JobStatus.FAILED, + termination_reason=JobTerminationReason.EXECUTOR_ERROR, + last_processed_at=now - datetime.timedelta(minutes=4), + replica_num=0, + job_provisioning_data=get_job_provisioning_data(), + ) + await create_job( + session=session, + run=run, + status=JobStatus.FAILED, + termination_reason=JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY, + replica_num=0, + submission_num=1, + last_processed_at=now - datetime.timedelta(minutes=2), + job_provisioning_data=None, + ) + + with patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock: + datetime_mock.return_value = now + await process_runs.process_runs() + + await session.refresh(run) + assert run.status == RunStatus.TERMINATING + assert run.termination_reason == RunTerminationReason.RETRY_LIMIT_EXCEEDED + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_calculates_retry_duration_since_last_successful_submission( @@ -283,6 +327,53 @@ async def test_pending_to_submitted(self, test_db, session: AsyncSession): assert run.jobs[0].status == JobStatus.FAILED assert run.jobs[1].status == JobStatus.SUBMITTED + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_stops_when_master_job_done(self, test_db, session: AsyncSession): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec( + repo_id=repo.name, + configuration=TaskConfiguration( + commands=["echo hello"], + nodes=2, + ), + profile=Profile( + name="test-profile", + stop_criteria=StopCriteria.MASTER_DONE, + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=run_spec, + status=RunStatus.RUNNING, + ) + await create_job( + session=session, + run=run, + status=JobStatus.DONE, + termination_reason=JobTerminationReason.DONE_BY_RUNNER, + replica_num=0, + job_num=0, + ) + await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=0, + job_num=1, + ) + + await process_runs.process_runs() + + await session.refresh(run) + assert run.status == RunStatus.TERMINATING + assert run.termination_reason == RunTerminationReason.ALL_JOBS_DONE + class TestProcessRunsReplicas: @pytest.mark.asyncio @@ -478,6 +569,64 @@ async def test_considers_replicas_inactive_only_when_all_jobs_done( # Should not create new replica with new jobs assert len(run.jobs) == 2 + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_retrying_multinode_replica_terminates_active_sibling_jobs( + self, + test_db, + session: AsyncSession, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec( + repo_id=repo.name, + configuration=TaskConfiguration( + commands=["echo hello"], + nodes=2, + ), + profile=Profile( + name="test-profile", + retry=ProfileRetry(duration="10m", on_events=[RetryEvent.NO_CAPACITY]), + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=run_spec, + status=RunStatus.RUNNING, + ) + failed_job = await create_job( + session=session, + run=run, + status=JobStatus.FAILED, + termination_reason=JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY, + replica_num=0, + job_num=0, + ) + running_job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=0, + job_num=1, + job_provisioning_data=get_job_provisioning_data(), + ) + + with patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock: + datetime_mock.return_value = run.submitted_at + datetime.timedelta(minutes=1) + await process_runs.process_runs() + + await session.refresh(run) + await session.refresh(failed_job) + await session.refresh(running_job) + assert run.status == RunStatus.PENDING + assert failed_job.status == JobStatus.FAILED + assert running_job.status == JobStatus.TERMINATING + assert running_job.termination_reason == JobTerminationReason.TERMINATED_BY_SERVER + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_pending_to_submitted_adds_replicas(self, test_db, session: AsyncSession): @@ -708,6 +857,82 @@ async def test_not_updates_deployment_num_in_place_for_finished_replica( assert len(events[0].targets) == 1 assert events[0].targets[0].entity_id == run.jobs[0].id + async def test_terminates_replicas_from_removed_group(self, test_db, session: AsyncSession): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec( + repo_id=repo.name, + configuration=ServiceConfiguration( + port=8000, + image="ubuntu:latest", + replicas=[ + ReplicaGroup( + name="group-a", + count=parse_obj_as(Range[int], 1), + commands=["echo group-a"], + ), + ReplicaGroup( + name="group-b", + count=parse_obj_as(Range[int], 1), + commands=["echo group-b"], + ), + ], + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=run_spec, + status=RunStatus.RUNNING, + ) + group_a_job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=0, + registered=True, + ) + group_b_job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=1, + registered=True, + ) + + group_a_job_spec = cast(JobSpec, JobSpec.__response__.parse_raw(group_a_job.job_spec_data)) + group_a_job_spec.replica_group = "group-a" + group_a_job.job_spec_data = group_a_job_spec.json() + + group_b_job_spec = cast(JobSpec, JobSpec.__response__.parse_raw(group_b_job.job_spec_data)) + group_b_job_spec.replica_group = "group-b" + group_b_job.job_spec_data = group_b_job_spec.json() + + updated_run_spec: RunSpec = RunSpec.__response__.parse_raw(run.run_spec) + assert isinstance(updated_run_spec.configuration, ServiceConfiguration) + updated_run_spec.configuration.replicas = [ + ReplicaGroup( + name="group-a", + count=parse_obj_as(Range[int], 1), + commands=["echo group-a"], + ) + ] + run.run_spec = updated_run_spec.json() + await session.commit() + + await process_runs.process_runs() + + await session.refresh(run) + await session.refresh(group_a_job) + await session.refresh(group_b_job) + assert run.status == RunStatus.RUNNING + assert group_a_job.status == JobStatus.RUNNING + assert group_b_job.status == JobStatus.TERMINATING + assert group_b_job.termination_reason == JobTerminationReason.SCALED_DOWN + async def test_starts_new_replica(self, test_db, session: AsyncSession) -> None: run = await make_run(session, status=RunStatus.RUNNING, replicas=2, image="old") for replica_num in range(2): From aad98ec9ebfb1b29f8896515c5f9df638c12e754 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 18 Mar 2026 15:32:04 +0500 Subject: [PATCH 03/14] Refactor _process_active_run loop --- .../server/background/scheduled_tasks/runs.py | 414 ++++++++++++------ 1 file changed, 287 insertions(+), 127 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/runs.py b/src/dstack/_internal/server/background/scheduled_tasks/runs.py index 0e1e0370f..d009a749d 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/runs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/runs.py @@ -1,7 +1,8 @@ import asyncio import datetime import json -from typing import List, Optional, Set, Tuple +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Set, Tuple from sqlalchemy import and_, func, or_, select from sqlalchemy.ext.asyncio import AsyncSession @@ -31,7 +32,6 @@ ) from dstack._internal.server.services import events from dstack._internal.server.services.jobs import ( - find_job, get_job_spec, get_job_specs_from_run_spec, group_jobs_by_replica_latest, @@ -297,6 +297,30 @@ def _get_retry_delay(resubmission_attempt: int) -> datetime.timedelta: return _PENDING_RETRY_DELAYS[-1] +@dataclass +class _ReplicaAnalysis: + replica_num: int + job_models: List[JobModel] + replica_info: autoscalers.ReplicaInfo + replica_statuses: Set[RunStatus] = field(default_factory=set) + run_termination_reasons: Set[RunTerminationReason] = field(default_factory=set) + replica_needs_retry: bool = False + + +@dataclass +class _ActiveRunAnalysis: + run_statuses: Set[RunStatus] = field(default_factory=set) + run_termination_reasons: Set[RunTerminationReason] = field(default_factory=set) + replicas_to_retry: List[Tuple[int, List[JobModel]]] = field(default_factory=list) + replicas_info: List[autoscalers.ReplicaInfo] = field(default_factory=list) + + +@dataclass +class _ActiveRunTransition: + new_status: RunStatus + termination_reason: Optional[RunTerminationReason] = None + + async def _process_active_run(session: AsyncSession, run_model: RunModel): """ Run is submitted, provisioning, or running. @@ -305,149 +329,285 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel): run = run_model_to_run(run_model) run_spec = run.run_spec retry_single_job = _can_retry_single_job(run_spec) + _maybe_set_run_fleet_id_from_jobs(run_model) + run_jobs_by_position = _get_run_jobs_by_position(run) + analysis = await _analyze_active_run( + session=session, + run_model=run_model, + run=run, + run_jobs_by_position=run_jobs_by_position, + retry_single_job=retry_single_job, + ) + transition = _get_active_run_transition(run, analysis) + await _apply_active_run_transition( + session=session, + run_model=run_model, + run_spec=run_spec, + transition=transition, + replicas_to_retry=analysis.replicas_to_retry, + retry_single_job=retry_single_job, + replicas_info=analysis.replicas_info, + ) - run_statuses: Set[RunStatus] = set() - run_termination_reasons: Set[RunTerminationReason] = set() - replicas_to_retry: List[Tuple[int, List[JobModel]]] = [] - replicas_info: List[autoscalers.ReplicaInfo] = [] +def _get_run_jobs_by_position(run: Run) -> Dict[Tuple[int, int], Job]: + return {(job.job_spec.replica_num, job.job_spec.job_num): job for job in run.jobs} + + +async def _analyze_active_run( + session: AsyncSession, + run_model: RunModel, + run: Run, + run_jobs_by_position: Dict[Tuple[int, int], Job], + retry_single_job: bool, +) -> _ActiveRunAnalysis: + analysis = _ActiveRunAnalysis() for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs): - replica_statuses: Set[RunStatus] = set() - replica_needs_retry = False - replica_active = True - jobs_done_num = 0 - for job_model in job_models: - job = find_job(run.jobs, job_model.replica_num, job_model.job_num) - if ( - run_model.fleet_id is None - and job_model.instance is not None - and job_model.instance.fleet_id is not None - ): - run_model.fleet_id = job_model.instance.fleet_id - if job_model.status == JobStatus.DONE or ( - job_model.status == JobStatus.TERMINATING - and job_model.termination_reason == JobTerminationReason.DONE_BY_RUNNER - ): - # the job is done or going to be done - replica_statuses.add(RunStatus.DONE) - jobs_done_num += 1 - elif job_model.termination_reason == JobTerminationReason.SCALED_DOWN: - # the job was scaled down - replica_active = False - elif job_model.status == JobStatus.RUNNING: - # the job is running - replica_statuses.add(RunStatus.RUNNING) - elif job_model.status in {JobStatus.PROVISIONING, JobStatus.PULLING}: - # the job is provisioning - replica_statuses.add(RunStatus.PROVISIONING) - elif job_model.status == JobStatus.SUBMITTED: - # the job is submitted - replica_statuses.add(RunStatus.SUBMITTED) - elif job_model.status == JobStatus.FAILED or ( - job_model.status - in [JobStatus.TERMINATING, JobStatus.TERMINATED, JobStatus.ABORTED] - and job_model.termination_reason - not in {JobTerminationReason.DONE_BY_RUNNER, JobTerminationReason.SCALED_DOWN} - ): - current_duration = await _should_retry_job(session, run, job, job_model) - if current_duration is None: - replica_statuses.add(RunStatus.FAILED) - run_termination_reasons.add(RunTerminationReason.JOB_FAILED) - else: - if _is_retry_duration_exceeded(job, current_duration): - replica_statuses.add(RunStatus.FAILED) - run_termination_reasons.add(RunTerminationReason.RETRY_LIMIT_EXCEEDED) - else: - replica_needs_retry = True - else: - raise ValueError(f"Unexpected job status {job_model.status}") + replica_analysis = await _analyze_active_run_replica( + session=session, + run_model=run_model, + run=run, + run_jobs_by_position=run_jobs_by_position, + replica_num=replica_num, + job_models=job_models, + ) + _update_active_run_analysis(analysis, replica_analysis, retry_single_job) + return analysis - if RunStatus.FAILED in replica_statuses: - run_statuses.add(RunStatus.FAILED) - else: - if replica_needs_retry: - replicas_to_retry.append((replica_num, job_models)) - if not replica_needs_retry or retry_single_job: - run_statuses.update(replica_statuses) - - if jobs_done_num == len(job_models): - # Consider replica inactive if all its jobs are done for some reason. - # If only some jobs are done, replica is considered active to avoid - # provisioning new replicas for partially done multi-node tasks. + +async def _analyze_active_run_replica( + session: AsyncSession, + run_model: RunModel, + run: Run, + run_jobs_by_position: Dict[Tuple[int, int], Job], + replica_num: int, + job_models: List[JobModel], +) -> _ReplicaAnalysis: + replica_statuses: Set[RunStatus] = set() + run_termination_reasons: Set[RunTerminationReason] = set() + replica_needs_retry = False + replica_active = True + jobs_done_num = 0 + + for job_model in job_models: + job = run_jobs_by_position[(job_model.replica_num, job_model.job_num)] + + if _job_is_done_or_finishing_done(job_model): + replica_statuses.add(RunStatus.DONE) + jobs_done_num += 1 + continue + + if _job_was_scaled_down(job_model): replica_active = False + continue - replica_info = _get_replica_info(job_models, replica_active) - replicas_info.append(replica_info) + replica_status = _get_non_terminal_replica_status(job_model) + if replica_status is not None: + replica_statuses.add(replica_status) + continue - termination_reason: Optional[RunTerminationReason] = None - if RunStatus.FAILED in run_statuses: - new_status = RunStatus.TERMINATING - if RunTerminationReason.JOB_FAILED in run_termination_reasons: + if _job_needs_retry_evaluation(job_model): + current_duration = await _should_retry_job(session, run, job, job_model) + if current_duration is None: + replica_statuses.add(RunStatus.FAILED) + run_termination_reasons.add(RunTerminationReason.JOB_FAILED) + elif _is_retry_duration_exceeded(job, current_duration): + replica_statuses.add(RunStatus.FAILED) + run_termination_reasons.add(RunTerminationReason.RETRY_LIMIT_EXCEEDED) + else: + replica_needs_retry = True + continue + + raise ValueError(f"Unexpected job status {job_model.status}") + + if jobs_done_num == len(job_models): + # Consider replica inactive if all its jobs are done for some reason. + # If only some jobs are done, replica is considered active to avoid + # provisioning new replicas for partially done multi-node tasks. + replica_active = False + + return _ReplicaAnalysis( + replica_num=replica_num, + job_models=job_models, + replica_info=_get_replica_info(job_models, replica_active), + replica_statuses=replica_statuses, + run_termination_reasons=run_termination_reasons, + replica_needs_retry=replica_needs_retry, + ) + + +def _update_active_run_analysis( + analysis: _ActiveRunAnalysis, + replica_analysis: _ReplicaAnalysis, + retry_single_job: bool, +) -> None: + analysis.replicas_info.append(replica_analysis.replica_info) + + if RunStatus.FAILED in replica_analysis.replica_statuses: + analysis.run_statuses.add(RunStatus.FAILED) + analysis.run_termination_reasons.update(replica_analysis.run_termination_reasons) + return + + if replica_analysis.replica_needs_retry: + analysis.replicas_to_retry.append( + (replica_analysis.replica_num, replica_analysis.job_models) + ) + + if not replica_analysis.replica_needs_retry or retry_single_job: + analysis.run_statuses.update(replica_analysis.replica_statuses) + + +def _maybe_set_run_fleet_id_from_jobs(run_model: RunModel) -> None: + """ + The master job gets fleet assigned with the instance. + The run then gets from the master job's instance, and non-master jobs wait for the run's fleet to be assigned. + """ + if run_model.fleet_id is not None: + return + + for job_model in run_model.jobs: + if job_model.instance is not None and job_model.instance.fleet_id is not None: + run_model.fleet_id = job_model.instance.fleet_id + return + + +def _job_is_done_or_finishing_done(job_model: JobModel) -> bool: + return job_model.status == JobStatus.DONE or ( + job_model.status == JobStatus.TERMINATING + and job_model.termination_reason == JobTerminationReason.DONE_BY_RUNNER + ) + + +def _job_was_scaled_down(job_model: JobModel) -> bool: + return job_model.termination_reason == JobTerminationReason.SCALED_DOWN + + +def _get_non_terminal_replica_status(job_model: JobModel) -> Optional[RunStatus]: + if job_model.status == JobStatus.RUNNING: + return RunStatus.RUNNING + if job_model.status in {JobStatus.PROVISIONING, JobStatus.PULLING}: + return RunStatus.PROVISIONING + if job_model.status == JobStatus.SUBMITTED: + return RunStatus.SUBMITTED + return None + + +def _job_needs_retry_evaluation(job_model: JobModel) -> bool: + return job_model.status == JobStatus.FAILED or ( + job_model.status in [JobStatus.TERMINATING, JobStatus.TERMINATED, JobStatus.ABORTED] + and job_model.termination_reason + not in {JobTerminationReason.DONE_BY_RUNNER, JobTerminationReason.SCALED_DOWN} + ) + + +def _get_active_run_transition(run: Run, analysis: _ActiveRunAnalysis) -> _ActiveRunTransition: + if RunStatus.FAILED in analysis.run_statuses: + if RunTerminationReason.JOB_FAILED in analysis.run_termination_reasons: termination_reason = RunTerminationReason.JOB_FAILED - elif RunTerminationReason.RETRY_LIMIT_EXCEEDED in run_termination_reasons: + elif RunTerminationReason.RETRY_LIMIT_EXCEEDED in analysis.run_termination_reasons: termination_reason = RunTerminationReason.RETRY_LIMIT_EXCEEDED else: - raise ValueError(f"Unexpected termination reason {run_termination_reasons}") - elif _should_stop_on_master_done(run): - new_status = RunStatus.TERMINATING + raise ValueError(f"Unexpected termination reason {analysis.run_termination_reasons}") + return _ActiveRunTransition( + new_status=RunStatus.TERMINATING, + termination_reason=termination_reason, + ) + + if _should_stop_on_master_done(run): # ALL_JOBS_DONE is used for all DONE reasons including master-done - termination_reason = RunTerminationReason.ALL_JOBS_DONE - elif RunStatus.RUNNING in run_statuses: - new_status = RunStatus.RUNNING - elif RunStatus.PROVISIONING in run_statuses: - new_status = RunStatus.PROVISIONING - elif RunStatus.SUBMITTED in run_statuses: - new_status = RunStatus.SUBMITTED - elif RunStatus.DONE in run_statuses and not replicas_to_retry: - new_status = RunStatus.TERMINATING - termination_reason = RunTerminationReason.ALL_JOBS_DONE - else: - new_status = RunStatus.PENDING + return _ActiveRunTransition( + new_status=RunStatus.TERMINATING, + termination_reason=RunTerminationReason.ALL_JOBS_DONE, + ) - # Terminate active jobs if the run is to be resubmitted - if new_status == RunStatus.PENDING and not retry_single_job: - for _, replica_jobs in replicas_to_retry: - for job_model in replica_jobs: - if not ( - job_model.status.is_finished() or job_model.status == JobStatus.TERMINATING - ): - job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER - job_model.termination_reason_message = "Run is to be resubmitted" - switch_job_status(session, job_model, JobStatus.TERMINATING) - - if new_status not in {RunStatus.TERMINATING, RunStatus.PENDING}: + if RunStatus.RUNNING in analysis.run_statuses: + return _ActiveRunTransition(new_status=RunStatus.RUNNING) + if RunStatus.PROVISIONING in analysis.run_statuses: + return _ActiveRunTransition(new_status=RunStatus.PROVISIONING) + if RunStatus.SUBMITTED in analysis.run_statuses: + return _ActiveRunTransition(new_status=RunStatus.SUBMITTED) + if RunStatus.DONE in analysis.run_statuses and not analysis.replicas_to_retry: + return _ActiveRunTransition( + new_status=RunStatus.TERMINATING, + termination_reason=RunTerminationReason.ALL_JOBS_DONE, + ) + return _ActiveRunTransition(new_status=RunStatus.PENDING) + + +async def _apply_active_run_transition( + session: AsyncSession, + run_model: RunModel, + run_spec: RunSpec, + transition: _ActiveRunTransition, + replicas_to_retry: List[Tuple[int, List[JobModel]]], + retry_single_job: bool, + replicas_info: List[autoscalers.ReplicaInfo], +) -> None: + if transition.new_status == RunStatus.PENDING and not retry_single_job: + _terminate_retrying_replica_jobs(session, replicas_to_retry) + + if transition.new_status not in {RunStatus.TERMINATING, RunStatus.PENDING}: # No need to retry, scale, or redeploy replicas if the run is terminating, # pending run will retry replicas in `process_pending_run` await _handle_run_replicas( - session, run_model, run_spec, replicas_to_retry, retry_single_job, replicas_info + session, + run_model, + run_spec, + replicas_to_retry, + retry_single_job, + replicas_info, ) - if run_model.status != new_status: - if run_model.status == RunStatus.SUBMITTED and new_status == RunStatus.PROVISIONING: - current_time = common.get_current_datetime() - submit_to_provision_duration = (current_time - run_model.submitted_at).total_seconds() - logger.info( - "%s: run took %.2f seconds from submission to provisioning.", - fmt(run_model), - submit_to_provision_duration, - ) - project_name = run_model.project.name - run_metrics.log_submit_to_provision_duration( - submit_to_provision_duration, project_name, run_spec.configuration.type - ) + _maybe_switch_active_run_status(session, run_model, run_spec, transition) + + +def _terminate_retrying_replica_jobs( + session: AsyncSession, + replicas_to_retry: List[Tuple[int, List[JobModel]]], +) -> None: + for _, replica_jobs in replicas_to_retry: + for job_model in replica_jobs: + if job_model.status.is_finished() or job_model.status == JobStatus.TERMINATING: + continue + job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER + job_model.termination_reason_message = "Run is to be resubmitted" + switch_job_status(session, job_model, JobStatus.TERMINATING) + + +def _maybe_switch_active_run_status( + session: AsyncSession, + run_model: RunModel, + run_spec: RunSpec, + transition: _ActiveRunTransition, +) -> None: + if run_model.status == transition.new_status: + return + + if run_model.status == RunStatus.SUBMITTED and transition.new_status == RunStatus.PROVISIONING: + current_time = common.get_current_datetime() + submit_to_provision_duration = (current_time - run_model.submitted_at).total_seconds() + logger.info( + "%s: run took %.2f seconds from submission to provisioning.", + fmt(run_model), + submit_to_provision_duration, + ) + project_name = run_model.project.name + run_metrics.log_submit_to_provision_duration( + submit_to_provision_duration, project_name, run_spec.configuration.type + ) - if new_status == RunStatus.PENDING: - run_metrics.increment_pending_runs(run_model.project.name, run_spec.configuration.type) - # Unassign run from fleet so that the new fleet can be chosen when retrying - run_model.fleet = None - - run_model.termination_reason = termination_reason - switch_run_status(session, run_model, new_status) - # While a run goes to pending without provisioning, resubmission_attempt increases. - if new_status == RunStatus.PROVISIONING: - run_model.resubmission_attempt = 0 - elif new_status == RunStatus.PENDING: - run_model.resubmission_attempt += 1 + if transition.new_status == RunStatus.PENDING: + run_metrics.increment_pending_runs(run_model.project.name, run_spec.configuration.type) + # Unassign run from fleet so that the new fleet can be chosen when retrying + run_model.fleet = None + + run_model.termination_reason = transition.termination_reason + switch_run_status(session=session, run_model=run_model, new_status=transition.new_status) + # While a run goes to pending without provisioning, resubmission_attempt increases. + if transition.new_status == RunStatus.PROVISIONING: + run_model.resubmission_attempt = 0 + elif transition.new_status == RunStatus.PENDING: + run_model.resubmission_attempt += 1 def _get_replica_info( From 051d5af0c0bd2b363cc7818166d33ca34cb1f119 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 18 Mar 2026 15:59:03 +0500 Subject: [PATCH 04/14] Document analysis classes --- .../server/background/scheduled_tasks/runs.py | 83 ++++++++++++------- 1 file changed, 51 insertions(+), 32 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/runs.py b/src/dstack/_internal/server/background/scheduled_tasks/runs.py index d009a749d..93459cffb 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/runs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/runs.py @@ -299,20 +299,38 @@ def _get_retry_delay(resubmission_attempt: int) -> datetime.timedelta: @dataclass class _ReplicaAnalysis: + """Per-replica classification of job states for determining the run's next status.""" + replica_num: int job_models: List[JobModel] replica_info: autoscalers.ReplicaInfo - replica_statuses: Set[RunStatus] = field(default_factory=set) - run_termination_reasons: Set[RunTerminationReason] = field(default_factory=set) - replica_needs_retry: bool = False + contributed_statuses: Set[RunStatus] = field(default_factory=set) + """`RunStatus` values derived from this replica's jobs. Merged into the run-level + analysis unless the replica is being retried as a whole.""" + termination_reasons: Set[RunTerminationReason] = field(default_factory=set) + """Why the replica failed. Only populated when `FAILED` is in `contributed_statuses`.""" + needs_retry: bool = False + """At least one job failed with a retryable reason and the retry duration hasn't been + exceeded. When `True`, the replica does not contribute its statuses to the run-level + analysis (unless `retry_single_job` is enabled) and is added to `replicas_to_retry` instead.""" @dataclass class _ActiveRunAnalysis: - run_statuses: Set[RunStatus] = field(default_factory=set) - run_termination_reasons: Set[RunTerminationReason] = field(default_factory=set) + """Aggregated replica analysis used to determine the run's next status. + + Each replica contributes `RunStatus` based on its jobs. + The run's new status is the highest-priority value across all + contributing replicas: FAILED > RUNNING > PROVISIONING > SUBMITTED > DONE. + Replicas that need full retry do not contribute and instead cause a PENDING transition. + """ + + contributed_statuses: Set[RunStatus] = field(default_factory=set) + termination_reasons: Set[RunTerminationReason] = field(default_factory=set) replicas_to_retry: List[Tuple[int, List[JobModel]]] = field(default_factory=list) + """Replicas with retryable failures that haven't exceeded the retry duration.""" replicas_info: List[autoscalers.ReplicaInfo] = field(default_factory=list) + """Per-replica active/inactive info for the autoscaler.""" @dataclass @@ -383,9 +401,9 @@ async def _analyze_active_run_replica( replica_num: int, job_models: List[JobModel], ) -> _ReplicaAnalysis: - replica_statuses: Set[RunStatus] = set() - run_termination_reasons: Set[RunTerminationReason] = set() - replica_needs_retry = False + contributed_statuses: Set[RunStatus] = set() + termination_reasons: Set[RunTerminationReason] = set() + needs_retry = False replica_active = True jobs_done_num = 0 @@ -393,7 +411,7 @@ async def _analyze_active_run_replica( job = run_jobs_by_position[(job_model.replica_num, job_model.job_num)] if _job_is_done_or_finishing_done(job_model): - replica_statuses.add(RunStatus.DONE) + contributed_statuses.add(RunStatus.DONE) jobs_done_num += 1 continue @@ -403,19 +421,19 @@ async def _analyze_active_run_replica( replica_status = _get_non_terminal_replica_status(job_model) if replica_status is not None: - replica_statuses.add(replica_status) + contributed_statuses.add(replica_status) continue if _job_needs_retry_evaluation(job_model): current_duration = await _should_retry_job(session, run, job, job_model) if current_duration is None: - replica_statuses.add(RunStatus.FAILED) - run_termination_reasons.add(RunTerminationReason.JOB_FAILED) + contributed_statuses.add(RunStatus.FAILED) + termination_reasons.add(RunTerminationReason.JOB_FAILED) elif _is_retry_duration_exceeded(job, current_duration): - replica_statuses.add(RunStatus.FAILED) - run_termination_reasons.add(RunTerminationReason.RETRY_LIMIT_EXCEEDED) + contributed_statuses.add(RunStatus.FAILED) + termination_reasons.add(RunTerminationReason.RETRY_LIMIT_EXCEEDED) else: - replica_needs_retry = True + needs_retry = True continue raise ValueError(f"Unexpected job status {job_model.status}") @@ -430,9 +448,9 @@ async def _analyze_active_run_replica( replica_num=replica_num, job_models=job_models, replica_info=_get_replica_info(job_models, replica_active), - replica_statuses=replica_statuses, - run_termination_reasons=run_termination_reasons, - replica_needs_retry=replica_needs_retry, + contributed_statuses=contributed_statuses, + termination_reasons=termination_reasons, + needs_retry=needs_retry, ) @@ -443,18 +461,18 @@ def _update_active_run_analysis( ) -> None: analysis.replicas_info.append(replica_analysis.replica_info) - if RunStatus.FAILED in replica_analysis.replica_statuses: - analysis.run_statuses.add(RunStatus.FAILED) - analysis.run_termination_reasons.update(replica_analysis.run_termination_reasons) + if RunStatus.FAILED in replica_analysis.contributed_statuses: + analysis.contributed_statuses.add(RunStatus.FAILED) + analysis.termination_reasons.update(replica_analysis.termination_reasons) return - if replica_analysis.replica_needs_retry: + if replica_analysis.needs_retry: analysis.replicas_to_retry.append( (replica_analysis.replica_num, replica_analysis.job_models) ) - if not replica_analysis.replica_needs_retry or retry_single_job: - analysis.run_statuses.update(replica_analysis.replica_statuses) + if not replica_analysis.needs_retry or retry_single_job: + analysis.contributed_statuses.update(replica_analysis.contributed_statuses) def _maybe_set_run_fleet_id_from_jobs(run_model: RunModel) -> None: @@ -501,13 +519,13 @@ def _job_needs_retry_evaluation(job_model: JobModel) -> bool: def _get_active_run_transition(run: Run, analysis: _ActiveRunAnalysis) -> _ActiveRunTransition: - if RunStatus.FAILED in analysis.run_statuses: - if RunTerminationReason.JOB_FAILED in analysis.run_termination_reasons: + if RunStatus.FAILED in analysis.contributed_statuses: + if RunTerminationReason.JOB_FAILED in analysis.termination_reasons: termination_reason = RunTerminationReason.JOB_FAILED - elif RunTerminationReason.RETRY_LIMIT_EXCEEDED in analysis.run_termination_reasons: + elif RunTerminationReason.RETRY_LIMIT_EXCEEDED in analysis.termination_reasons: termination_reason = RunTerminationReason.RETRY_LIMIT_EXCEEDED else: - raise ValueError(f"Unexpected termination reason {analysis.run_termination_reasons}") + raise ValueError(f"Unexpected termination reason {analysis.termination_reasons}") return _ActiveRunTransition( new_status=RunStatus.TERMINATING, termination_reason=termination_reason, @@ -520,17 +538,18 @@ def _get_active_run_transition(run: Run, analysis: _ActiveRunAnalysis) -> _Activ termination_reason=RunTerminationReason.ALL_JOBS_DONE, ) - if RunStatus.RUNNING in analysis.run_statuses: + if RunStatus.RUNNING in analysis.contributed_statuses: return _ActiveRunTransition(new_status=RunStatus.RUNNING) - if RunStatus.PROVISIONING in analysis.run_statuses: + if RunStatus.PROVISIONING in analysis.contributed_statuses: return _ActiveRunTransition(new_status=RunStatus.PROVISIONING) - if RunStatus.SUBMITTED in analysis.run_statuses: + if RunStatus.SUBMITTED in analysis.contributed_statuses: return _ActiveRunTransition(new_status=RunStatus.SUBMITTED) - if RunStatus.DONE in analysis.run_statuses and not analysis.replicas_to_retry: + if RunStatus.DONE in analysis.contributed_statuses and not analysis.replicas_to_retry: return _ActiveRunTransition( new_status=RunStatus.TERMINATING, termination_reason=RunTerminationReason.ALL_JOBS_DONE, ) + # All contributing replicas need full retry — resubmit the entire run. return _ActiveRunTransition(new_status=RunStatus.PENDING) From 96ff0da8493d8f47eadd0d82df250b6ac38c7a09 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 18 Mar 2026 16:19:14 +0500 Subject: [PATCH 05/14] Fix implicit RunStatus.PENDING --- .../_internal/server/background/scheduled_tasks/runs.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/runs.py b/src/dstack/_internal/server/background/scheduled_tasks/runs.py index 93459cffb..980b9bc47 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/runs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/runs.py @@ -549,8 +549,12 @@ def _get_active_run_transition(run: Run, analysis: _ActiveRunAnalysis) -> _Activ new_status=RunStatus.TERMINATING, termination_reason=RunTerminationReason.ALL_JOBS_DONE, ) - # All contributing replicas need full retry — resubmit the entire run. - return _ActiveRunTransition(new_status=RunStatus.PENDING) + if not analysis.contributed_statuses or analysis.contributed_statuses == {RunStatus.DONE}: + # No active replicas remain — resubmit the entire run. + # `contributed_statuses` is either empty (every replica is retrying) or contains + # only DONE (some replicas finished, others need retry). + return _ActiveRunTransition(new_status=RunStatus.PENDING) + raise ServerError("Failed to determine run transition: unexpected active run state") async def _apply_active_run_transition( From 81dce21f564fdc2f381f952281ca2d54d67b86bc Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 18 Mar 2026 16:25:22 +0500 Subject: [PATCH 06/14] Replace unhandled ValueError with ServerError and fix broken assert --- .../_internal/server/background/scheduled_tasks/runs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/runs.py b/src/dstack/_internal/server/background/scheduled_tasks/runs.py index 980b9bc47..47db798f9 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/runs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/runs.py @@ -436,7 +436,7 @@ async def _analyze_active_run_replica( needs_retry = True continue - raise ValueError(f"Unexpected job status {job_model.status}") + raise ServerError(f"Unexpected job status {job_model.status}") if jobs_done_num == len(job_models): # Consider replica inactive if all its jobs are done for some reason. @@ -525,7 +525,7 @@ def _get_active_run_transition(run: Run, analysis: _ActiveRunAnalysis) -> _Activ elif RunTerminationReason.RETRY_LIMIT_EXCEEDED in analysis.termination_reasons: termination_reason = RunTerminationReason.RETRY_LIMIT_EXCEEDED else: - raise ValueError(f"Unexpected termination reason {analysis.termination_reasons}") + raise ServerError(f"Unexpected termination reason {analysis.termination_reasons}") return _ActiveRunTransition( new_status=RunStatus.TERMINATING, termination_reason=termination_reason, @@ -757,7 +757,7 @@ async def _handle_run_replicas( # Currently, only services can change job spec on update, # so for other runs out-of-date replicas are not possible. # Keeping assert in case this changes. - assert "Rolling deployment is only supported for services" + assert False, "Rolling deployment is only supported for services" async def _update_jobs_to_new_deployment_in_place( From 73f784e42d3cf812a4018c1048c6ab0e63fdd87d Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 18 Mar 2026 16:34:22 +0500 Subject: [PATCH 07/14] Drop lazy imports --- .../server/background/scheduled_tasks/runs.py | 12 ++++-------- .../_internal/server/services/runs/replicas.py | 2 +- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/runs.py b/src/dstack/_internal/server/background/scheduled_tasks/runs.py index 47db798f9..efd3d14fc 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/runs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/runs.py @@ -55,7 +55,8 @@ retry_run_replica_jobs, scale_down_replicas, scale_run_replicas, - scale_run_replicas_per_group, + scale_run_replicas_for_all_groups, + scale_run_replicas_for_group, ) from dstack._internal.server.services.secrets import get_project_secrets_mapping from dstack._internal.server.services.services import update_service_desired_replica_count @@ -257,7 +258,7 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel): replicas: List[ReplicaGroup] = run.run_spec.configuration.replica_groups - await scale_run_replicas_per_group(session, run_model, replicas) + await scale_run_replicas_for_all_groups(session, run_model, replicas) else: # Non-service pending runs may have 0 job submissions and require new submission, e.g. scheduled tasks. run_model.desired_replica_count = 1 @@ -685,7 +686,7 @@ async def _handle_run_replicas( replicas: List[ReplicaGroup] = run_spec.configuration.replica_groups assert replicas, "replica groups should always return at least one group" - await scale_run_replicas_per_group(session, run_model, replicas) + await scale_run_replicas_for_all_groups(session, run_model, replicas) # Handle per-group rolling deployment await _update_jobs_to_new_deployment_in_place( @@ -899,11 +900,6 @@ async def _handle_rolling_deployment_for_group( """ Handle rolling deployment for a single replica group. """ - from dstack._internal.server.services.runs.replicas import ( - build_replica_lists, - scale_run_replicas_for_group, - ) - desired_replica_counts = ( json.loads(run_model.desired_replica_counts) if run_model.desired_replica_counts else {} ) diff --git a/src/dstack/_internal/server/services/runs/replicas.py b/src/dstack/_internal/server/services/runs/replicas.py index ffb7bd216..dffc2d81b 100644 --- a/src/dstack/_internal/server/services/runs/replicas.py +++ b/src/dstack/_internal/server/services/runs/replicas.py @@ -220,7 +220,7 @@ async def _scale_up_replicas( run_model.jobs.append(job_model) -async def scale_run_replicas_per_group( +async def scale_run_replicas_for_all_groups( session: AsyncSession, run_model: RunModel, replicas: List[ReplicaGroup], From 58a3c9f221365492abae50d49bb12e4e8017d966 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 18 Mar 2026 16:40:51 +0500 Subject: [PATCH 08/14] Drop redundant except Exception --- .../_internal/server/background/scheduled_tasks/runs.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/runs.py b/src/dstack/_internal/server/background/scheduled_tasks/runs.py index efd3d14fc..497fb41d9 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/runs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/runs.py @@ -705,11 +705,8 @@ async def _handle_run_replicas( for job in run_model.jobs: if job.status.is_finished(): continue - try: - job_spec = get_job_spec(job) - existing_group_names.add(job_spec.replica_group) - except Exception: - continue + job_spec = get_job_spec(job) + existing_group_names.add(job_spec.replica_group) new_group_names = {group.name for group in replicas} removed_group_names = existing_group_names - new_group_names for removed_group_name in removed_group_names: From f1d01806db6a4e90e824865385c368134e59b37e Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 19 Mar 2026 09:56:59 +0500 Subject: [PATCH 09/14] Drop dummy non-services code in _handle_run_replicas --- .../server/background/scheduled_tasks/runs.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/runs.py b/src/dstack/_internal/server/background/scheduled_tasks/runs.py index 497fb41d9..fb74e44c3 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/runs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/runs.py @@ -732,20 +732,6 @@ async def _handle_run_replicas( scale_down_replicas(session, inactive_replicas, len(inactive_replicas)) return - max_replica_count = run_model.desired_replica_count - if has_out_of_date_replicas(run_model): - # allow extra replicas when deployment is in progress - max_replica_count += ROLLING_DEPLOYMENT_MAX_SURGE - - active_replica_count = sum(1 for r in replicas_info if r.active) - if active_replica_count not in range(run_model.desired_replica_count, max_replica_count + 1): - await scale_run_replicas( - session, - run_model, - replicas_diff=run_model.desired_replica_count - active_replica_count, - ) - return - await _update_jobs_to_new_deployment_in_place( session=session, run_model=run_model, From ecb8ba7339b30d390ad762571306bc5a187c160b Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 19 Mar 2026 10:21:16 +0500 Subject: [PATCH 10/14] Extracted _terminate_removed_replica_groups --- .../server/background/scheduled_tasks/runs.py | 100 +++++++++--------- 1 file changed, 48 insertions(+), 52 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/runs.py b/src/dstack/_internal/server/background/scheduled_tasks/runs.py index fb74e44c3..32567fd2e 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/runs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/runs.py @@ -320,7 +320,7 @@ class _ReplicaAnalysis: class _ActiveRunAnalysis: """Aggregated replica analysis used to determine the run's next status. - Each replica contributes `RunStatus` based on its jobs. + Each replica contributes `RunStatus` based on its jobs' statuses. The run's new status is the highest-priority value across all contributing replicas: FAILED > RUNNING > PROVISIONING > SUBMITTED > DONE. Replicas that need full retry do not contribute and instead cause a PENDING transition. @@ -520,6 +520,7 @@ def _job_needs_retry_evaluation(job_model: JobModel) -> bool: def _get_active_run_transition(run: Run, analysis: _ActiveRunAnalysis) -> _ActiveRunTransition: + # Check `analysis.contributed_statuses` in the priority order. if RunStatus.FAILED in analysis.contributed_statuses: if RunTerminationReason.JOB_FAILED in analysis.termination_reasons: termination_reason = RunTerminationReason.JOB_FAILED @@ -660,12 +661,10 @@ async def _handle_run_replicas( replicas_info: list[autoscalers.ReplicaInfo], ) -> None: """ - Does ONE of: - - replica retry - - replica scaling - - replica rolling deployment - - Does not do everything at once to avoid conflicts between the stages and long DB transactions. + Performs one or more steps: + - replicas retry + - replicas scaling + - replicas rolling deployment """ if replicas_to_retry: @@ -683,53 +682,26 @@ async def _handle_run_replicas( # FIXME: should only include scaling events, not retries and deployments last_scaled_at=max((r.timestamp for r in replicas_info), default=None), ) - replicas: List[ReplicaGroup] = run_spec.configuration.replica_groups - assert replicas, "replica groups should always return at least one group" + replica_groups: List[ReplicaGroup] = run_spec.configuration.replica_groups + assert replica_groups, "replica groups should always return at least one group" - await scale_run_replicas_for_all_groups(session, run_model, replicas) + await scale_run_replicas_for_all_groups(session, run_model, replica_groups) - # Handle per-group rolling deployment await _update_jobs_to_new_deployment_in_place( session=session, run_model=run_model, run_spec=run_spec, - replicas=replicas, + replicas=replica_groups, ) - # Process per-group rolling deployment - for group in replicas: + + for group in replica_groups: await _handle_rolling_deployment_for_group( session=session, run_model=run_model, group=group, run_spec=run_spec ) - # Terminate replicas from groups that were removed from the configuration - existing_group_names = set() - for job in run_model.jobs: - if job.status.is_finished(): - continue - job_spec = get_job_spec(job) - existing_group_names.add(job_spec.replica_group) - new_group_names = {group.name for group in replicas} - removed_group_names = existing_group_names - new_group_names - for removed_group_name in removed_group_names: - # Build replica lists for this removed group - active_replicas, inactive_replicas = build_replica_lists( - run_model=run_model, - group_filter=removed_group_name, - ) - total_replicas = len(active_replicas) + len(inactive_replicas) - if total_replicas > 0: - logger.info( - "%s: terminating %d replica(s) from removed group '%s'", - fmt(run_model), - total_replicas, - removed_group_name, - ) - # Terminate all active replicas in the removed group - if active_replicas: - scale_down_replicas(session, active_replicas, len(active_replicas)) - # Terminate all inactive replicas in the removed group - if inactive_replicas: - scale_down_replicas(session, inactive_replicas, len(inactive_replicas)) + _terminate_removed_replica_groups( + session=session, run_model=run_model, replica_groups=replica_groups + ) return await _update_jobs_to_new_deployment_in_place( @@ -883,21 +855,15 @@ async def _handle_rolling_deployment_for_group( """ Handle rolling deployment for a single replica group. """ + if not has_out_of_date_replicas(run_model, group_filter=group.name): + return + desired_replica_counts = ( json.loads(run_model.desired_replica_counts) if run_model.desired_replica_counts else {} ) - group_desired = desired_replica_counts.get(group.name, group.count.min or 0) - - # Check if group has out-of-date replicas - if not has_out_of_date_replicas(run_model, group_filter=group.name): - return # Group is up-to-date - - # Calculate max replicas (allow surge during deployment) group_max_replica_count = group_desired + ROLLING_DEPLOYMENT_MAX_SURGE - # Count non-terminated replicas for this group only - non_terminated_replica_count = len( { j.replica_num @@ -970,3 +936,33 @@ async def _handle_rolling_deployment_for_group( active_replicas=active_replicas, inactive_replicas=inactive_replicas, ) + + +def _terminate_removed_replica_groups( + session: AsyncSession, run_model: RunModel, replica_groups: List[ReplicaGroup] +): + existing_group_names = set() + for job in run_model.jobs: + if job.status.is_finished(): + continue + job_spec = get_job_spec(job) + existing_group_names.add(job_spec.replica_group) + new_group_names = {group.name for group in replica_groups} + removed_group_names = existing_group_names - new_group_names + for removed_group_name in removed_group_names: + active_replicas, inactive_replicas = build_replica_lists( + run_model=run_model, + group_filter=removed_group_name, + ) + total_replicas = len(active_replicas) + len(inactive_replicas) + if total_replicas > 0: + logger.info( + "%s: terminating %d replica(s) from removed group '%s'", + fmt(run_model), + total_replicas, + removed_group_name, + ) + if active_replicas: + scale_down_replicas(session, active_replicas, len(active_replicas)) + if inactive_replicas: + scale_down_replicas(session, inactive_replicas, len(inactive_replicas)) From e5f2c0d21bf49a9e55b4663c6666dcfe9cba14df Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 19 Mar 2026 10:36:16 +0500 Subject: [PATCH 11/14] Drop redundant replicas arg from _update_jobs_to_new_deployment_in_place --- .../_internal/server/background/scheduled_tasks/runs.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/runs.py b/src/dstack/_internal/server/background/scheduled_tasks/runs.py index 32567fd2e..dc7cb173a 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/runs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/runs.py @@ -691,7 +691,6 @@ async def _handle_run_replicas( session=session, run_model=run_model, run_spec=run_spec, - replicas=replica_groups, ) for group in replica_groups: @@ -720,7 +719,6 @@ async def _update_jobs_to_new_deployment_in_place( session: AsyncSession, run_model: RunModel, run_spec: RunSpec, - replicas: Optional[List] = None, ) -> None: """ Bump deployment_num for jobs that do not require redeployment. @@ -736,10 +734,8 @@ async def _update_jobs_to_new_deployment_in_place( if all(j.deployment_num == run_model.deployment_num for j in job_models): continue - # Determine which group this replica belongs to replica_group_name = None - - if replicas: + if run_spec.configuration.type == "service": job_spec = get_job_spec(job_models[0]) replica_group_name = job_spec.replica_group From fe16dc9feb6e109be9e4fdbe28b4817535f688c3 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 19 Mar 2026 10:43:25 +0500 Subject: [PATCH 12/14] Minor renaming --- .../server/background/scheduled_tasks/runs.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/runs.py b/src/dstack/_internal/server/background/scheduled_tasks/runs.py index dc7cb173a..aa71e9a11 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/runs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/runs.py @@ -317,7 +317,7 @@ class _ReplicaAnalysis: @dataclass -class _ActiveRunAnalysis: +class _RunAnalysis: """Aggregated replica analysis used to determine the run's next status. Each replica contributes `RunStatus` based on its jobs' statuses. @@ -379,8 +379,8 @@ async def _analyze_active_run( run: Run, run_jobs_by_position: Dict[Tuple[int, int], Job], retry_single_job: bool, -) -> _ActiveRunAnalysis: - analysis = _ActiveRunAnalysis() +) -> _RunAnalysis: + run_analysis = _RunAnalysis() for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs): replica_analysis = await _analyze_active_run_replica( session=session, @@ -390,8 +390,8 @@ async def _analyze_active_run( replica_num=replica_num, job_models=job_models, ) - _update_active_run_analysis(analysis, replica_analysis, retry_single_job) - return analysis + _apply_replica_analysis(run_analysis, replica_analysis, retry_single_job) + return run_analysis async def _analyze_active_run_replica( @@ -455,8 +455,8 @@ async def _analyze_active_run_replica( ) -def _update_active_run_analysis( - analysis: _ActiveRunAnalysis, +def _apply_replica_analysis( + analysis: _RunAnalysis, replica_analysis: _ReplicaAnalysis, retry_single_job: bool, ) -> None: @@ -519,7 +519,7 @@ def _job_needs_retry_evaluation(job_model: JobModel) -> bool: ) -def _get_active_run_transition(run: Run, analysis: _ActiveRunAnalysis) -> _ActiveRunTransition: +def _get_active_run_transition(run: Run, analysis: _RunAnalysis) -> _ActiveRunTransition: # Check `analysis.contributed_statuses` in the priority order. if RunStatus.FAILED in analysis.contributed_statuses: if RunTerminationReason.JOB_FAILED in analysis.termination_reasons: From 3704dd2b5aa230f0abbf4d6b68a8596a054f1809 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 19 Mar 2026 10:53:25 +0500 Subject: [PATCH 13/14] Extract get_group_rollout_state --- .../server/background/scheduled_tasks/runs.py | 80 ++++------------- .../server/services/runs/replicas.py | 90 ++++++++++++++----- 2 files changed, 85 insertions(+), 85 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/runs.py b/src/dstack/_internal/server/background/scheduled_tasks/runs.py index aa71e9a11..bbe5b4256 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/runs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/runs.py @@ -1,6 +1,5 @@ import asyncio import datetime -import json from dataclasses import dataclass, field from typing import Dict, List, Optional, Set, Tuple @@ -49,9 +48,9 @@ ) from dstack._internal.server.services.runs.replicas import ( build_replica_lists, + get_group_desired_replica_count, + get_group_rollout_state, has_out_of_date_replicas, - is_replica_registered, - job_belongs_to_group, retry_run_replica_jobs, scale_down_replicas, scale_run_replicas, @@ -851,86 +850,41 @@ async def _handle_rolling_deployment_for_group( """ Handle rolling deployment for a single replica group. """ - if not has_out_of_date_replicas(run_model, group_filter=group.name): + group_desired = get_group_desired_replica_count(run_model, group) + state = get_group_rollout_state(run_model, group) + if not state.has_out_of_date_replicas: return - desired_replica_counts = ( - json.loads(run_model.desired_replica_counts) if run_model.desired_replica_counts else {} - ) - group_desired = desired_replica_counts.get(group.name, group.count.min or 0) group_max_replica_count = group_desired + ROLLING_DEPLOYMENT_MAX_SURGE - non_terminated_replica_count = len( - { - j.replica_num - for j in run_model.jobs - if not j.status.is_finished() - and group.name is not None - and job_belongs_to_group(job=j, group_name=group.name) - } - ) - # Start new up-to-date replicas if needed - if non_terminated_replica_count < group_max_replica_count: - active_replicas, inactive_replicas = build_replica_lists( - run_model=run_model, - group_filter=group.name, - ) - + if state.non_terminated_replica_count < group_max_replica_count: await scale_run_replicas_for_group( session=session, run_model=run_model, group=group, - replicas_diff=group_max_replica_count - non_terminated_replica_count, + replicas_diff=group_max_replica_count - state.non_terminated_replica_count, run_spec=run_spec, - active_replicas=active_replicas, - inactive_replicas=inactive_replicas, + active_replicas=state.active_replicas, + inactive_replicas=state.inactive_replicas, ) + state = get_group_rollout_state(run_model, group) - # Stop out-of-date replicas that are not registered - replicas_to_stop_count = 0 - for _, jobs in group_jobs_by_replica_latest(run_model.jobs): - assert group.name is not None, "Group name is always set" - if not job_belongs_to_group(jobs[0], group.name): - continue - # Check if replica is out-of-date and not registered - if ( - any(j.deployment_num < run_model.deployment_num for j in jobs) - and any( - j.status not in [JobStatus.TERMINATING] + JobStatus.finished_statuses() - for j in jobs - ) - and not is_replica_registered(jobs) - ): - replicas_to_stop_count += 1 - - # Stop excessive registered out-of-date replicas - non_terminating_registered_replicas_count = 0 - for _, jobs in group_jobs_by_replica_latest(run_model.jobs): - assert group.name is not None, "Group name is always set" - if not job_belongs_to_group(jobs[0], group.name): - continue - - if is_replica_registered(jobs) and all(j.status != JobStatus.TERMINATING for j in jobs): - non_terminating_registered_replicas_count += 1 - - replicas_to_stop_count += max(0, non_terminating_registered_replicas_count - group_desired) + replicas_to_stop_count = state.unregistered_out_of_date_replica_count + replicas_to_stop_count += max( + 0, + state.registered_non_terminating_replica_count - group_desired, + ) if replicas_to_stop_count > 0: - # Build lists again to get current state - active_replicas, inactive_replicas = build_replica_lists( - run_model=run_model, - group_filter=group.name, - ) - await scale_run_replicas_for_group( session=session, run_model=run_model, group=group, replicas_diff=-replicas_to_stop_count, run_spec=run_spec, - active_replicas=active_replicas, - inactive_replicas=inactive_replicas, + active_replicas=state.active_replicas, + inactive_replicas=state.inactive_replicas, ) diff --git a/src/dstack/_internal/server/services/runs/replicas.py b/src/dstack/_internal/server/services/runs/replicas.py index dffc2d81b..0ef4daa24 100644 --- a/src/dstack/_internal/server/services/runs/replicas.py +++ b/src/dstack/_internal/server/services/runs/replicas.py @@ -1,4 +1,5 @@ import json +from dataclasses import dataclass from typing import List, Optional, Tuple from sqlalchemy.ext.asyncio import AsyncSession @@ -220,6 +221,65 @@ async def _scale_up_replicas( run_model.jobs.append(job_model) +@dataclass +class GroupRolloutState: + active_replicas: List[Tuple[int, bool, int, List[JobModel]]] + inactive_replicas: List[Tuple[int, bool, int, List[JobModel]]] + has_out_of_date_replicas: bool + non_terminated_replica_count: int + unregistered_out_of_date_replica_count: int + registered_non_terminating_replica_count: int + + +def get_group_desired_replica_count(run_model: RunModel, group: ReplicaGroup) -> int: + assert group.name is not None, "Group name is always set" + desired_replica_counts = ( + json.loads(run_model.desired_replica_counts) if run_model.desired_replica_counts else {} + ) + return desired_replica_counts.get(group.name, group.count.min or 0) + + +def get_group_rollout_state(run_model: RunModel, group: ReplicaGroup) -> GroupRolloutState: + assert group.name is not None, "Group name is always set" + active_replicas, inactive_replicas = build_replica_lists( + run_model=run_model, + group_filter=group.name, + ) + + non_terminated_replica_nums = set() + unregistered_out_of_date_replica_count = 0 + registered_non_terminating_replica_count = 0 + + for _, jobs in group_jobs_by_replica_latest(run_model.jobs): + if not job_belongs_to_group(jobs[0], group.name): + continue + + if any(not j.status.is_finished() for j in jobs): + non_terminated_replica_nums.add(jobs[0].replica_num) + + if ( + any(j.deployment_num < run_model.deployment_num for j in jobs) + and any( + j.status not in [JobStatus.TERMINATING] + JobStatus.finished_statuses() + for j in jobs + ) + and not is_replica_registered(jobs) + ): + unregistered_out_of_date_replica_count += 1 + + if is_replica_registered(jobs) and all(j.status != JobStatus.TERMINATING for j in jobs): + registered_non_terminating_replica_count += 1 + + return GroupRolloutState( + active_replicas=active_replicas, + inactive_replicas=inactive_replicas, + has_out_of_date_replicas=has_out_of_date_replicas(run_model, group_filter=group.name), + non_terminated_replica_count=len(non_terminated_replica_nums), + unregistered_out_of_date_replica_count=unregistered_out_of_date_replica_count, + registered_non_terminating_replica_count=registered_non_terminating_replica_count, + ) + + async def scale_run_replicas_for_all_groups( session: AsyncSession, run_model: RunModel, @@ -229,31 +289,17 @@ async def scale_run_replicas_for_all_groups( if not replicas: return - desired_replica_counts = ( - json.loads(run_model.desired_replica_counts) if run_model.desired_replica_counts else {} - ) + run_spec = get_run_spec(run_model) for group in replicas: - assert group.name is not None, "Group name is always set" - group_desired = desired_replica_counts.get(group.name, group.count.min or 0) - - # Build replica lists filtered by this group - active_replicas, inactive_replicas = build_replica_lists( - run_model=run_model, group_filter=group.name - ) - - # Count active replicas - active_group_count = len(active_replicas) - group_diff = group_desired - active_group_count + group_desired = get_group_desired_replica_count(run_model, group) + state = get_group_rollout_state(run_model, group) + group_diff = group_desired - len(state.active_replicas) if group_diff != 0: - # Check if rolling deployment is in progress for THIS GROUP - - group_has_out_of_date = has_out_of_date_replicas(run_model, group_filter=group.name) - # During rolling deployment, don't scale down old replicas # Let rolling deployment handle stopping old replicas - if group_diff < 0 and group_has_out_of_date: + if group_diff < 0 and state.has_out_of_date_replicas: # Skip scaling down during rolling deployment continue await scale_run_replicas_for_group( @@ -261,9 +307,9 @@ async def scale_run_replicas_for_all_groups( run_model=run_model, group=group, replicas_diff=group_diff, - run_spec=get_run_spec(run_model), - active_replicas=active_replicas, - inactive_replicas=inactive_replicas, + run_spec=run_spec, + active_replicas=state.active_replicas, + inactive_replicas=state.inactive_replicas, ) From b54fd317b53c6ea82109403c592ab18e799f332d Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 19 Mar 2026 10:59:31 +0500 Subject: [PATCH 14/14] Move run job lookup into analysis --- .../server/background/scheduled_tasks/runs.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/runs.py b/src/dstack/_internal/server/background/scheduled_tasks/runs.py index bbe5b4256..1b3313725 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/runs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/runs.py @@ -348,12 +348,10 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel): run_spec = run.run_spec retry_single_job = _can_retry_single_job(run_spec) _maybe_set_run_fleet_id_from_jobs(run_model) - run_jobs_by_position = _get_run_jobs_by_position(run) analysis = await _analyze_active_run( session=session, run_model=run_model, run=run, - run_jobs_by_position=run_jobs_by_position, retry_single_job=retry_single_job, ) transition = _get_active_run_transition(run, analysis) @@ -368,18 +366,16 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel): ) -def _get_run_jobs_by_position(run: Run) -> Dict[Tuple[int, int], Job]: - return {(job.job_spec.replica_num, job.job_spec.job_num): job for job in run.jobs} - - async def _analyze_active_run( session: AsyncSession, run_model: RunModel, run: Run, - run_jobs_by_position: Dict[Tuple[int, int], Job], retry_single_job: bool, ) -> _RunAnalysis: run_analysis = _RunAnalysis() + run_jobs_by_position = { + (job.job_spec.replica_num, job.job_spec.job_num): job for job in run.jobs + } for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs): replica_analysis = await _analyze_active_run_replica( session=session,