diff --git a/src/dstack/_internal/server/background/scheduled_tasks/runs.py b/src/dstack/_internal/server/background/scheduled_tasks/runs.py index 5bb373ae0..1b3313725 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/runs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/runs.py @@ -1,7 +1,7 @@ import asyncio import datetime -import json -from typing import List, Optional, Set, Tuple +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Set, Tuple from sqlalchemy import and_, func, or_, select from sqlalchemy.ext.asyncio import AsyncSession @@ -31,7 +31,6 @@ ) from dstack._internal.server.services import events from dstack._internal.server.services.jobs import ( - find_job, get_job_spec, get_job_specs_from_run_spec, group_jobs_by_replica_latest, @@ -49,13 +48,14 @@ ) from dstack._internal.server.services.runs.replicas import ( build_replica_lists, + get_group_desired_replica_count, + get_group_rollout_state, has_out_of_date_replicas, - is_replica_registered, - job_belongs_to_group, retry_run_replica_jobs, scale_down_replicas, scale_run_replicas, - scale_run_replicas_per_group, + scale_run_replicas_for_all_groups, + scale_run_replicas_for_group, ) from dstack._internal.server.services.secrets import get_project_secrets_mapping from dstack._internal.server.services.services import update_service_desired_replica_count @@ -257,8 +257,9 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel): replicas: List[ReplicaGroup] = run.run_spec.configuration.replica_groups - await scale_run_replicas_per_group(session, run_model, replicas) + await scale_run_replicas_for_all_groups(session, run_model, replicas) else: + # Non-service pending runs may have 0 job submissions and require new submission, e.g. scheduled tasks. run_model.desired_replica_count = 1 await scale_run_replicas(session, run_model, replicas_diff=run_model.desired_replica_count) @@ -296,6 +297,48 @@ def _get_retry_delay(resubmission_attempt: int) -> datetime.timedelta: return _PENDING_RETRY_DELAYS[-1] +@dataclass +class _ReplicaAnalysis: + """Per-replica classification of job states for determining the run's next status.""" + + replica_num: int + job_models: List[JobModel] + replica_info: autoscalers.ReplicaInfo + contributed_statuses: Set[RunStatus] = field(default_factory=set) + """`RunStatus` values derived from this replica's jobs. Merged into the run-level + analysis unless the replica is being retried as a whole.""" + termination_reasons: Set[RunTerminationReason] = field(default_factory=set) + """Why the replica failed. Only populated when `FAILED` is in `contributed_statuses`.""" + needs_retry: bool = False + """At least one job failed with a retryable reason and the retry duration hasn't been + exceeded. When `True`, the replica does not contribute its statuses to the run-level + analysis (unless `retry_single_job` is enabled) and is added to `replicas_to_retry` instead.""" + + +@dataclass +class _RunAnalysis: + """Aggregated replica analysis used to determine the run's next status. + + Each replica contributes `RunStatus` based on its jobs' statuses. + The run's new status is the highest-priority value across all + contributing replicas: FAILED > RUNNING > PROVISIONING > SUBMITTED > DONE. + Replicas that need full retry do not contribute and instead cause a PENDING transition. + """ + + contributed_statuses: Set[RunStatus] = field(default_factory=set) + termination_reasons: Set[RunTerminationReason] = field(default_factory=set) + replicas_to_retry: List[Tuple[int, List[JobModel]]] = field(default_factory=list) + """Replicas with retryable failures that haven't exceeded the retry duration.""" + replicas_info: List[autoscalers.ReplicaInfo] = field(default_factory=list) + """Per-replica active/inactive info for the autoscaler.""" + + +@dataclass +class _ActiveRunTransition: + new_status: RunStatus + termination_reason: Optional[RunTerminationReason] = None + + async def _process_active_run(session: AsyncSession, run_model: RunModel): """ Run is submitted, provisioning, or running. @@ -304,149 +347,287 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel): run = run_model_to_run(run_model) run_spec = run.run_spec retry_single_job = _can_retry_single_job(run_spec) + _maybe_set_run_fleet_id_from_jobs(run_model) + analysis = await _analyze_active_run( + session=session, + run_model=run_model, + run=run, + retry_single_job=retry_single_job, + ) + transition = _get_active_run_transition(run, analysis) + await _apply_active_run_transition( + session=session, + run_model=run_model, + run_spec=run_spec, + transition=transition, + replicas_to_retry=analysis.replicas_to_retry, + retry_single_job=retry_single_job, + replicas_info=analysis.replicas_info, + ) - run_statuses: Set[RunStatus] = set() - run_termination_reasons: Set[RunTerminationReason] = set() - replicas_to_retry: List[Tuple[int, List[JobModel]]] = [] - replicas_info: List[autoscalers.ReplicaInfo] = [] +async def _analyze_active_run( + session: AsyncSession, + run_model: RunModel, + run: Run, + retry_single_job: bool, +) -> _RunAnalysis: + run_analysis = _RunAnalysis() + run_jobs_by_position = { + (job.job_spec.replica_num, job.job_spec.job_num): job for job in run.jobs + } for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs): - replica_statuses: Set[RunStatus] = set() - replica_needs_retry = False - replica_active = True - jobs_done_num = 0 - for job_model in job_models: - job = find_job(run.jobs, job_model.replica_num, job_model.job_num) - if ( - run_model.fleet_id is None - and job_model.instance is not None - and job_model.instance.fleet_id is not None - ): - run_model.fleet_id = job_model.instance.fleet_id - if job_model.status == JobStatus.DONE or ( - job_model.status == JobStatus.TERMINATING - and job_model.termination_reason == JobTerminationReason.DONE_BY_RUNNER - ): - # the job is done or going to be done - replica_statuses.add(RunStatus.DONE) - jobs_done_num += 1 - elif job_model.termination_reason == JobTerminationReason.SCALED_DOWN: - # the job was scaled down - replica_active = False - elif job_model.status == JobStatus.RUNNING: - # the job is running - replica_statuses.add(RunStatus.RUNNING) - elif job_model.status in {JobStatus.PROVISIONING, JobStatus.PULLING}: - # the job is provisioning - replica_statuses.add(RunStatus.PROVISIONING) - elif job_model.status == JobStatus.SUBMITTED: - # the job is submitted - replica_statuses.add(RunStatus.SUBMITTED) - elif job_model.status == JobStatus.FAILED or ( - job_model.status - in [JobStatus.TERMINATING, JobStatus.TERMINATED, JobStatus.ABORTED] - and job_model.termination_reason - not in {JobTerminationReason.DONE_BY_RUNNER, JobTerminationReason.SCALED_DOWN} - ): - current_duration = await _should_retry_job(session, run, job, job_model) - if current_duration is None: - replica_statuses.add(RunStatus.FAILED) - run_termination_reasons.add(RunTerminationReason.JOB_FAILED) - else: - if _is_retry_duration_exceeded(job, current_duration): - replica_statuses.add(RunStatus.FAILED) - run_termination_reasons.add(RunTerminationReason.RETRY_LIMIT_EXCEEDED) - else: - replica_needs_retry = True - else: - raise ValueError(f"Unexpected job status {job_model.status}") + replica_analysis = await _analyze_active_run_replica( + session=session, + run_model=run_model, + run=run, + run_jobs_by_position=run_jobs_by_position, + replica_num=replica_num, + job_models=job_models, + ) + _apply_replica_analysis(run_analysis, replica_analysis, retry_single_job) + return run_analysis - if RunStatus.FAILED in replica_statuses: - run_statuses.add(RunStatus.FAILED) - else: - if replica_needs_retry: - replicas_to_retry.append((replica_num, job_models)) - if not replica_needs_retry or retry_single_job: - run_statuses.update(replica_statuses) - - if jobs_done_num == len(job_models): - # Consider replica inactive if all its jobs are done for some reason. - # If only some jobs are done, replica is considered active to avoid - # provisioning new replicas for partially done multi-node tasks. + +async def _analyze_active_run_replica( + session: AsyncSession, + run_model: RunModel, + run: Run, + run_jobs_by_position: Dict[Tuple[int, int], Job], + replica_num: int, + job_models: List[JobModel], +) -> _ReplicaAnalysis: + contributed_statuses: Set[RunStatus] = set() + termination_reasons: Set[RunTerminationReason] = set() + needs_retry = False + replica_active = True + jobs_done_num = 0 + + for job_model in job_models: + job = run_jobs_by_position[(job_model.replica_num, job_model.job_num)] + + if _job_is_done_or_finishing_done(job_model): + contributed_statuses.add(RunStatus.DONE) + jobs_done_num += 1 + continue + + if _job_was_scaled_down(job_model): replica_active = False + continue - replica_info = _get_replica_info(job_models, replica_active) - replicas_info.append(replica_info) + replica_status = _get_non_terminal_replica_status(job_model) + if replica_status is not None: + contributed_statuses.add(replica_status) + continue - termination_reason: Optional[RunTerminationReason] = None - if RunStatus.FAILED in run_statuses: - new_status = RunStatus.TERMINATING - if RunTerminationReason.JOB_FAILED in run_termination_reasons: + if _job_needs_retry_evaluation(job_model): + current_duration = await _should_retry_job(session, run, job, job_model) + if current_duration is None: + contributed_statuses.add(RunStatus.FAILED) + termination_reasons.add(RunTerminationReason.JOB_FAILED) + elif _is_retry_duration_exceeded(job, current_duration): + contributed_statuses.add(RunStatus.FAILED) + termination_reasons.add(RunTerminationReason.RETRY_LIMIT_EXCEEDED) + else: + needs_retry = True + continue + + raise ServerError(f"Unexpected job status {job_model.status}") + + if jobs_done_num == len(job_models): + # Consider replica inactive if all its jobs are done for some reason. + # If only some jobs are done, replica is considered active to avoid + # provisioning new replicas for partially done multi-node tasks. + replica_active = False + + return _ReplicaAnalysis( + replica_num=replica_num, + job_models=job_models, + replica_info=_get_replica_info(job_models, replica_active), + contributed_statuses=contributed_statuses, + termination_reasons=termination_reasons, + needs_retry=needs_retry, + ) + + +def _apply_replica_analysis( + analysis: _RunAnalysis, + replica_analysis: _ReplicaAnalysis, + retry_single_job: bool, +) -> None: + analysis.replicas_info.append(replica_analysis.replica_info) + + if RunStatus.FAILED in replica_analysis.contributed_statuses: + analysis.contributed_statuses.add(RunStatus.FAILED) + analysis.termination_reasons.update(replica_analysis.termination_reasons) + return + + if replica_analysis.needs_retry: + analysis.replicas_to_retry.append( + (replica_analysis.replica_num, replica_analysis.job_models) + ) + + if not replica_analysis.needs_retry or retry_single_job: + analysis.contributed_statuses.update(replica_analysis.contributed_statuses) + + +def _maybe_set_run_fleet_id_from_jobs(run_model: RunModel) -> None: + """ + The master job gets fleet assigned with the instance. + The run then gets from the master job's instance, and non-master jobs wait for the run's fleet to be assigned. + """ + if run_model.fleet_id is not None: + return + + for job_model in run_model.jobs: + if job_model.instance is not None and job_model.instance.fleet_id is not None: + run_model.fleet_id = job_model.instance.fleet_id + return + + +def _job_is_done_or_finishing_done(job_model: JobModel) -> bool: + return job_model.status == JobStatus.DONE or ( + job_model.status == JobStatus.TERMINATING + and job_model.termination_reason == JobTerminationReason.DONE_BY_RUNNER + ) + + +def _job_was_scaled_down(job_model: JobModel) -> bool: + return job_model.termination_reason == JobTerminationReason.SCALED_DOWN + + +def _get_non_terminal_replica_status(job_model: JobModel) -> Optional[RunStatus]: + if job_model.status == JobStatus.RUNNING: + return RunStatus.RUNNING + if job_model.status in {JobStatus.PROVISIONING, JobStatus.PULLING}: + return RunStatus.PROVISIONING + if job_model.status == JobStatus.SUBMITTED: + return RunStatus.SUBMITTED + return None + + +def _job_needs_retry_evaluation(job_model: JobModel) -> bool: + return job_model.status == JobStatus.FAILED or ( + job_model.status in [JobStatus.TERMINATING, JobStatus.TERMINATED, JobStatus.ABORTED] + and job_model.termination_reason + not in {JobTerminationReason.DONE_BY_RUNNER, JobTerminationReason.SCALED_DOWN} + ) + + +def _get_active_run_transition(run: Run, analysis: _RunAnalysis) -> _ActiveRunTransition: + # Check `analysis.contributed_statuses` in the priority order. + if RunStatus.FAILED in analysis.contributed_statuses: + if RunTerminationReason.JOB_FAILED in analysis.termination_reasons: termination_reason = RunTerminationReason.JOB_FAILED - elif RunTerminationReason.RETRY_LIMIT_EXCEEDED in run_termination_reasons: + elif RunTerminationReason.RETRY_LIMIT_EXCEEDED in analysis.termination_reasons: termination_reason = RunTerminationReason.RETRY_LIMIT_EXCEEDED else: - raise ValueError(f"Unexpected termination reason {run_termination_reasons}") - elif _should_stop_on_master_done(run): - new_status = RunStatus.TERMINATING + raise ServerError(f"Unexpected termination reason {analysis.termination_reasons}") + return _ActiveRunTransition( + new_status=RunStatus.TERMINATING, + termination_reason=termination_reason, + ) + + if _should_stop_on_master_done(run): # ALL_JOBS_DONE is used for all DONE reasons including master-done - termination_reason = RunTerminationReason.ALL_JOBS_DONE - elif RunStatus.RUNNING in run_statuses: - new_status = RunStatus.RUNNING - elif RunStatus.PROVISIONING in run_statuses: - new_status = RunStatus.PROVISIONING - elif RunStatus.SUBMITTED in run_statuses: - new_status = RunStatus.SUBMITTED - elif RunStatus.DONE in run_statuses and not replicas_to_retry: - new_status = RunStatus.TERMINATING - termination_reason = RunTerminationReason.ALL_JOBS_DONE - else: - new_status = RunStatus.PENDING + return _ActiveRunTransition( + new_status=RunStatus.TERMINATING, + termination_reason=RunTerminationReason.ALL_JOBS_DONE, + ) - # Terminate active jobs if the run is to be resubmitted - if new_status == RunStatus.PENDING and not retry_single_job: - for _, replica_jobs in replicas_to_retry: - for job_model in replica_jobs: - if not ( - job_model.status.is_finished() or job_model.status == JobStatus.TERMINATING - ): - job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER - job_model.termination_reason_message = "Run is to be resubmitted" - switch_job_status(session, job_model, JobStatus.TERMINATING) - - if new_status not in {RunStatus.TERMINATING, RunStatus.PENDING}: + if RunStatus.RUNNING in analysis.contributed_statuses: + return _ActiveRunTransition(new_status=RunStatus.RUNNING) + if RunStatus.PROVISIONING in analysis.contributed_statuses: + return _ActiveRunTransition(new_status=RunStatus.PROVISIONING) + if RunStatus.SUBMITTED in analysis.contributed_statuses: + return _ActiveRunTransition(new_status=RunStatus.SUBMITTED) + if RunStatus.DONE in analysis.contributed_statuses and not analysis.replicas_to_retry: + return _ActiveRunTransition( + new_status=RunStatus.TERMINATING, + termination_reason=RunTerminationReason.ALL_JOBS_DONE, + ) + if not analysis.contributed_statuses or analysis.contributed_statuses == {RunStatus.DONE}: + # No active replicas remain — resubmit the entire run. + # `contributed_statuses` is either empty (every replica is retrying) or contains + # only DONE (some replicas finished, others need retry). + return _ActiveRunTransition(new_status=RunStatus.PENDING) + raise ServerError("Failed to determine run transition: unexpected active run state") + + +async def _apply_active_run_transition( + session: AsyncSession, + run_model: RunModel, + run_spec: RunSpec, + transition: _ActiveRunTransition, + replicas_to_retry: List[Tuple[int, List[JobModel]]], + retry_single_job: bool, + replicas_info: List[autoscalers.ReplicaInfo], +) -> None: + if transition.new_status == RunStatus.PENDING and not retry_single_job: + _terminate_retrying_replica_jobs(session, replicas_to_retry) + + if transition.new_status not in {RunStatus.TERMINATING, RunStatus.PENDING}: # No need to retry, scale, or redeploy replicas if the run is terminating, # pending run will retry replicas in `process_pending_run` await _handle_run_replicas( - session, run_model, run_spec, replicas_to_retry, retry_single_job, replicas_info + session, + run_model, + run_spec, + replicas_to_retry, + retry_single_job, + replicas_info, ) - if run_model.status != new_status: - if run_model.status == RunStatus.SUBMITTED and new_status == RunStatus.PROVISIONING: - current_time = common.get_current_datetime() - submit_to_provision_duration = (current_time - run_model.submitted_at).total_seconds() - logger.info( - "%s: run took %.2f seconds from submission to provisioning.", - fmt(run_model), - submit_to_provision_duration, - ) - project_name = run_model.project.name - run_metrics.log_submit_to_provision_duration( - submit_to_provision_duration, project_name, run_spec.configuration.type - ) + _maybe_switch_active_run_status(session, run_model, run_spec, transition) + + +def _terminate_retrying_replica_jobs( + session: AsyncSession, + replicas_to_retry: List[Tuple[int, List[JobModel]]], +) -> None: + for _, replica_jobs in replicas_to_retry: + for job_model in replica_jobs: + if job_model.status.is_finished() or job_model.status == JobStatus.TERMINATING: + continue + job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER + job_model.termination_reason_message = "Run is to be resubmitted" + switch_job_status(session, job_model, JobStatus.TERMINATING) - if new_status == RunStatus.PENDING: - run_metrics.increment_pending_runs(run_model.project.name, run_spec.configuration.type) - # Unassign run from fleet so that the new fleet can be chosen when retrying - run_model.fleet = None - run_model.termination_reason = termination_reason - switch_run_status(session, run_model, new_status) - # While a run goes to pending without provisioning, resubmission_attempt increases. - if new_status == RunStatus.PROVISIONING: - run_model.resubmission_attempt = 0 - elif new_status == RunStatus.PENDING: - run_model.resubmission_attempt += 1 +def _maybe_switch_active_run_status( + session: AsyncSession, + run_model: RunModel, + run_spec: RunSpec, + transition: _ActiveRunTransition, +) -> None: + if run_model.status == transition.new_status: + return + + if run_model.status == RunStatus.SUBMITTED and transition.new_status == RunStatus.PROVISIONING: + current_time = common.get_current_datetime() + submit_to_provision_duration = (current_time - run_model.submitted_at).total_seconds() + logger.info( + "%s: run took %.2f seconds from submission to provisioning.", + fmt(run_model), + submit_to_provision_duration, + ) + project_name = run_model.project.name + run_metrics.log_submit_to_provision_duration( + submit_to_provision_duration, project_name, run_spec.configuration.type + ) + + if transition.new_status == RunStatus.PENDING: + run_metrics.increment_pending_runs(run_model.project.name, run_spec.configuration.type) + # Unassign run from fleet so that the new fleet can be chosen when retrying + run_model.fleet = None + + run_model.termination_reason = transition.termination_reason + switch_run_status(session=session, run_model=run_model, new_status=transition.new_status) + # While a run goes to pending without provisioning, resubmission_attempt increases. + if transition.new_status == RunStatus.PROVISIONING: + run_model.resubmission_attempt = 0 + elif transition.new_status == RunStatus.PENDING: + run_model.resubmission_attempt += 1 def _get_replica_info( @@ -475,12 +656,10 @@ async def _handle_run_replicas( replicas_info: list[autoscalers.ReplicaInfo], ) -> None: """ - Does ONE of: - - replica retry - - replica scaling - - replica rolling deployment - - Does not do everything at once to avoid conflicts between the stages and long DB transactions. + Performs one or more steps: + - replicas retry + - replicas scaling + - replicas rolling deployment """ if replicas_to_retry: @@ -498,69 +677,24 @@ async def _handle_run_replicas( # FIXME: should only include scaling events, not retries and deployments last_scaled_at=max((r.timestamp for r in replicas_info), default=None), ) - replicas: List[ReplicaGroup] = run_spec.configuration.replica_groups - assert replicas, "replica groups should always return at least one group" + replica_groups: List[ReplicaGroup] = run_spec.configuration.replica_groups + assert replica_groups, "replica groups should always return at least one group" - await scale_run_replicas_per_group(session, run_model, replicas) + await scale_run_replicas_for_all_groups(session, run_model, replica_groups) - # Handle per-group rolling deployment await _update_jobs_to_new_deployment_in_place( session=session, run_model=run_model, run_spec=run_spec, - replicas=replicas, ) - # Process per-group rolling deployment - for group in replicas: + + for group in replica_groups: await _handle_rolling_deployment_for_group( session=session, run_model=run_model, group=group, run_spec=run_spec ) - # Terminate replicas from groups that were removed from the configuration - existing_group_names = set() - for job in run_model.jobs: - if job.status.is_finished(): - continue - try: - job_spec = get_job_spec(job) - existing_group_names.add(job_spec.replica_group) - except Exception: - continue - new_group_names = {group.name for group in replicas} - removed_group_names = existing_group_names - new_group_names - for removed_group_name in removed_group_names: - # Build replica lists for this removed group - active_replicas, inactive_replicas = build_replica_lists( - run_model=run_model, - group_filter=removed_group_name, - ) - total_replicas = len(active_replicas) + len(inactive_replicas) - if total_replicas > 0: - logger.info( - "%s: terminating %d replica(s) from removed group '%s'", - fmt(run_model), - total_replicas, - removed_group_name, - ) - # Terminate all active replicas in the removed group - if active_replicas: - scale_down_replicas(session, active_replicas, len(active_replicas)) - # Terminate all inactive replicas in the removed group - if inactive_replicas: - scale_down_replicas(session, inactive_replicas, len(inactive_replicas)) - return - - max_replica_count = run_model.desired_replica_count - if has_out_of_date_replicas(run_model): - # allow extra replicas when deployment is in progress - max_replica_count += ROLLING_DEPLOYMENT_MAX_SURGE - - active_replica_count = sum(1 for r in replicas_info if r.active) - if active_replica_count not in range(run_model.desired_replica_count, max_replica_count + 1): - await scale_run_replicas( - session, - run_model, - replicas_diff=run_model.desired_replica_count - active_replica_count, + _terminate_removed_replica_groups( + session=session, run_model=run_model, replica_groups=replica_groups ) return @@ -570,54 +704,16 @@ async def _handle_run_replicas( run_spec=run_spec, ) if has_out_of_date_replicas(run_model): - assert run_spec.configuration.type == "service", ( - "Rolling deployment is only supported for services" - ) - non_terminated_replica_count = len( - {j.replica_num for j in run_model.jobs if not j.status.is_finished()} - ) - # Avoid using too much hardware during a deployment - never have - # more than max_replica_count non-terminated replicas. - if non_terminated_replica_count < max_replica_count: - # Start more up-to-date replicas that will eventually replace out-of-date replicas. - await scale_run_replicas( - session, - run_model, - replicas_diff=max_replica_count - non_terminated_replica_count, - ) - - replicas_to_stop_count = 0 - # stop any out-of-date replicas that are not registered - replicas_to_stop_count += sum( - any(j.deployment_num < run_model.deployment_num for j in jobs) - and any( - j.status not in [JobStatus.TERMINATING] + JobStatus.finished_statuses() - for j in jobs - ) - and not is_replica_registered(jobs) - for _, jobs in group_jobs_by_replica_latest(run_model.jobs) - ) - # stop excessive registered out-of-date replicas, except those that are already `terminating` - non_terminating_registered_replicas_count = sum( - is_replica_registered(jobs) and all(j.status != JobStatus.TERMINATING for j in jobs) - for _, jobs in group_jobs_by_replica_latest(run_model.jobs) - ) - replicas_to_stop_count += max( - 0, non_terminating_registered_replicas_count - run_model.desired_replica_count - ) - if replicas_to_stop_count: - await scale_run_replicas( - session, - run_model, - replicas_diff=-replicas_to_stop_count, - ) + # Currently, only services can change job spec on update, + # so for other runs out-of-date replicas are not possible. + # Keeping assert in case this changes. + assert False, "Rolling deployment is only supported for services" async def _update_jobs_to_new_deployment_in_place( session: AsyncSession, run_model: RunModel, run_spec: RunSpec, - replicas: Optional[List] = None, ) -> None: """ Bump deployment_num for jobs that do not require redeployment. @@ -633,10 +729,8 @@ async def _update_jobs_to_new_deployment_in_place( if all(j.deployment_num == run_model.deployment_num for j in job_models): continue - # Determine which group this replica belongs to replica_group_name = None - - if replicas: + if run_spec.configuration.type == "service": job_spec = get_job_spec(job_models[0]) replica_group_name = job_spec.replica_group @@ -752,95 +846,69 @@ async def _handle_rolling_deployment_for_group( """ Handle rolling deployment for a single replica group. """ - from dstack._internal.server.services.runs.replicas import ( - build_replica_lists, - scale_run_replicas_for_group, - ) - - desired_replica_counts = ( - json.loads(run_model.desired_replica_counts) if run_model.desired_replica_counts else {} - ) - - group_desired = desired_replica_counts.get(group.name, group.count.min or 0) - - # Check if group has out-of-date replicas - if not has_out_of_date_replicas(run_model, group_filter=group.name): - return # Group is up-to-date + group_desired = get_group_desired_replica_count(run_model, group) + state = get_group_rollout_state(run_model, group) + if not state.has_out_of_date_replicas: + return - # Calculate max replicas (allow surge during deployment) group_max_replica_count = group_desired + ROLLING_DEPLOYMENT_MAX_SURGE - # Count non-terminated replicas for this group only - - non_terminated_replica_count = len( - { - j.replica_num - for j in run_model.jobs - if not j.status.is_finished() - and group.name is not None - and job_belongs_to_group(job=j, group_name=group.name) - } - ) - # Start new up-to-date replicas if needed - if non_terminated_replica_count < group_max_replica_count: - active_replicas, inactive_replicas = build_replica_lists( - run_model=run_model, - group_filter=group.name, - ) - + if state.non_terminated_replica_count < group_max_replica_count: await scale_run_replicas_for_group( session=session, run_model=run_model, group=group, - replicas_diff=group_max_replica_count - non_terminated_replica_count, + replicas_diff=group_max_replica_count - state.non_terminated_replica_count, run_spec=run_spec, - active_replicas=active_replicas, - inactive_replicas=inactive_replicas, + active_replicas=state.active_replicas, + inactive_replicas=state.inactive_replicas, ) + state = get_group_rollout_state(run_model, group) - # Stop out-of-date replicas that are not registered - replicas_to_stop_count = 0 - for _, jobs in group_jobs_by_replica_latest(run_model.jobs): - assert group.name is not None, "Group name is always set" - if not job_belongs_to_group(jobs[0], group.name): - continue - # Check if replica is out-of-date and not registered - if ( - any(j.deployment_num < run_model.deployment_num for j in jobs) - and any( - j.status not in [JobStatus.TERMINATING] + JobStatus.finished_statuses() - for j in jobs - ) - and not is_replica_registered(jobs) - ): - replicas_to_stop_count += 1 - - # Stop excessive registered out-of-date replicas - non_terminating_registered_replicas_count = 0 - for _, jobs in group_jobs_by_replica_latest(run_model.jobs): - assert group.name is not None, "Group name is always set" - if not job_belongs_to_group(jobs[0], group.name): - continue - - if is_replica_registered(jobs) and all(j.status != JobStatus.TERMINATING for j in jobs): - non_terminating_registered_replicas_count += 1 - - replicas_to_stop_count += max(0, non_terminating_registered_replicas_count - group_desired) + replicas_to_stop_count = state.unregistered_out_of_date_replica_count + replicas_to_stop_count += max( + 0, + state.registered_non_terminating_replica_count - group_desired, + ) if replicas_to_stop_count > 0: - # Build lists again to get current state - active_replicas, inactive_replicas = build_replica_lists( - run_model=run_model, - group_filter=group.name, - ) - await scale_run_replicas_for_group( session=session, run_model=run_model, group=group, replicas_diff=-replicas_to_stop_count, run_spec=run_spec, - active_replicas=active_replicas, - inactive_replicas=inactive_replicas, + active_replicas=state.active_replicas, + inactive_replicas=state.inactive_replicas, ) + + +def _terminate_removed_replica_groups( + session: AsyncSession, run_model: RunModel, replica_groups: List[ReplicaGroup] +): + existing_group_names = set() + for job in run_model.jobs: + if job.status.is_finished(): + continue + job_spec = get_job_spec(job) + existing_group_names.add(job_spec.replica_group) + new_group_names = {group.name for group in replica_groups} + removed_group_names = existing_group_names - new_group_names + for removed_group_name in removed_group_names: + active_replicas, inactive_replicas = build_replica_lists( + run_model=run_model, + group_filter=removed_group_name, + ) + total_replicas = len(active_replicas) + len(inactive_replicas) + if total_replicas > 0: + logger.info( + "%s: terminating %d replica(s) from removed group '%s'", + fmt(run_model), + total_replicas, + removed_group_name, + ) + if active_replicas: + scale_down_replicas(session, active_replicas, len(active_replicas)) + if inactive_replicas: + scale_down_replicas(session, inactive_replicas, len(inactive_replicas)) diff --git a/src/dstack/_internal/server/services/runs/replicas.py b/src/dstack/_internal/server/services/runs/replicas.py index ffb7bd216..0ef4daa24 100644 --- a/src/dstack/_internal/server/services/runs/replicas.py +++ b/src/dstack/_internal/server/services/runs/replicas.py @@ -1,4 +1,5 @@ import json +from dataclasses import dataclass from typing import List, Optional, Tuple from sqlalchemy.ext.asyncio import AsyncSession @@ -220,7 +221,66 @@ async def _scale_up_replicas( run_model.jobs.append(job_model) -async def scale_run_replicas_per_group( +@dataclass +class GroupRolloutState: + active_replicas: List[Tuple[int, bool, int, List[JobModel]]] + inactive_replicas: List[Tuple[int, bool, int, List[JobModel]]] + has_out_of_date_replicas: bool + non_terminated_replica_count: int + unregistered_out_of_date_replica_count: int + registered_non_terminating_replica_count: int + + +def get_group_desired_replica_count(run_model: RunModel, group: ReplicaGroup) -> int: + assert group.name is not None, "Group name is always set" + desired_replica_counts = ( + json.loads(run_model.desired_replica_counts) if run_model.desired_replica_counts else {} + ) + return desired_replica_counts.get(group.name, group.count.min or 0) + + +def get_group_rollout_state(run_model: RunModel, group: ReplicaGroup) -> GroupRolloutState: + assert group.name is not None, "Group name is always set" + active_replicas, inactive_replicas = build_replica_lists( + run_model=run_model, + group_filter=group.name, + ) + + non_terminated_replica_nums = set() + unregistered_out_of_date_replica_count = 0 + registered_non_terminating_replica_count = 0 + + for _, jobs in group_jobs_by_replica_latest(run_model.jobs): + if not job_belongs_to_group(jobs[0], group.name): + continue + + if any(not j.status.is_finished() for j in jobs): + non_terminated_replica_nums.add(jobs[0].replica_num) + + if ( + any(j.deployment_num < run_model.deployment_num for j in jobs) + and any( + j.status not in [JobStatus.TERMINATING] + JobStatus.finished_statuses() + for j in jobs + ) + and not is_replica_registered(jobs) + ): + unregistered_out_of_date_replica_count += 1 + + if is_replica_registered(jobs) and all(j.status != JobStatus.TERMINATING for j in jobs): + registered_non_terminating_replica_count += 1 + + return GroupRolloutState( + active_replicas=active_replicas, + inactive_replicas=inactive_replicas, + has_out_of_date_replicas=has_out_of_date_replicas(run_model, group_filter=group.name), + non_terminated_replica_count=len(non_terminated_replica_nums), + unregistered_out_of_date_replica_count=unregistered_out_of_date_replica_count, + registered_non_terminating_replica_count=registered_non_terminating_replica_count, + ) + + +async def scale_run_replicas_for_all_groups( session: AsyncSession, run_model: RunModel, replicas: List[ReplicaGroup], @@ -229,31 +289,17 @@ async def scale_run_replicas_per_group( if not replicas: return - desired_replica_counts = ( - json.loads(run_model.desired_replica_counts) if run_model.desired_replica_counts else {} - ) + run_spec = get_run_spec(run_model) for group in replicas: - assert group.name is not None, "Group name is always set" - group_desired = desired_replica_counts.get(group.name, group.count.min or 0) - - # Build replica lists filtered by this group - active_replicas, inactive_replicas = build_replica_lists( - run_model=run_model, group_filter=group.name - ) - - # Count active replicas - active_group_count = len(active_replicas) - group_diff = group_desired - active_group_count + group_desired = get_group_desired_replica_count(run_model, group) + state = get_group_rollout_state(run_model, group) + group_diff = group_desired - len(state.active_replicas) if group_diff != 0: - # Check if rolling deployment is in progress for THIS GROUP - - group_has_out_of_date = has_out_of_date_replicas(run_model, group_filter=group.name) - # During rolling deployment, don't scale down old replicas # Let rolling deployment handle stopping old replicas - if group_diff < 0 and group_has_out_of_date: + if group_diff < 0 and state.has_out_of_date_replicas: # Skip scaling down during rolling deployment continue await scale_run_replicas_for_group( @@ -261,9 +307,9 @@ async def scale_run_replicas_per_group( run_model=run_model, group=group, replicas_diff=group_diff, - run_spec=get_run_spec(run_model), - active_replicas=active_replicas, - inactive_replicas=inactive_replicas, + run_spec=run_spec, + active_replicas=state.active_replicas, + inactive_replicas=state.inactive_replicas, ) diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_runs.py b/src/tests/_internal/server/background/scheduled_tasks/test_runs.py index ffb63de35..eb45ebcb6 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_runs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_runs.py @@ -11,11 +11,18 @@ import dstack._internal.server.background.scheduled_tasks.runs as process_runs from dstack._internal.core.models.configurations import ( ProbeConfig, + ReplicaGroup, ServiceConfiguration, TaskConfiguration, ) from dstack._internal.core.models.instances import InstanceStatus -from dstack._internal.core.models.profiles import Profile, ProfileRetry, RetryEvent, Schedule +from dstack._internal.core.models.profiles import ( + Profile, + ProfileRetry, + RetryEvent, + Schedule, + StopCriteria, +) from dstack._internal.core.models.resources import Range from dstack._internal.core.models.runs import ( JobSpec, @@ -232,6 +239,43 @@ async def test_retry_running_to_failed(self, test_db, session: AsyncSession): assert run.status == RunStatus.TERMINATING assert run.termination_reason == RunTerminationReason.JOB_FAILED + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_retry_running_to_retry_limit_exceeded(self, test_db, session: AsyncSession): + run = await make_run( + session, + status=RunStatus.RUNNING, + retry=ProfileRetry(duration=180, on_events=[RetryEvent.NO_CAPACITY]), + ) + now = run.submitted_at + datetime.timedelta(minutes=10) + await create_job( + session=session, + run=run, + status=JobStatus.FAILED, + termination_reason=JobTerminationReason.EXECUTOR_ERROR, + last_processed_at=now - datetime.timedelta(minutes=4), + replica_num=0, + job_provisioning_data=get_job_provisioning_data(), + ) + await create_job( + session=session, + run=run, + status=JobStatus.FAILED, + termination_reason=JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY, + replica_num=0, + submission_num=1, + last_processed_at=now - datetime.timedelta(minutes=2), + job_provisioning_data=None, + ) + + with patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock: + datetime_mock.return_value = now + await process_runs.process_runs() + + await session.refresh(run) + assert run.status == RunStatus.TERMINATING + assert run.termination_reason == RunTerminationReason.RETRY_LIMIT_EXCEEDED + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_calculates_retry_duration_since_last_successful_submission( @@ -283,6 +327,53 @@ async def test_pending_to_submitted(self, test_db, session: AsyncSession): assert run.jobs[0].status == JobStatus.FAILED assert run.jobs[1].status == JobStatus.SUBMITTED + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_stops_when_master_job_done(self, test_db, session: AsyncSession): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec( + repo_id=repo.name, + configuration=TaskConfiguration( + commands=["echo hello"], + nodes=2, + ), + profile=Profile( + name="test-profile", + stop_criteria=StopCriteria.MASTER_DONE, + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=run_spec, + status=RunStatus.RUNNING, + ) + await create_job( + session=session, + run=run, + status=JobStatus.DONE, + termination_reason=JobTerminationReason.DONE_BY_RUNNER, + replica_num=0, + job_num=0, + ) + await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=0, + job_num=1, + ) + + await process_runs.process_runs() + + await session.refresh(run) + assert run.status == RunStatus.TERMINATING + assert run.termination_reason == RunTerminationReason.ALL_JOBS_DONE + class TestProcessRunsReplicas: @pytest.mark.asyncio @@ -478,6 +569,64 @@ async def test_considers_replicas_inactive_only_when_all_jobs_done( # Should not create new replica with new jobs assert len(run.jobs) == 2 + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_retrying_multinode_replica_terminates_active_sibling_jobs( + self, + test_db, + session: AsyncSession, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec( + repo_id=repo.name, + configuration=TaskConfiguration( + commands=["echo hello"], + nodes=2, + ), + profile=Profile( + name="test-profile", + retry=ProfileRetry(duration="10m", on_events=[RetryEvent.NO_CAPACITY]), + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=run_spec, + status=RunStatus.RUNNING, + ) + failed_job = await create_job( + session=session, + run=run, + status=JobStatus.FAILED, + termination_reason=JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY, + replica_num=0, + job_num=0, + ) + running_job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=0, + job_num=1, + job_provisioning_data=get_job_provisioning_data(), + ) + + with patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock: + datetime_mock.return_value = run.submitted_at + datetime.timedelta(minutes=1) + await process_runs.process_runs() + + await session.refresh(run) + await session.refresh(failed_job) + await session.refresh(running_job) + assert run.status == RunStatus.PENDING + assert failed_job.status == JobStatus.FAILED + assert running_job.status == JobStatus.TERMINATING + assert running_job.termination_reason == JobTerminationReason.TERMINATED_BY_SERVER + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_pending_to_submitted_adds_replicas(self, test_db, session: AsyncSession): @@ -708,6 +857,82 @@ async def test_not_updates_deployment_num_in_place_for_finished_replica( assert len(events[0].targets) == 1 assert events[0].targets[0].entity_id == run.jobs[0].id + async def test_terminates_replicas_from_removed_group(self, test_db, session: AsyncSession): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec( + repo_id=repo.name, + configuration=ServiceConfiguration( + port=8000, + image="ubuntu:latest", + replicas=[ + ReplicaGroup( + name="group-a", + count=parse_obj_as(Range[int], 1), + commands=["echo group-a"], + ), + ReplicaGroup( + name="group-b", + count=parse_obj_as(Range[int], 1), + commands=["echo group-b"], + ), + ], + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=run_spec, + status=RunStatus.RUNNING, + ) + group_a_job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=0, + registered=True, + ) + group_b_job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=1, + registered=True, + ) + + group_a_job_spec = cast(JobSpec, JobSpec.__response__.parse_raw(group_a_job.job_spec_data)) + group_a_job_spec.replica_group = "group-a" + group_a_job.job_spec_data = group_a_job_spec.json() + + group_b_job_spec = cast(JobSpec, JobSpec.__response__.parse_raw(group_b_job.job_spec_data)) + group_b_job_spec.replica_group = "group-b" + group_b_job.job_spec_data = group_b_job_spec.json() + + updated_run_spec: RunSpec = RunSpec.__response__.parse_raw(run.run_spec) + assert isinstance(updated_run_spec.configuration, ServiceConfiguration) + updated_run_spec.configuration.replicas = [ + ReplicaGroup( + name="group-a", + count=parse_obj_as(Range[int], 1), + commands=["echo group-a"], + ) + ] + run.run_spec = updated_run_spec.json() + await session.commit() + + await process_runs.process_runs() + + await session.refresh(run) + await session.refresh(group_a_job) + await session.refresh(group_b_job) + assert run.status == RunStatus.RUNNING + assert group_a_job.status == JobStatus.RUNNING + assert group_b_job.status == JobStatus.TERMINATING + assert group_b_job.termination_reason == JobTerminationReason.SCALED_DOWN + async def test_starts_new_replica(self, test_db, session: AsyncSession) -> None: run = await make_run(session, status=RunStatus.RUNNING, replicas=2, image="old") for replica_num in range(2):