diff --git a/runner/internal/common/types/types.go b/runner/internal/common/types/types.go index b7f6c6fd3..057c0248c 100644 --- a/runner/internal/common/types/types.go +++ b/runner/internal/common/types/types.go @@ -10,4 +10,5 @@ const ( TerminationReasonTerminatedByUser TerminationReason = "terminated_by_user" TerminationReasonTerminatedByServer TerminationReason = "terminated_by_server" TerminationReasonMaxDurationExceeded TerminationReason = "max_duration_exceeded" + TerminationReasonLogQuotaExceeded TerminationReason = "log_quota_exceeded" ) diff --git a/runner/internal/runner/executor/executor.go b/runner/internal/runner/executor/executor.go index 98289eb4e..3662a45aa 100644 --- a/runner/internal/runner/executor/executor.go +++ b/runner/internal/runner/executor/executor.go @@ -261,6 +261,17 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) { default: } + if errors.Is(err, ErrLogQuotaExceeded) { + log.Error(ctx, "Log quota exceeded", "quota", ex.jobLogs.quota) + ex.SetJobStateWithTerminationReason( + ctx, + schemas.JobStateFailed, + types.TerminationReasonLogQuotaExceeded, + fmt.Sprintf("Job log output exceeded the hourly quota of %d bytes", ex.jobLogs.quota), + ) + return fmt.Errorf("log quota exceeded: %w", err) + } + // todo fail reason? log.Error(ctx, "Exec failed", "err", err) var exitError *exec.ExitError @@ -283,6 +294,7 @@ func (ex *RunExecutor) SetJob(body schemas.SubmitBody) { ex.clusterInfo = body.ClusterInfo ex.secrets = body.Secrets ex.repoCredentials = body.RepoCredentials + ex.jobLogs.SetQuota(body.LogQuotaHour) ex.state = WaitCode } @@ -586,11 +598,10 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error defer func() { _ = cmd.Wait() }() // release resources if copy fails stripper := ansistrip.NewWriter(ex.jobLogs, AnsiStripFlushInterval, AnsiStripMaxDelay, MaxBufferSize) - defer func() { _ = stripper.Close() }() logger := io.MultiWriter(jobLogFile, ex.jobWsLogs, stripper) - _, err = io.Copy(logger, ptm) - if err != nil && !isPtyError(err) { - return fmt.Errorf("copy command output: %w", err) + + if err := ex.copyOutputWithQuota(cmd, ptm, stripper, logger); err != nil { + return err } if err = cmd.Wait(); err != nil { return fmt.Errorf("wait for command: %w", err) @@ -598,6 +609,40 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error return nil } +// copyOutputWithQuota streams process output through the log pipeline and +// monitors for log quota exceeded. The quota signal is out-of-band (via channel) +// because the ansistrip writer is async and swallows downstream write errors. +func (ex *RunExecutor) copyOutputWithQuota(cmd *exec.Cmd, ptm io.Reader, stripper io.Closer, logger io.Writer) error { + copyDone := make(chan error, 1) + go func() { + _, err := io.Copy(logger, ptm) + copyDone <- err + }() + + // Wait for either io.Copy to finish or quota to be exceeded. + var copyErr error + select { + case copyErr = <-copyDone: + case <-ex.jobLogs.QuotaExceeded(): + _ = cmd.Process.Kill() + <-copyDone + } + + // Flush the ansistrip buffer — may also trigger quota exceeded. + _ = stripper.Close() + + select { + case <-ex.jobLogs.QuotaExceeded(): + return ErrLogQuotaExceeded + default: + } + + if copyErr != nil && !isPtyError(copyErr) { + return fmt.Errorf("copy command output: %w", copyErr) + } + return nil +} + // setupGitCredentials must be called from Run after setJobUser func (ex *RunExecutor) setupGitCredentials(ctx context.Context) (func(), error) { if ex.repoCredentials == nil { diff --git a/runner/internal/runner/executor/executor_test.go b/runner/internal/runner/executor/executor_test.go index 915cca35a..2330cd6f3 100644 --- a/runner/internal/runner/executor/executor_test.go +++ b/runner/internal/runner/executor/executor_test.go @@ -141,6 +141,27 @@ func TestExecutor_MaxDuration(t *testing.T) { assert.ErrorContains(t, err, "killed") } +func TestExecutor_LogQuota(t *testing.T) { + if testing.Short() { + t.Skip() + } + + ex := makeTestExecutor(t) + ex.killDelay = 500 * time.Millisecond + // Output >100 bytes to trigger the quota + ex.jobSpec.Commands = append(ex.jobSpec.Commands, "for i in $(seq 1 20); do echo 'This line is long enough to exceed the quota easily'; done") + ex.jobLogs.SetQuota(100) + makeCodeTar(t, ex) + + err := ex.Run(t.Context()) + assert.ErrorContains(t, err, "log quota exceeded") + + // Verify the termination state was set + history := ex.GetHistory(0) + lastState := history.JobStates[len(history.JobStates)-1] + assert.Equal(t, schemas.JobStateFailed, lastState.State) +} + func TestExecutor_RemoteRepo(t *testing.T) { if testing.Short() { t.Skip() diff --git a/runner/internal/runner/executor/logs.go b/runner/internal/runner/executor/logs.go index 808fc84b1..54b087e32 100644 --- a/runner/internal/runner/executor/logs.go +++ b/runner/internal/runner/executor/logs.go @@ -1,29 +1,65 @@ package executor import ( + "errors" + "math" "sync" + "time" "github.com/dstackai/dstack/runner/internal/runner/schemas" ) +var ErrLogQuotaExceeded = errors.New("log quota exceeded") + type appendWriter struct { mu *sync.RWMutex // shares with executor history []schemas.LogEvent timestamp *MonotonicTimestamp // shares with executor + + quota int // bytes per hour, 0 = unlimited + bytesInHour int // bytes written in current hour bucket + currentHour int // monotonic hour bucket index since timeStarted + timeStarted time.Time // monotonic reference point for hour buckets + quotaExceeded chan struct{} // closed when quota is exceeded (out-of-band signal) + exceededOnce sync.Once } func newAppendWriter(mu *sync.RWMutex, timestamp *MonotonicTimestamp) *appendWriter { return &appendWriter{ - mu: mu, - history: make([]schemas.LogEvent, 0), - timestamp: timestamp, + mu: mu, + history: make([]schemas.LogEvent, 0), + timestamp: timestamp, + quotaExceeded: make(chan struct{}), } } +func (w *appendWriter) SetQuota(quota int) { + w.quota = quota + w.timeStarted = time.Now() +} + +// QuotaExceeded returns a channel that is closed when the log quota is exceeded. +func (w *appendWriter) QuotaExceeded() <-chan struct{} { + return w.quotaExceeded +} + func (w *appendWriter) Write(p []byte) (n int, err error) { w.mu.Lock() defer w.mu.Unlock() + if w.quota > 0 { + hour := int(math.Floor(time.Since(w.timeStarted).Hours())) + if hour != w.currentHour { + w.bytesInHour = 0 + w.currentHour = hour + } + if w.bytesInHour+len(p) > w.quota { + w.exceededOnce.Do(func() { close(w.quotaExceeded) }) + return 0, ErrLogQuotaExceeded + } + w.bytesInHour += len(p) + } + pCopy := make([]byte, len(p)) copy(pCopy, p) w.history = append(w.history, schemas.LogEvent{Message: pCopy, Timestamp: w.timestamp.Next()}) diff --git a/runner/internal/runner/schemas/schemas.go b/runner/internal/runner/schemas/schemas.go index ca707db76..47706228c 100644 --- a/runner/internal/runner/schemas/schemas.go +++ b/runner/internal/runner/schemas/schemas.go @@ -36,6 +36,7 @@ type SubmitBody struct { ClusterInfo ClusterInfo `json:"cluster_info"` Secrets map[string]string `json:"secrets"` RepoCredentials *RepoCredentials `json:"repo_credentials"` + LogQuotaHour int `json:"log_quota_hour"` // bytes per hour, 0 = unlimited } type PullResponse struct { diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index bd1307df7..fdb7b58cd 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -151,6 +151,7 @@ class JobTerminationReason(str, Enum): CREATING_CONTAINER_ERROR = "creating_container_error" EXECUTOR_ERROR = "executor_error" MAX_DURATION_EXCEEDED = "max_duration_exceeded" + LOG_QUOTA_EXCEEDED = "log_quota_exceeded" def to_status(self) -> JobStatus: mapping = { @@ -173,6 +174,7 @@ def to_status(self) -> JobStatus: self.CREATING_CONTAINER_ERROR: JobStatus.FAILED, self.EXECUTOR_ERROR: JobStatus.FAILED, self.MAX_DURATION_EXCEEDED: JobStatus.TERMINATED, + self.LOG_QUOTA_EXCEEDED: JobStatus.FAILED, } return mapping[self] @@ -205,6 +207,7 @@ def to_error(self) -> Optional[str]: JobTerminationReason.CREATING_CONTAINER_ERROR: "runner error", JobTerminationReason.EXECUTOR_ERROR: "executor error", JobTerminationReason.MAX_DURATION_EXCEEDED: "max duration exceeded", + JobTerminationReason.LOG_QUOTA_EXCEEDED: "log quota exceeded", } return error_mapping.get(self) diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index 43e9ddbb8..549ff7914 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -103,6 +103,8 @@ class SubmitBody(CoreModel): cluster_info: Annotated[Optional[ClusterInfo], Field(include=True)] secrets: Annotated[Optional[Dict[str, str]], Field(include=True)] repo_credentials: Annotated[Optional[RemoteRepoCreds], Field(include=True)] + log_quota_hour: Annotated[Optional[int], Field(include=True)] = None + """Maximum bytes of log output per hour. None means unlimited.""" # TODO: remove `run_spec` once instances deployed with 0.19.8 or earlier are no longer supported. run_spec: Annotated[ RunSpec, diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index c31726e76..4b78eefee 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -15,6 +15,7 @@ from dstack._internal.core.models.resources import Memory from dstack._internal.core.models.runs import ClusterInfo, Job, Run from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint +from dstack._internal.server import settings as server_settings from dstack._internal.server.schemas.instances import InstanceCheck from dstack._internal.server.schemas.runner import ( ComponentInfo, @@ -93,6 +94,7 @@ def submit_job( merged_env.update(job_spec.env) job_spec = job_spec.copy(deep=True) job_spec.env = merged_env + quota = server_settings.SERVER_LOG_QUOTA_PER_JOB_HOUR body = SubmitBody( run=run, job_spec=job_spec, @@ -100,6 +102,7 @@ def submit_job( cluster_info=cluster_info, secrets=secrets, repo_credentials=repo_credentials, + log_quota_hour=quota if quota > 0 else None, run_spec=run.run_spec, ) resp = requests.post( diff --git a/src/dstack/_internal/server/settings.py b/src/dstack/_internal/server/settings.py index 71a43a30b..01216cff3 100644 --- a/src/dstack/_internal/server/settings.py +++ b/src/dstack/_internal/server/settings.py @@ -133,6 +133,11 @@ SERVER_TEMPLATES_REPO = os.getenv("DSTACK_SERVER_TEMPLATES_REPO") +# Per-job log quota: maximum bytes of log output per calendar hour. 0 = unlimited. +SERVER_LOG_QUOTA_PER_JOB_HOUR = int( + os.getenv("DSTACK_SERVER_LOG_QUOTA_PER_JOB_HOUR", 50 * 1024 * 1024) # 50 MB +) + # Development settings SQL_ECHO_ENABLED = os.getenv("DSTACK_SQL_ECHO_ENABLED") is not None