diff --git a/Makefile b/Makefile index ca63ed6..0273485 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ BINARY_NAME=sitectl DOCS_PORT ?= 3000 -INSTALL_DIR ?= /usr/local/bin +INSTALL_DIR ?= $(or $(dir $(shell which $(BINARY_NAME) 2>/dev/null)),/usr/local/bin/) deps: go get . @@ -12,7 +12,7 @@ build: deps go build -o $(BINARY_NAME) . install: build - sudo cp $(BINARY_NAME) $(INSTALL_DIR)/$(BINARY_NAME) + sudo cp $(BINARY_NAME) $(INSTALL_DIR)$(BINARY_NAME) @if [ -d ../sitectl-isle ]; then $(MAKE) -C ../sitectl-isle install; fi @if [ -d ../sitectl-drupal ]; then $(MAKE) -C ../sitectl-drupal install; fi diff --git a/cmd/component.go b/cmd/component.go index 8958257..b8d8c89 100644 --- a/cmd/component.go +++ b/cmd/component.go @@ -27,15 +27,21 @@ var ( componentSetDisposition string componentSetTLSMode string componentSetYolo bool - invokePluginCommand = func(pluginName string, args []string) error { + invokePluginCommand = func(pluginName, contextName string, args []string) error { installed, ok := plugin.FindInstalled(pluginName) if !ok { return fmt.Errorf("plugin %q is not installed", pluginName) } - _, err := pluginSDK.InvokePluginCommand(installed.Name, args, plugin.CommandExecOptions{ - Stdin: RootCmd.InOrStdin(), - Stdout: RootCmd.OutOrStdout(), - Stderr: RootCmd.ErrOrStderr(), + invocation := make([]string, 0, len(args)+2) + if strings.TrimSpace(contextName) != "" { + invocation = append(invocation, "--context", contextName) + } + invocation = append(invocation, args...) + _, err := pluginSDK.InvokePluginCommand(installed.Name, invocation, plugin.CommandExecOptions{ + Context: RootCmd.Context(), + Stdin: RootCmd.InOrStdin(), + Stdout: RootCmd.OutOrStdout(), + Stderr: RootCmd.ErrOrStderr(), }) return err } @@ -51,7 +57,7 @@ var componentDescribeCmd = &cobra.Command{ Aliases: []string{"status"}, Short: "Describe the current component state", RunE: func(cmd *cobra.Command, args []string) error { - owner, name, err := resolveComponentOwner(cmd, componentDescribeName) + contextName, owner, name, err := resolveComponentOwner(cmd, componentDescribeName) if err != nil { return err } @@ -73,7 +79,7 @@ var componentDescribeCmd = &cobra.Command{ invocation = append(invocation, "--format", componentDescribeFormat) } - return invokePluginCommand(owner, invocation) + return invokePluginCommand(owner, contextName, invocation) }, } @@ -82,7 +88,7 @@ var componentReconcileCmd = &cobra.Command{ Aliases: []string{"review", "align"}, Short: "Review and reconcile component state", RunE: func(cmd *cobra.Command, args []string) error { - owner, name, err := resolveComponentOwner(cmd, componentReconcileName) + contextName, owner, name, err := resolveComponentOwner(cmd, componentReconcileName) if err != nil { return err } @@ -107,7 +113,7 @@ var componentReconcileCmd = &cobra.Command{ invocation = append(invocation, "--format", componentReconcileFormat) } - return invokePluginCommand(owner, invocation) + return invokePluginCommand(owner, contextName, invocation) }, } @@ -116,7 +122,7 @@ var componentSetCmd = &cobra.Command{ Short: "Set a component disposition", Args: cobra.RangeArgs(1, 2), RunE: func(cmd *cobra.Command, args []string) error { - owner, name, err := resolveComponentOwner(cmd, args[0]) + contextName, owner, name, err := resolveComponentOwner(cmd, args[0]) if err != nil { return err } @@ -144,7 +150,7 @@ var componentSetCmd = &cobra.Command{ invocation = append(invocation, "--yolo") } - return invokePluginCommand(owner, invocation) + return invokePluginCommand(owner, contextName, invocation) }, } @@ -179,15 +185,15 @@ func init() { var pluginSDK *plugin.SDK -func resolveComponentOwner(cmd *cobra.Command, raw string) (string, string, error) { +func resolveComponentOwner(cmd *cobra.Command, raw string) (string, string, string, error) { contextName, err := config.ResolveCurrentContextName(cmd.Flags()) if err != nil { - return "", "", err + return "", "", "", err } ctx, err := config.GetContext(contextName) if err != nil { - return "", "", err + return "", "", "", err } owner := ctx.Plugin @@ -197,12 +203,12 @@ func resolveComponentOwner(cmd *cobra.Command, raw string) (string, string, erro name = componentName } if strings.TrimSpace(owner) == "" { - return "", "", fmt.Errorf("context %q does not define a plugin owner", ctx.Name) + return "", "", "", fmt.Errorf("context %q does not define a plugin owner", ctx.Name) } if owner == "core" { - return "", "", fmt.Errorf("context %q uses plugin %q; component commands require a stack plugin such as isle", ctx.Name, owner) + return "", "", "", fmt.Errorf("context %q uses plugin %q; component commands require a stack plugin such as isle", ctx.Name, owner) } - return owner, name, nil + return contextName, owner, name, nil } func splitNamespacedComponent(raw string) (string, string, bool) { diff --git a/cmd/component_test.go b/cmd/component_test.go index 3840090..f2851e9 100644 --- a/cmd/component_test.go +++ b/cmd/component_test.go @@ -38,10 +38,13 @@ func TestResolveComponentOwnerUsesNamespace(t *testing.T) { t.Fatalf("Set(context) error = %v", err) } - owner, name, err := resolveComponentOwner(cmd, "drupal/modules") + contextName, owner, name, err := resolveComponentOwner(cmd, "drupal/modules") if err != nil { t.Fatalf("resolveComponentOwner() error = %v", err) } + if contextName != "museum" { + t.Fatalf("unexpected context name: %q", contextName) + } if owner != "drupal" || name != "modules" { t.Fatalf("unexpected owner/name: %q %q", owner, name) } @@ -68,10 +71,13 @@ func TestResolveComponentOwnerFallsBackToContextPlugin(t *testing.T) { t.Fatalf("Set(context) error = %v", err) } - owner, name, err := resolveComponentOwner(cmd, "fcrepo") + contextName, owner, name, err := resolveComponentOwner(cmd, "fcrepo") if err != nil { t.Fatalf("resolveComponentOwner() error = %v", err) } + if contextName != "museum" { + t.Fatalf("unexpected context name: %q", contextName) + } if owner != "isle" || name != "fcrepo" { t.Fatalf("unexpected owner/name: %q %q", owner, name) } diff --git a/cmd/compose.go b/cmd/compose.go index 512265e..561f5c2 100644 --- a/cmd/compose.go +++ b/cmd/compose.go @@ -124,7 +124,7 @@ Examples: cmdArgs = append(cmdArgs, filteredArgs...) c := exec.Command("docker", cmdArgs...) c.Dir = context.ProjectDir - _, err = context.RunCommand(c) + _, err = context.RunCommandContext(cmd.Context(), c) if err != nil { return err } diff --git a/cmd/config.go b/cmd/config.go index 690bd82..ca88036 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -774,8 +774,6 @@ func promptRemoteEnvironmentContext(localCtx, previousRemote *config.Context) (* localCtx.ComposeNetwork, localCtx.EffectiveComposeNetwork(), ) - runSudo := remoteContextBool(previousRemote, func(ctx *config.Context) bool { return ctx.RunSudo }, localCtx.RunSudo) - for { hostname, err = promptRequiredValueWithDefault("Remote hostname/domain (e.g. stage.example.com)", hostname) if err != nil { @@ -808,7 +806,6 @@ func promptRemoteEnvironmentContext(localCtx, previousRemote *config.Context) (* SSHUser: sshUser, SSHPort: sshPort, SSHKeyPath: sshKey, - RunSudo: localCtx.RunSudo, DockerSocket: dockerSocket, ComposeFile: append([]string{}, localCtx.ComposeFile...), EnvFile: append([]string{}, localCtx.EnvFile...), @@ -820,7 +817,6 @@ func promptRemoteEnvironmentContext(localCtx, previousRemote *config.Context) (* remoteCtx.ProjectName = projectName remoteCtx.ComposeProjectName = composeProjectName remoteCtx.ComposeNetwork = composeNetwork - remoteCtx.RunSudo = runSudo if detected := config.DetectContextComposeNetwork(remoteCtx); detected != "" { remoteCtx.ComposeNetwork = detected } @@ -911,37 +907,6 @@ func promptUintWithDefault(label string, defaultValue uint) (uint, error) { return uint(parsed), nil } -func promptBooleanChoice(label string, defaultValue bool) (bool, error) { - defaultChoice := "no" - if defaultValue { - defaultChoice = "yes" - } - value, err := createConfigPromptChoice( - strings.ToLower(strings.ReplaceAll(label, " ", "-")), - []corecomponent.Choice{ - { - Value: "yes", - Label: "yes", - Help: label, - Aliases: []string{"y", "1"}, - }, - { - Value: "no", - Label: "no", - Help: "Do not use sudo.", - Aliases: []string{"n", "2"}, - }, - }, - defaultChoice, - createConfigInput, - strings.Split(corecomponent.RenderSection(label, label+"?"), "\n")..., - ) - if err != nil { - return false, err - } - return strings.TrimSpace(value) == "yes", nil -} - func validateRemoteDockerAccess(ctx *config.Context) error { if ctx == nil || ctx.DockerHostType != config.ContextRemote { return nil @@ -988,16 +953,11 @@ func validateRemoteDockerAccess(ctx *config.Context) error { if promptErr != nil { return promptErr } - runSudo, promptErr := promptBooleanChoice("Run Docker commands with sudo", ctx.RunSudo) - if promptErr != nil { - return promptErr - } ctx.ProjectDir = projectDir ctx.ProjectName = projectName ctx.ComposeProjectName = firstNonEmptyString(ctx.ComposeProjectName, projectName) ctx.ComposeNetwork = firstNonEmptyString(config.DetectContextComposeNetwork(ctx), ctx.ComposeNetwork, ctx.EffectiveComposeNetwork()) ctx.DockerSocket = dockerSocket - ctx.RunSudo = runSudo continue } return nil @@ -1034,13 +994,6 @@ func remoteContextUint(ctx *config.Context, getter func(*config.Context) uint, f return fallback } -func remoteContextBool(ctx *config.Context, getter func(*config.Context) bool, fallback bool) bool { - if ctx == nil || getter == nil { - return fallback - } - return getter(ctx) -} - func suggestedEnvironmentContextName(localCtx *config.Context, environment string) string { base := strings.TrimSpace(environment) if localCtx == nil { diff --git a/cmd/config_create_test.go b/cmd/config_create_test.go index 71b8ece..3bc22db 100644 --- a/cmd/config_create_test.go +++ b/cmd/config_create_test.go @@ -740,9 +740,6 @@ func TestRunCreateConfigRepromptsDockerSettingsAfterComposePSFailure(t *testing. if ctx.DockerSocket != "/run/user/1000/docker.sock" { t.Fatalf("expected updated docker socket, got %q", ctx.DockerSocket) } - if !ctx.RunSudo { - t.Fatal("expected updated run sudo true") - } if ctx.ProjectName != "museum-prod" { t.Fatalf("expected updated project name museum-prod, got %q", ctx.ProjectName) } @@ -762,8 +759,6 @@ func TestRunCreateConfigRepromptsDockerSettingsAfterComposePSFailure(t *testing. return "no", nil case "update-environment-context": return "update", nil - case "run-docker-commands-with-sudo": - return "yes", nil default: t.Fatalf("unexpected choice prompt: %s", name) return "", nil @@ -811,9 +806,6 @@ func TestRunCreateConfigRepromptsDockerSettingsAfterComposePSFailure(t *testing. if remoteCtx.Plugin != "core" { t.Fatalf("expected saved plugin core, got %q", remoteCtx.Plugin) } - if !remoteCtx.RunSudo { - t.Fatal("expected saved run sudo true") - } } func TestInheritNewContextDefaultsFromActive(t *testing.T) { diff --git a/cmd/debug.go b/cmd/debug.go index 9d01157..f0981fe 100644 --- a/cmd/debug.go +++ b/cmd/debug.go @@ -2,15 +2,15 @@ package cmd import ( "context" - "errors" "fmt" + "io" "log/slog" "os" - "os/exec" "regexp" "sort" "strconv" "strings" + "sync" "time" "charm.land/lipgloss/v2" @@ -18,7 +18,6 @@ import ( "github.com/docker/docker/api/types/filters" dockerimage "github.com/docker/docker/api/types/image" "github.com/docker/docker/client" - "github.com/kballard/go-shellquote" "github.com/libops/sitectl/pkg/config" "github.com/libops/sitectl/pkg/docker" "github.com/libops/sitectl/pkg/plugin" @@ -28,6 +27,7 @@ import ( var debugOutputPath string var debugVerbose bool +var debugProgressUIActive bool var ansiPattern = regexp.MustCompile(`\x1b\[[0-9;]*m`) @@ -72,37 +72,67 @@ var debugCmd = &cobra.Command{ if err != nil { return err } - - var body strings.Builder - body.WriteString(renderCoreDebug(ctx)) - - if pluginName := strings.TrimSpace(ctx.Plugin); pluginName != "" && pluginName != "core" { - pluginArgs := []string{"__debug"} - if debugVerbose { - pluginArgs = append(pluginArgs, "--verbose") - } - output, err := pluginSDK.InvokePluginCommand(pluginName, pluginArgs, plugin.CommandExecOptions{Capture: true}) + reporter := debugProgressReporter(nil) + if stderrFile, ok := cmd.ErrOrStderr().(*os.File); ok && term.IsTerminal(int(stderrFile.Fd())) { + report, err := runDebugCollectionWithProgress(cmd, contextName, ctx) if err != nil { return err } - if trimmed := strings.TrimSpace(output); trimmed != "" { - body.WriteString("\n\n") - body.WriteString(trimmed) - } + return writeDebugReport(cmd, report) } - if strings.TrimSpace(debugOutputPath) != "" { - report := renderPlainDebugReport(body.String()) - if err := os.WriteFile(debugOutputPath, []byte(report+"\n"), 0o644); err != nil { - return err - } - _, err = fmt.Fprintf(cmd.OutOrStdout(), "wrote debug bundle to %s\n", debugOutputPath) + report, err := collectDebugReport(cmd.Context(), contextName, ctx, reporter) + if err != nil { return err } + return writeDebugReport(cmd, report) + }, +} - _, err = fmt.Fprintln(cmd.OutOrStdout(), body.String()) +func writeDebugReport(cmd *cobra.Command, report string) error { + if strings.TrimSpace(debugOutputPath) != "" { + report = renderPlainDebugReport(report) + if err := os.WriteFile(debugOutputPath, []byte(report+"\n"), 0o644); err != nil { + return err + } + _, err := fmt.Fprintf(cmd.OutOrStdout(), "wrote debug bundle to %s\n", debugOutputPath) return err - }, + } + + _, err := fmt.Fprintln(cmd.OutOrStdout(), report) + return err +} + +func collectDebugReport(runCtx context.Context, contextName string, ctx config.Context, reporter debugProgressReporter) (string, error) { + if err := runCtx.Err(); err != nil { + return "", err + } + var body strings.Builder + reportProgress(reporter, "Collecting Core Diagnostics", "Inspecting Docker configuration, logs, and images") + body.WriteString(renderCoreDebug(runCtx, ctx)) + + if pluginName := strings.TrimSpace(ctx.Plugin); pluginName != "" && pluginName != "core" { + if err := runCtx.Err(); err != nil { + return "", err + } + pluginArgs := []string{"--context", contextName, "__debug"} + if debugVerbose { + pluginArgs = append(pluginArgs, "--verbose") + } + reportProgress(reporter, "Collecting Plugin Diagnostics", fmt.Sprintf("Running %s debug collectors", pluginName)) + slog.Debug("handing off debug to plugin", "context", contextName, "plugin", pluginName, "args", pluginArgs) + output, err := pluginSDK.InvokePluginCommand(pluginName, pluginArgs, plugin.CommandExecOptions{Context: runCtx, Capture: true, LiveStderr: !progressEnabled()}) + if err != nil { + return "", err + } + slog.Debug("plugin debug completed", "context", contextName, "plugin", pluginName) + if trimmed := strings.TrimSpace(output); trimmed != "" { + body.WriteString("\n\n") + body.WriteString(trimmed) + } + } + + return body.String(), nil } func init() { @@ -111,7 +141,26 @@ func init() { RootCmd.AddCommand(debugCmd) } -func renderCoreDebug(ctx config.Context) string { +func runDebugCollectionWithProgress(cmd *cobra.Command, contextName string, ctx config.Context) (string, error) { + debugProgressUIActive = true + defer func() { debugProgressUIActive = false }() + progress := newDebugProgressLine(cmd.ErrOrStderr()) + defer progress.Close() + return collectDebugReport(cmd.Context(), contextName, ctx, progress.Report) +} + +func reportProgress(reporter debugProgressReporter, title, detail string) { + if reporter != nil { + reporter(title, detail) + } +} + +func progressEnabled() bool { + return debugProgressUIActive +} + +func renderCoreDebug(runCtx context.Context, ctx config.Context) string { + slog.Debug("starting core debug", "context", ctx.Name, "docker_host_type", ctx.DockerHostType) meta := []debugRow{ {Label: "Generated", Value: time.Now().UTC().Format(time.RFC3339)}, {Label: "Context", Value: ctx.Name}, @@ -138,31 +187,35 @@ func renderCoreDebug(ctx config.Context) string { "", formatDebugRows(meta), } - if diagnostics, err := collectLogDiagnostics(&ctx); err == nil { - coreBody = append(coreBody, "", debugDivider(), "", debugTitleStyle.Render("Log Summary"), "", formatDebugRows(logSummaryRows(diagnostics))) + logDiagnostics, logErr, imageDiagnostics, imageErr := collectCoreDockerDiagnostics(runCtx, &ctx) + if logErr == nil { + slog.Debug("collected log diagnostics", "context", ctx.Name, "containers", len(logDiagnostics.Containers)) + coreBody = append(coreBody, "", debugDivider(), "", debugTitleStyle.Render("Log Summary"), "", formatDebugRows(logSummaryRows(logDiagnostics))) if debugVerbose { - coreBody = append(coreBody, "", debugDivider(), "", debugTitleStyle.Render("Log Details"), "", renderLogDetailsBody(diagnostics)) + coreBody = append(coreBody, "", debugDivider(), "", debugTitleStyle.Render("Log Details"), "", renderLogDetailsBody(logDiagnostics)) } } else { + slog.Debug("log diagnostics failed", "context", ctx.Name, "error", logErr) coreBody = append(coreBody, "", debugDivider(), "", debugTitleStyle.Render("Log Summary"), "", formatDebugRows([]debugRow{ {Label: "Log status", Value: renderStatus("warning")}, - {Label: "Log diagnostics", Value: err.Error()}, + {Label: "Log diagnostics", Value: logErr.Error()}, })) } - if diagnostics, err := collectImageDiagnostics(&ctx); err == nil { - coreBody = append(coreBody, "", debugDivider(), "", debugTitleStyle.Render("Image Summary"), "", formatDebugRows(imageSummaryRows(diagnostics))) + if imageErr == nil { + slog.Debug("collected image diagnostics", "context", ctx.Name, "images", imageDiagnostics.ImageCount, "total_bytes", imageDiagnostics.TotalBytes) + coreBody = append(coreBody, "", debugDivider(), "", debugTitleStyle.Render("Image Summary"), "", formatDebugRows(imageSummaryRows(imageDiagnostics))) } else { + slog.Debug("image diagnostics failed", "context", ctx.Name, "error", imageErr) coreBody = append(coreBody, "", debugDivider(), "", debugTitleStyle.Render("Image Summary"), "", formatDebugRows([]debugRow{ {Label: "Image status", Value: renderStatus("warning")}, - {Label: "Image diagnostics", Value: err.Error()}, + {Label: "Image diagnostics", Value: imageErr.Error()}, })) } + slog.Debug("finished core debug", "context", ctx.Name) return renderDebugPanel("sitectl", strings.Join(coreBody, "\n")) } type logDiagnostics struct { - TotalBytes int64 - KnownSize bool Containers []containerLogDiagnostics UnboundedCount int ExternalDriverCount int @@ -172,9 +225,6 @@ type containerLogDiagnostics struct { Service string Container string Driver string - LogPath string - SizeBytes int64 - HasSize bool Rotated bool External bool RotationHint string @@ -185,33 +235,125 @@ type imageDiagnostics struct { ImageCount int } -func collectLogDiagnostics(ctxCfg *config.Context) (logDiagnostics, error) { - cli, err := docker.GetDockerCli(ctxCfg) - if err != nil { - return logDiagnostics{}, err +type debugProgressReporter func(title, detail string) + +type debugProgressLine struct { + out *os.File + frames []string + index int + title string + detail string + mu sync.Mutex + done chan struct{} + once sync.Once +} + +func newDebugProgressLine(w io.Writer) *debugProgressLine { + file, ok := w.(*os.File) + if !ok { + return &debugProgressLine{frames: []string{".", "o", "O", "o"}} } - defer cli.Close() + progress := &debugProgressLine{ + out: file, + frames: []string{"-", "\\", "|", "/"}, + title: "Preparing Debug Bundle", + detail: "Starting diagnostic collection", + done: make(chan struct{}), + } + go progress.animate(120 * time.Millisecond) + return progress +} + +func (p *debugProgressLine) Report(title, detail string) { + if p == nil { + return + } + p.mu.Lock() + p.title = strings.TrimSpace(title) + p.detail = strings.TrimSpace(detail) + p.renderLocked() + p.mu.Unlock() +} + +func (p *debugProgressLine) Close() { + if p == nil || p.out == nil { + return + } + p.once.Do(func() { + close(p.done) + p.mu.Lock() + defer p.mu.Unlock() + fmt.Fprint(p.out, "\r\033[2K") + }) +} + +func (p *debugProgressLine) animate(interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + p.mu.Lock() + p.renderLocked() + p.mu.Unlock() + case <-p.done: + return + } + } +} +func (p *debugProgressLine) renderLocked() { + if p.out == nil { + return + } + frame := p.frames[p.index%len(p.frames)] + p.index++ + line := fmt.Sprintf("\r%s %s", frame, strings.TrimSpace(strings.Join([]string{p.title, p.detail}, " - "))) + fmt.Fprint(p.out, truncateDebugProgress(line)) +} + +func truncateDebugProgress(line string) string { + width := debugPanelWidth() + if width <= 0 { + return line + } + plain := ansiPattern.ReplaceAllString(line, "") + if lipgloss.Width(plain) <= width { + return line + } + runes := []rune(plain) + if len(runes) <= width { + return string(runes) + } + if width <= 1 { + return string(runes[:width]) + } + return string(runes[:width-1]) + "…" +} + +func collectLogDiagnosticsWithClient(runCtx context.Context, ctxCfg *config.Context, cli *docker.DockerClient) (logDiagnostics, error) { filterArgs := filters.NewArgs() filterArgs.Add("label", "com.docker.compose.project="+ctxCfg.EffectiveComposeProjectName()) - containers, err := cli.CLI.ContainerList(context.Background(), dockercontainer.ListOptions{ + containers, err := cli.CLI.ContainerList(runCtx, dockercontainer.ListOptions{ All: true, Filters: filterArgs, }) if err != nil { return logDiagnostics{}, err } + slog.Debug("listed containers for log diagnostics", "context", ctxCfg.Name, "count", len(containers)) diagnostics := logDiagnostics{ - KnownSize: true, Containers: make([]containerLogDiagnostics, 0, len(containers)), } - remotePaths := make([]string, 0, len(containers)) for _, summary := range containers { + if err := runCtx.Err(); err != nil { + return logDiagnostics{}, err + } name := trimContainerName(summary.Names) service := firstNonEmpty(summary.Labels["com.docker.compose.service"], name) - inspect, err := cli.CLI.ContainerInspect(context.Background(), name) + inspect, err := cli.CLI.ContainerInspect(runCtx, name) if err != nil { return logDiagnostics{}, err } @@ -223,60 +365,9 @@ func collectLogDiagnostics(ctxCfg *config.Context) (logDiagnostics, error) { if !item.Rotated && !item.External { diagnostics.UnboundedCount++ } - if item.LogPath != "" && ctxCfg.DockerHostType != config.ContextLocal { - remotePaths = append(remotePaths, item.LogPath) - } diagnostics.Containers = append(diagnostics.Containers, item) } - if ctxCfg.DockerHostType == config.ContextLocal { - for i := range diagnostics.Containers { - item := &diagnostics.Containers[i] - if item.LogPath == "" { - diagnostics.KnownSize = false - continue - } - size, hasSize, err := logFileSizeLocal(item.LogPath) - if err != nil { - item.RotationHint = appendHint(item.RotationHint, fmt.Sprintf("unable to stat log file: %v", err)) - diagnostics.KnownSize = false - continue - } - item.SizeBytes = size - item.HasSize = hasSize - if hasSize { - diagnostics.TotalBytes += size - } else { - diagnostics.KnownSize = false - } - } - } else if len(remotePaths) > 0 { - sizes, err := logFileSizesRemote(ctxCfg, remotePaths) - if err != nil { - diagnostics.KnownSize = false - for i := range diagnostics.Containers { - if diagnostics.Containers[i].LogPath == "" { - continue - } - diagnostics.Containers[i].RotationHint = appendHint(diagnostics.Containers[i].RotationHint, fmt.Sprintf("unable to stat log file: %v", err)) - } - } else { - for i := range diagnostics.Containers { - item := &diagnostics.Containers[i] - if item.LogPath != "" { - size, ok := sizes[item.LogPath] - if ok { - item.SizeBytes = size - item.HasSize = true - diagnostics.TotalBytes += size - continue - } - } - diagnostics.KnownSize = false - } - } - } - sort.Slice(diagnostics.Containers, func(i, j int) bool { return diagnostics.Containers[i].Service < diagnostics.Containers[j].Service }) @@ -284,22 +375,17 @@ func collectLogDiagnostics(ctxCfg *config.Context) (logDiagnostics, error) { return diagnostics, nil } -func collectImageDiagnostics(ctxCfg *config.Context) (imageDiagnostics, error) { - cli, err := docker.GetDockerCli(ctxCfg) - if err != nil { - return imageDiagnostics{}, err - } - defer cli.Close() - +func collectImageDiagnosticsWithClient(runCtx context.Context, ctxCfg *config.Context, cli *docker.DockerClient) (imageDiagnostics, error) { apiClient, ok := cli.CLI.(*client.Client) if !ok { return imageDiagnostics{}, fmt.Errorf("docker client does not support image listing") } - images, err := apiClient.ImageList(context.Background(), dockerimage.ListOptions{All: true}) + images, err := apiClient.ImageList(runCtx, dockerimage.ListOptions{All: true}) if err != nil { return imageDiagnostics{}, err } + slog.Debug("listed images", "context", ctxCfg.Name, "count", len(images)) diagnostics := imageDiagnostics{ImageCount: len(images)} for _, image := range images { @@ -311,6 +397,21 @@ func collectImageDiagnostics(ctxCfg *config.Context) (imageDiagnostics, error) { return diagnostics, nil } +func collectCoreDockerDiagnostics(runCtx context.Context, ctxCfg *config.Context) (logDiagnostics, error, imageDiagnostics, error) { + slog.Debug("opening shared docker client for core diagnostics", "context", ctxCfg.Name) + cli, err := docker.GetDockerCli(ctxCfg) + if err != nil { + return logDiagnostics{}, err, imageDiagnostics{}, err + } + defer cli.Close() + + slog.Debug("collecting log diagnostics", "context", ctxCfg.Name) + logs, logErr := collectLogDiagnosticsWithClient(runCtx, ctxCfg, cli) + slog.Debug("collecting image diagnostics", "context", ctxCfg.Name) + images, imageErr := collectImageDiagnosticsWithClient(runCtx, ctxCfg, cli) + return logs, logErr, images, imageErr +} + func imageSummaryRows(diagnostics imageDiagnostics) []debugRow { state := "ok" rows := []debugRow{ @@ -333,7 +434,6 @@ func describeContainerLogs(service, containerName string, inspect dockercontaine item := containerLogDiagnostics{ Service: service, Container: containerName, - LogPath: strings.TrimSpace(inspect.LogPath), } if inspect.HostConfig != nil { item.Driver = strings.TrimSpace(inspect.HostConfig.LogConfig.Type) @@ -367,114 +467,8 @@ func evaluateLogConfig(driver string, options map[string]string) (rotated bool, } } -func logFileSizeLocal(path string) (int64, bool, error) { - if strings.TrimSpace(path) == "" { - return 0, false, nil - } - slog.Debug("logFileSizeLocal", "path", path) - - info, err := os.Stat(path) - if err != nil { - if os.IsNotExist(err) { - return 0, false, nil - } - if errors.Is(err, os.ErrPermission) { - size, sudoErr := logFileSizeLocalSudo(path) - if sudoErr == nil { - return size, true, nil - } - return 0, false, fmt.Errorf("%w; sudo stat failed: %v", err, sudoErr) - } - return 0, false, err - } - return info.Size(), true, nil -} - -func logFileSizeLocalSudo(path string) (int64, error) { - cmd := exec.Command("sudo", "-n", "sh", "-lc", fmt.Sprintf("test -f %s && wc -c < %s || true", shellquote.Join(path), shellquote.Join(path))) - slog.Debug(cmd.String()) - output, err := cmd.CombinedOutput() - if err != nil { - return 0, err - } - text := strings.TrimSpace(string(output)) - if text == "" { - return 0, nil - } - return strconv.ParseInt(text, 10, 64) -} - -func logFileSizesRemote(ctxCfg *config.Context, paths []string) (map[string]int64, error) { - uniquePaths := make([]string, 0, len(paths)) - seen := map[string]bool{} - for _, path := range paths { - path = strings.TrimSpace(path) - if path == "" || seen[path] { - continue - } - seen[path] = true - uniquePaths = append(uniquePaths, path) - } - if len(uniquePaths) == 0 { - return map[string]int64{}, nil - } - - client, err := ctxCfg.DialSSH() - if err != nil { - return nil, err - } - defer client.Close() - - session, err := client.NewSession() - if err != nil { - return nil, err - } - defer session.Close() - - parts := make([]string, 0, len(uniquePaths)) - for _, path := range uniquePaths { - quoted := shellquote.Join(path) - parts = append(parts, fmt.Sprintf("if test -f %s; then printf '%%s\\t' %s; stat -c %%s %s; fi", quoted, quoted, quoted)) - } - cmd := strings.Join(parts, "; ") - if ctxCfg.RunSudo { - cmd = "sudo -n sh -lc " + shellquote.Join(cmd) - } - output, err := session.CombinedOutput(cmd) - if err != nil { - return nil, err - } - - sizes := map[string]int64{} - for _, line := range strings.Split(strings.TrimSpace(string(output)), "\n") { - line = strings.TrimSpace(line) - if line == "" { - continue - } - path, rawSize, ok := strings.Cut(line, "\t") - if !ok { - continue - } - size, err := strconv.ParseInt(strings.TrimSpace(rawSize), 10, 64) - if err != nil { - return nil, err - } - sizes[strings.TrimSpace(path)] = size - } - return sizes, nil -} - func logSummaryRows(diagnostics logDiagnostics) []debugRow { - totalLine := "unknown" - if diagnostics.KnownSize { - totalLine = humanBytes(diagnostics.TotalBytes) - } totalState := "ok" - if !diagnostics.KnownSize { - totalState = "warning" - } else if diagnostics.TotalBytes >= 1<<30 { - totalState = "warning" - } logHandling := "file-backed container logs appear capped" if diagnostics.UnboundedCount == 0 { @@ -491,18 +485,12 @@ func logSummaryRows(diagnostics logDiagnostics) []debugRow { rows := []debugRow{ {Label: "Log status", Value: renderStatus(totalState)}, - {Label: "Total logs", Value: totalLine}, {Label: "Log handling", Value: logHandling}, } - if !diagnostics.KnownSize { - rows = append(rows, debugRow{Label: "Note", Value: "unable to determine one or more container log file sizes"}) - } else if diagnostics.TotalBytes >= 1<<30 { - rows = append(rows, debugRow{Label: "Note", Value: "aggregate container logs exceed 1 GiB"}) - } if diagnostics.UnboundedCount > 0 { rows = append(rows, debugRow{ Label: "Recommendation", - Value: `configure Docker log rotation with max-size and max-file, or ship logs to syslog, journald, or another central driver + Value: `for non-local environments, configure Docker log rotation with max-size and max-file, or ship logs to syslog, journald, or another central driver https://docs.docker.com/engine/logging/configure/`}) } @@ -513,9 +501,6 @@ func renderLogDetailsBody(diagnostics logDiagnostics) string { lines := []string{"Log details:"} for _, item := range diagnostics.Containers { line := fmt.Sprintf(" %s: driver=%s", item.Service, item.Driver) - if item.HasSize { - line += fmt.Sprintf(", size=%s", humanBytes(item.SizeBytes)) - } if item.External { line += ", external" } else if item.Rotated { @@ -650,16 +635,3 @@ func firstNonEmpty(values ...string) string { } return "" } - -func appendHint(current, next string) string { - current = strings.TrimSpace(current) - next = strings.TrimSpace(next) - switch { - case current == "": - return next - case next == "": - return current - default: - return current + "; " + next - } -} diff --git a/cmd/debug_test.go b/cmd/debug_test.go index 9822f69..a39c0f9 100644 --- a/cmd/debug_test.go +++ b/cmd/debug_test.go @@ -20,17 +20,15 @@ func TestEvaluateLogConfigDetectsUnboundedJSONFileLogs(t *testing.T) { func TestLogSummaryRowsIncludeRecommendationWhenLogsNeedAttention(t *testing.T) { rows := logSummaryRows(logDiagnostics{ - KnownSize: true, - TotalBytes: 25 * 1024 * 1024, UnboundedCount: 1, Containers: []containerLogDiagnostics{ - {Service: "drupal", Driver: "json-file", SizeBytes: 25 * 1024 * 1024, HasSize: true, Rotated: false, RotationHint: "file-backed logs are not capped; set max-size and max-file"}, + {Service: "drupal", Driver: "json-file", Rotated: false, RotationHint: "file-backed logs are not capped; set max-size and max-file"}, }, }) rendered := formatDebugRows(rows) - if !strings.Contains(rendered, "Total logs") || !strings.Contains(rendered, "25.0MiB") { - t.Fatalf("expected total log size, got:\n%s", rendered) + if strings.Contains(rendered, "Total logs") { + t.Fatalf("expected log summary without total log size, got:\n%s", rendered) } if !strings.Contains(rendered, "Recommendation") { t.Fatalf("expected recommendation guidance, got:\n%s", rendered) @@ -39,26 +37,22 @@ func TestLogSummaryRowsIncludeRecommendationWhenLogsNeedAttention(t *testing.T) func TestRenderLogDetailsBodyIncludesPerContainerRows(t *testing.T) { rendered := renderLogDetailsBody(logDiagnostics{ - KnownSize: true, - TotalBytes: 25 * 1024 * 1024, UnboundedCount: 1, Containers: []containerLogDiagnostics{ - {Service: "drupal", Driver: "json-file", SizeBytes: 25 * 1024 * 1024, HasSize: true, Rotated: false, RotationHint: "file-backed logs are not capped; set max-size and max-file"}, + {Service: "drupal", Driver: "json-file", Rotated: false, RotationHint: "file-backed logs are not capped; set max-size and max-file"}, }, }) - if !strings.Contains(rendered, "drupal: driver=json-file, size=25.0MiB, not rotated") { + if !strings.Contains(rendered, "drupal: driver=json-file, not rotated") { t.Fatalf("expected per-container detail, got:\n%s", rendered) } } func TestLogSummaryRowsStayCompact(t *testing.T) { rows := logSummaryRows(logDiagnostics{ - KnownSize: true, - TotalBytes: 25 * 1024 * 1024, UnboundedCount: 1, Containers: []containerLogDiagnostics{ - {Service: "drupal", Driver: "json-file", SizeBytes: 25 * 1024 * 1024, HasSize: true, Rotated: false}, + {Service: "drupal", Driver: "json-file", Rotated: false}, }, }) diff --git a/cmd/make.go b/cmd/make.go deleted file mode 100644 index d258bf1..0000000 --- a/cmd/make.go +++ /dev/null @@ -1,34 +0,0 @@ -package cmd - -import ( - "os/exec" - - "github.com/libops/sitectl/pkg/config" - "github.com/libops/sitectl/pkg/helpers" - "github.com/spf13/cobra" -) - -// makeCmd support deprecated custom make commands -var makeCmd = &cobra.Command{ - Use: "make", - Short: "Run custom make commands", - Args: cobra.ArbitraryArgs, - Run: func(cmd *cobra.Command, args []string) { - f := cmd.Flags() - context, err := config.CurrentContext(f) - if err != nil { - helpers.ExitOnError(err) - } - - c := exec.Command("make", args...) - c.Dir = context.ProjectDir - _, err = context.RunCommand(c) - if err != nil { - helpers.ExitOnError(err) - } - }, -} - -func init() { - RootCmd.AddCommand(makeCmd) -} diff --git a/cmd/port-forward.go b/cmd/port-forward.go index 8514fac..7113082 100644 --- a/cmd/port-forward.go +++ b/cmd/port-forward.go @@ -63,6 +63,7 @@ Be sure to run Ctrl+c in your terminal when you are done to close the connection listeners := make([]net.Listener, 0, len(args)) done := make(chan os.Signal, 1) signal.Notify(done, os.Interrupt, syscall.SIGHUP, syscall.SIGTERM) + defer signal.Stop(done) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -77,10 +78,16 @@ Be sure to run Ctrl+c in your terminal when you are done to close the connection if err != nil { return fmt.Errorf("invalid local port '%s': must be an integer", localPortStr) } + if localPort < 1 || localPort > 65535 { + return fmt.Errorf("invalid local port '%s': must be between 1 and 65535", localPortStr) + } remotePort, err := strconv.Atoi(remotePortStr) if err != nil { return fmt.Errorf("invalid remote port '%s': must be an integer", remotePortStr) } + if remotePort < 1 || remotePort > 65535 { + return fmt.Errorf("invalid remote port '%s': must be between 1 and 65535", remotePortStr) + } addr := fmt.Sprintf("localhost:%d", localPort) listener, err := net.Listen("tcp", addr) @@ -89,7 +96,7 @@ Be sure to run Ctrl+c in your terminal when you are done to close the connection } listeners = append(listeners, listener) - containerName, err := cli.GetContainerName(c, service) + containerName, err := cli.GetContainerNameContext(ctx, c, service) if err != nil { return err } @@ -100,7 +107,6 @@ Be sure to run Ctrl+c in your terminal when you are done to close the connection remoteEndpoint := fmt.Sprintf("%s:%d", serviceIp, remotePort) go func(listener net.Listener, lp, remoteAddr string) { - defer listener.Close() fmt.Printf("Forwarding localhost:%s -> %s via SSH\n", lp, remoteAddr) for { localConn, err := listener.Accept() diff --git a/cmd/root.go b/cmd/root.go index 2c10fc3..5ecc591 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "os" + "os/signal" "strings" "syscall" @@ -48,8 +49,10 @@ var RootCmd = &cobra.Command{ } func Execute() { + runCtx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() err := fang.Execute( - context.Background(), + runCtx, RootCmd, fang.WithVersion(RootCmd.Version), ) diff --git a/index.mdx b/index.mdx deleted file mode 100644 index 0dc52d2..0000000 --- a/index.mdx +++ /dev/null @@ -1,68 +0,0 @@ ---- -title: sitectl -description: Command line utility to interact with your local and remote Docker Compose sites. ---- - -import { Compose } from "/docs/snippets/compose-tooltip.mdx"; -import { TUI } from "/docs/snippets/tui-tooltip.mdx"; - -## Overview - -`sitectl` was made with LAC-GLAM institutions at top of mind. `sitectl` is a command line utility to operate your local and remote sites. - - -## Scaling Human Operators - -The philosophy behind `sitectl` is not to help scale operations in the traditional technological sense of the word, but rather to scale *human operators*. -As more institutions run their own instances of an OSS project, the resulting increase in contributors triggers a -virtuous cycle of growth. - -By making the operation of the Docker containers needed to run an application well-defined through common, repeatable patterns using the spec, `sitectl`'s value prop is: -* **Empower institutions:** Giving organizations the capability and confidence to reliably host the software they depend on without relying soley on a dedicated DevOps team. -* **Empower individual contributors:** Providing teams with solid, standardized tooling that eliminates environmental toil and lets them focus on the work that matters. - -## `sitectl` Features - - - - Use the for routine site setup, monitoring, and operator workflows. - - - Track local and remote environments so sitectl can understand where a site lives and how to reach it. - - - Add stack-specific behavior for common technologies without abandoning the core workflow. - - - Model reviewed stack defaults and operator choices in a more structured way than ad hoc notes. - - - -## Development - -See the [contributing guide](/docs/concepts/contributing.mdx) for the local core/plugin development workflow, including the chained `make install` target used during plugin development. - -## Why not just use Docker Contexts? - -While [Docker's native context feature](https://docs.docker.com/engine/manage-resources/contexts/) handles basic docker daemon connections, `sitectl` is purpose-built for projects and adds: - - - - SFTP file operations, sudo support, and clearer SSH error handling. - - - General helpers to do things like resolve service names to containers, extract secrets and env vars for `exec` commands, and inspect container network details. - - -first design} icon="code"> - Automatically set the equivalent of `DOCKER_HOST`, `COMPOSE_PROJECT_NAME`, `COMPOSE_FILE`, and `COMPOSE_ENV_FILES` from the active sitectl context. - - - -## Why not make kube operators? - -Though isn't designed for massive-scale orchestration, the applications hosted by most LAC-GLAM institutions rarely require more than modest scaling. -The real advantage of is the developer experience. Because the exact same orchestration runs in both development and production - with only minor environmental tweaks - you can reliably mirror production on your local machine. -This provides built-in deployment safety long before your CI pipeline runs a single test. - -We could have spent our resources building Kubernetes operators for various LAC-GLAM stacks instead of creating sitectl. -But sitectl was a deliberate choice: it empowers institutions to adopt open-source projects without the hurdle of hiring a k8s admin or absorbing the heavy operational overhead of a Kubernetes cluster. diff --git a/pkg/config/cmd.go b/pkg/config/cmd.go index d6cfc81..79b52a3 100644 --- a/pkg/config/cmd.go +++ b/pkg/config/cmd.go @@ -9,6 +9,7 @@ import ( "os" "os/exec" "strings" + "sync" "github.com/kballard/go-shellquote" "golang.org/x/crypto/ssh" @@ -16,19 +17,27 @@ import ( ) func (c *Context) RunCommand(cmd *exec.Cmd) (string, error) { - return c.runCommand(cmd, true) + return c.runCommandContext(context.Background(), cmd, true) } func (c *Context) RunQuietCommand(cmd *exec.Cmd) (string, error) { - return c.runCommand(cmd, false) + return c.runCommandContext(context.Background(), cmd, false) } -func (c *Context) runCommand(cmd *exec.Cmd, printOutput bool) (string, error) { - var output string +func (c *Context) RunCommandContext(ctx context.Context, cmd *exec.Cmd) (string, error) { + return c.runCommandContext(ctx, cmd, true) +} + +func (c *Context) RunQuietCommandContext(ctx context.Context, cmd *exec.Cmd) (string, error) { + return c.runCommandContext(ctx, cmd, false) +} + +func (c *Context) runCommandContext(ctx context.Context, cmd *exec.Cmd, printOutput bool) (string, error) { + runCtx, cancel := context.WithCancel(ctx) + defer cancel() + var output strings.Builder if c.DockerHostType == ContextLocal { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - cmd = exec.CommandContext(ctx, cmd.Path, cmd.Args[1:]...) + cmd = exec.CommandContext(runCtx, cmd.Path, cmd.Args[1:]...) cmd.Env = os.Environ() if printOutput { cmd.Stdin = os.Stdin @@ -52,7 +61,8 @@ func (c *Context) runCommand(cmd *exec.Cmd, printOutput bool) (string, error) { if printOutput { fmt.Println(line) } - output = strings.TrimSpace(line) + output.WriteString(line) + output.WriteString("\n") } if err := scanner.Err(); err != nil { slog.Error("Error reading stdout", "err", err) @@ -60,58 +70,64 @@ func (c *Context) runCommand(cmd *exec.Cmd, printOutput bool) (string, error) { if err := cmd.Wait(); err != nil { return "", fmt.Errorf("error waiting for command %s: %v", cmd.String(), err) } - return output, nil + return strings.TrimRight(output.String(), "\n"), nil } sshClient, err := c.DialSSH() if err != nil { return "", fmt.Errorf("error establishing SSH connection: %v", err) } - defer sshClient.Close() remoteCmd := fmt.Sprintf("cd %s && ", shellquote.Join(c.ProjectDir)) - if c.RunSudo { - remoteCmd += "sudo " - } remoteCmd += shellquote.Join(cmd.Args...) slog.Info("Running remote command", "host", c.SSHHostname, "cmd", remoteCmd) session, err := sshClient.NewSession() if err != nil { + _ = sshClient.Close() return "", fmt.Errorf("error creating SSH session: %v", err) } - defer session.Close() - modes := ssh.TerminalModes{ - ssh.ECHO: 0, - ssh.TTY_OP_ISPEED: 14400, - ssh.TTY_OP_OSPEED: 14400, - } - width, height, err := term.GetSize(int(os.Stdin.Fd())) - if err != nil { - width = 80 - height = 40 - } - if err := session.RequestPty("xterm", width, height, modes); err != nil { - return "", fmt.Errorf("error requesting pseudo terminal: %w", err) + // closeOnce ensures session and client are closed exactly once, + // whether by the watchdog goroutine (context cancellation) or by deferred cleanup. + var closeOnce sync.Once + closeResources := func() { + _ = session.Close() + _ = sshClient.Close() } + defer closeOnce.Do(closeResources) + go func() { + <-runCtx.Done() + closeOnce.Do(closeResources) + }() - // set terminal to raw for easier stdin/out/err handling - // between the os and ssh session - if printOutput && term.IsTerminal(int(os.Stdin.Fd())) { - oldState, err := term.MakeRaw(int(os.Stdin.Fd())) + if printOutput { + modes := ssh.TerminalModes{ + ssh.ECHO: 0, + ssh.TTY_OP_ISPEED: 14400, + ssh.TTY_OP_OSPEED: 14400, + } + width, height, err := term.GetSize(int(os.Stdin.Fd())) if err != nil { - return "", fmt.Errorf("failed to set terminal to raw mode: %v", err) + width = 80 + height = 40 } - defer func() { - if err := term.Restore(int(os.Stdin.Fd()), oldState); err != nil { - slog.Error("Unable to return terminal to original state.", "err", err) + if err := session.RequestPty("xterm", width, height, modes); err != nil { + return "", fmt.Errorf("error requesting pseudo terminal: %w", err) + } + + if term.IsTerminal(int(os.Stdin.Fd())) { + oldState, err := term.MakeRaw(int(os.Stdin.Fd())) + if err != nil { + return "", fmt.Errorf("failed to set terminal to raw mode: %v", err) } - }() - } + defer func() { + if err := term.Restore(int(os.Stdin.Fd()), oldState); err != nil { + slog.Error("Unable to return terminal to original state.", "err", err) + } + }() + } - // setup some stdout/err pipes so we can capture output - if printOutput { session.Stdin = os.Stdin } stdoutPipe, err := session.StdoutPipe() @@ -138,7 +154,7 @@ func (c *Context) runCommand(cmd *exec.Cmd, printOutput bool) (string, error) { if printOutput { fmt.Print(chunk) } - output = chunk + output.WriteString(chunk) } if err != nil { if err == io.EOF { @@ -152,10 +168,10 @@ func (c *Context) runCommand(cmd *exec.Cmd, printOutput bool) (string, error) { if err = session.Wait(); err != nil { // do not mark error on sigint if exitErr, ok := err.(*ssh.ExitError); ok && exitErr.ExitStatus() == 130 { - return output, nil + return output.String(), nil } return "", fmt.Errorf("error waiting for remote command %q: %v", remoteCmd, err) } - return output, nil + return output.String(), nil } diff --git a/pkg/config/cmd_test.go b/pkg/config/cmd_test.go index 09a8752..57c0a32 100644 --- a/pkg/config/cmd_test.go +++ b/pkg/config/cmd_test.go @@ -19,3 +19,19 @@ func TestRunCommandLocal(t *testing.T) { t.Fatalf("expected output to contain 'hello', got %v", output) } } + +func TestRunCommandRemoteSudoUnsupported(t *testing.T) { + ctx := &Context{ + DockerHostType: ContextRemote, + SSHUser: "deploy", + SSHHostname: "example.org", + } + + _, err := ctx.RunCommand(exec.Command("docker", "ps")) + if err == nil { + t.Fatal("expected remote ssh error") + } + if !strings.Contains(err.Error(), "error establishing SSH connection") { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/pkg/config/context.go b/pkg/config/context.go index d943a47..3382319 100644 --- a/pkg/config/context.go +++ b/pkg/config/context.go @@ -3,8 +3,6 @@ package config import ( "errors" "fmt" - "io" - "log" "log/slog" "net/url" "os" @@ -14,7 +12,6 @@ import ( "strings" "time" - "github.com/pkg/sftp" "github.com/spf13/pflag" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/knownhosts" @@ -46,7 +43,6 @@ type Context struct { SSHKeyPath string `yaml:"ssh-key,omitempty"` EnvFile []string `yaml:"env-file"` ComposeFile []string `yaml:"compose-file,omitempty"` - RunSudo bool `yaml:"sudo"` // Database connection configuration DatabaseService string `yaml:"database-service,omitempty"` @@ -54,7 +50,7 @@ type Context struct { DatabasePasswordSecret string `yaml:"database-password-secret,omitempty"` DatabaseName string `yaml:"database-name,omitempty"` - ReadSmallFileFunc func(filename string) string `yaml:"-"` + ReadSmallFileFunc func(filename string) (string, error) `yaml:"-"` // Extra holds plugin-specific configuration. // Each plugin uses its own key (e.g., "drupal", "isle", "wordpress"). @@ -193,7 +189,7 @@ func ResolveCurrentContextName(f *pflag.FlagSet) (string, error) { return c, nil } -func (c *Context) ReadSmallFile(filename string) string { +func (c *Context) ReadSmallFile(filename string) (string, error) { if c.ReadSmallFileFunc != nil { return c.ReadSmallFileFunc(filename) } @@ -201,41 +197,22 @@ func (c *Context) ReadSmallFile(filename string) string { if c.DockerHostType == ContextLocal { data, err := os.ReadFile(filename) if err != nil { - slog.Error("Error reading file", "file", filename, "err", err) - return "" + return "", fmt.Errorf("read file %q: %w", filename, err) } - - return string(data) - } - client, err := c.DialSSH() - if err != nil { - slog.Error("Error establishing SSH connection", "err", err) - return "" + return string(data), nil } - defer client.Close() - sftpClient, err := sftp.NewClient(client) + accessor, err := c.NewFileAccessor() if err != nil { - slog.Error("Error creating SFTP client", "err", err) - return "" + return "", fmt.Errorf("create file accessor: %w", err) } - defer sftpClient.Close() + defer accessor.Close() - // Use SFTP to read the file securely - remoteFile, err := sftpClient.Open(filename) + data, err := accessor.ReadFile(filename) if err != nil { - slog.Error("Error opening remote file", "file", filename, "err", err) - return "" - } - defer remoteFile.Close() - - data, err := io.ReadAll(remoteFile) - if err != nil { - slog.Error("Error reading remote file", "file", filename, "err", err) - return "" + return "", fmt.Errorf("read remote file %q: %w", filename, err) } - - return string(data) + return string(data), nil } func (c Context) EffectiveComposeProjectName() string { @@ -258,6 +235,9 @@ func (c *Context) DialSSH() (*ssh.Client, error) { // Check if the error is due to encryption (passphrase required) var ppErr *ssh.PassphraseMissingError if errors.As(err, &ppErr) { + if !term.IsTerminal(int(os.Stdin.Fd())) { + return nil, fmt.Errorf("ssh key %s requires a passphrase, but no interactive terminal is available", c.SSHKeyPath) + } // Key is encrypted, prompt for passphrase fmt.Printf("Enter passphrase for SSH key %s: ", c.SSHKeyPath) passphrase, err := term.ReadPassword(int(os.Stdin.Fd())) @@ -276,7 +256,10 @@ func (c *Context) DialSSH() (*ssh.Client, error) { } } - knownHostsPath := filepath.Join(filepath.Dir(c.SSHKeyPath), "known_hosts") + knownHostsPath, err := defaultKnownHostsPath() + if err != nil { + return nil, fmt.Errorf("error resolving known_hosts path: %w", err) + } slog.Debug("Setting known_hosts", "known_hosts", knownHostsPath) hostKeyCallback, err := knownhosts.New(knownHostsPath) if err != nil { @@ -316,6 +299,21 @@ func (c *Context) DialSSH() (*ssh.Client, error) { return client, nil } +func defaultKnownHostsPath() (string, error) { + home := os.Getenv("HOME") + if strings.TrimSpace(home) == "" { + u, err := user.Current() + if err != nil { + return "", err + } + home = u.HomeDir + } + if strings.TrimSpace(home) == "" { + return "", fmt.Errorf("unable to determine user home directory") + } + return filepath.Join(home, ".ssh", "known_hosts"), nil +} + func (c *Context) ProjectDirExists() (bool, error) { if c.DockerHostType == ContextLocal { _, err := os.Stat(c.ProjectDir) @@ -329,21 +327,13 @@ func (c *Context) ProjectDirExists() (bool, error) { return !os.IsNotExist(err), nil } - client, err := c.DialSSH() + accessor, err := c.NewFileAccessor() if err != nil { slog.Error("Error establishing SSH connection", "err", err) return false, err } - defer client.Close() - - sftpClient, err := sftp.NewClient(client) - if err != nil { - slog.Error("Error creating SFTP client", "err", err) - return false, err - } - defer sftpClient.Close() - - _, err = sftpClient.Stat(c.ProjectDir) + defer accessor.Close() + _, err = accessor.Stat(c.ProjectDir) if err != nil { return false, nil } @@ -459,37 +449,13 @@ func (cc *Context) VerifyRemoteInput(existingSite bool) error { } func (c *Context) UploadFile(source, destination string) error { - client, err := c.DialSSH() + accessor, err := c.NewFileAccessor() if err != nil { slog.Error("Error establishing SSH connection", "err", err) return err } - defer client.Close() - - sftpClient, err := sftp.NewClient(client) - if err != nil { - log.Fatal(err) - } - defer sftpClient.Close() - - localFile, err := os.Open(source) - if err != nil { - log.Fatal(err) - } - defer localFile.Close() - - remoteFile, err := sftpClient.Create(destination) - if err != nil { - return err - } - defer remoteFile.Close() - - _, err = remoteFile.ReadFrom(localFile) - if err != nil { - return err - } - - return nil + defer accessor.Close() + return accessor.UploadFile(source, destination) } // GetSshUri returns an SSH connection URI diff --git a/pkg/config/context_test.go b/pkg/config/context_test.go index dbf97a4..cc037b3 100644 --- a/pkg/config/context_test.go +++ b/pkg/config/context_test.go @@ -111,7 +111,6 @@ func contextsEqual(a, b Context) bool { a.SSHKeyPath == b.SSHKeyPath && len(a.EnvFile) == len(b.EnvFile) && len(a.ComposeFile) == len(b.ComposeFile) && - a.RunSudo == b.RunSudo && a.DatabaseService == b.DatabaseService && a.DatabaseUser == b.DatabaseUser && a.DatabasePasswordSecret == b.DatabasePasswordSecret && @@ -291,7 +290,10 @@ func TestReadSmallFileLocal(t *testing.T) { ctx := &Context{ DockerHostType: ContextLocal, } - readContent := ctx.ReadSmallFile(filePath) + readContent, err := ctx.ReadSmallFile(filePath) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } if readContent != content { t.Fatalf("expected %q, got %q", content, readContent) } @@ -317,6 +319,21 @@ func TestDialSSHError(t *testing.T) { } } +func TestDefaultKnownHostsPathUsesHomeSSHDirectory(t *testing.T) { + tempHome := t.TempDir() + t.Setenv("HOME", tempHome) + + path, err := defaultKnownHostsPath() + if err != nil { + t.Fatalf("defaultKnownHostsPath() error = %v", err) + } + + want := filepath.Join(tempHome, ".ssh", "known_hosts") + if path != want { + t.Fatalf("defaultKnownHostsPath() = %q, want %q", path, want) + } +} + func TestProjectDirExistsLocal(t *testing.T) { tempHome := t.TempDir() t.Setenv("HOME", tempHome) diff --git a/pkg/config/discovery.go b/pkg/config/discovery.go index fabf1d3..2ecb9d7 100644 --- a/pkg/config/discovery.go +++ b/pkg/config/discovery.go @@ -137,8 +137,8 @@ func readComposeDiscoveryDocForContext(ctx *Context) (composeDiscoveryDoc, bool) if err != nil || !exists { continue } - data := ctx.ReadSmallFile(path) - if strings.TrimSpace(data) == "" { + data, err := ctx.ReadSmallFile(path) + if err != nil || strings.TrimSpace(data) == "" { continue } var doc composeDiscoveryDoc diff --git a/pkg/config/file_accessor.go b/pkg/config/file_accessor.go new file mode 100644 index 0000000..49988d0 --- /dev/null +++ b/pkg/config/file_accessor.go @@ -0,0 +1,337 @@ +package config + +import ( + "context" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +const maxRemoteReadBytes int64 = 4 << 20 +const remoteReadConcurrency = 8 + +type FileAccessor struct { + ctx *Context + ssh *ssh.Client + sftp *sftp.Client + ownsSSH bool +} + +func (c *Context) NewFileAccessor() (*FileAccessor, error) { + return NewFileAccessor(c) +} + +func NewFileAccessor(ctx *Context) (*FileAccessor, error) { + return NewFileAccessorWithSSH(ctx, nil, true) +} + +func NewFileAccessorWithSSH(ctx *Context, sshClient *ssh.Client, ownsSSH bool) (*FileAccessor, error) { + accessor := &FileAccessor{ctx: ctx, ownsSSH: ownsSSH} + if ctx == nil || ctx.DockerHostType == ContextLocal { + return accessor, nil + } + if sshClient == nil { + var err error + sshClient, err = ctx.DialSSH() + if err != nil { + return nil, err + } + } + sftpClient, err := sftp.NewClient(sshClient) + if err != nil { + if ownsSSH { + sshClient.Close() + } + return nil, err + } + accessor.ssh = sshClient + accessor.sftp = sftpClient + return accessor, nil +} + +func (a *FileAccessor) Close() error { + if a == nil { + return nil + } + if a.sftp != nil { + _ = a.sftp.Close() + } + if a.ssh != nil && a.ownsSSH { + return a.ssh.Close() + } + return nil +} + +func (a *FileAccessor) ReadFile(filename string) ([]byte, error) { + return a.ReadFileContext(context.Background(), filename) +} + +func (a *FileAccessor) ReadFileContext(ctx context.Context, filename string) ([]byte, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + if a == nil || a.ctx == nil || a.ctx.DockerHostType == ContextLocal { + return os.ReadFile(filename) + } + remoteFile, err := a.sftp.Open(filename) + if err != nil { + return nil, err + } + defer remoteFile.Close() + return readAllLimited(remoteFile, maxRemoteReadBytes) +} + +func (a *FileAccessor) ReadFiles(paths []string) (map[string][]byte, error) { + return a.ReadFilesContext(context.Background(), paths) +} + +func (a *FileAccessor) ReadFilesContext(ctx context.Context, paths []string) (map[string][]byte, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + results := make(map[string][]byte, len(paths)) + missing := make([]string, 0, len(paths)) + + for _, path := range paths { + if strings.TrimSpace(path) == "" { + continue + } + missing = append(missing, path) + } + if len(missing) == 0 { + return results, nil + } + + if a == nil || a.ctx == nil || a.ctx.DockerHostType == ContextLocal { + for _, path := range missing { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + results[path] = data + } + return results, nil + } + + type readResult struct { + path string + data []byte + err error + } + + workers := remoteReadConcurrency + if len(missing) < workers { + workers = len(missing) + } + if workers < 1 { + workers = 1 + } + + jobs := make(chan string, len(missing)) + out := make(chan readResult, len(missing)) + var wg sync.WaitGroup + + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-ctx.Done(): + return + case path, ok := <-jobs: + if !ok { + return + } + if err := ctx.Err(); err != nil { + out <- readResult{path: path, err: err} + return + } + remoteFile, err := a.sftp.Open(path) + if err != nil { + out <- readResult{path: path, err: err} + cancel() + return + } + data, err := readAllLimited(remoteFile, maxRemoteReadBytes) + remoteFile.Close() + out <- readResult{path: path, data: data, err: err} + if err != nil { + cancel() + return + } + } + } + }() + } + +enqueue: + for _, path := range missing { + if err := ctx.Err(); err != nil { + break + } + select { + case <-ctx.Done(): + break enqueue + case jobs <- path: + } + } + close(jobs) + + go func() { + wg.Wait() + close(out) + }() + + var firstErr error + for result := range out { + if result.err != nil && firstErr == nil { + firstErr = result.err + cancel() + continue + } + if result.err != nil { + continue + } + results[result.path] = result.data + } + if firstErr != nil { + return nil, firstErr + } + return results, nil +} + +func (a *FileAccessor) WriteFile(filename string, data []byte) error { + if a == nil || a.ctx == nil || a.ctx.DockerHostType == ContextLocal { + if err := os.MkdirAll(filepath.Dir(filename), 0o755); err != nil { + return err + } + return os.WriteFile(filename, data, 0o644) + } + if err := mkdirAllRemote(a.sftp, filepath.Dir(filename)); err != nil { + return err + } + remoteFile, err := a.sftp.Create(filename) + if err != nil { + return err + } + defer remoteFile.Close() + _, err = remoteFile.Write(data) + return err +} + +func (a *FileAccessor) RemoveFile(filename string) error { + if a == nil || a.ctx == nil || a.ctx.DockerHostType == ContextLocal { + if err := os.Remove(filename); err != nil && !os.IsNotExist(err) { + return err + } + return nil + } + if err := a.sftp.Remove(filename); err != nil && !isSFTPNotExist(err) { + return err + } + return nil +} + +func (a *FileAccessor) ListFiles(root string) ([]string, error) { + if a == nil || a.ctx == nil || a.ctx.DockerHostType == ContextLocal { + return listLocalFiles(root) + } + walker := a.sftp.Walk(root) + files := []string{} + for walker.Step() { + if err := walker.Err(); err != nil { + return nil, err + } + if walker.Stat().IsDir() { + continue + } + rel, err := filepath.Rel(root, walker.Path()) + if err != nil { + return nil, err + } + files = append(files, filepath.ToSlash(rel)) + } + return files, nil +} + +func (a *FileAccessor) FileExists(path string) (bool, error) { + if a == nil || a.ctx == nil { + return false, nil + } + if strings.TrimSpace(path) == "" { + return false, nil + } + if a.ctx.DockerHostType == ContextLocal { + _, err := os.Stat(path) + if os.IsNotExist(err) { + return false, nil + } + return err == nil, err + } + _, err := a.sftp.Stat(path) + if err != nil { + return false, nil + } + return true, nil +} + +func (a *FileAccessor) Stat(path string) (fs.FileInfo, error) { + if a == nil || a.ctx == nil || a.ctx.DockerHostType == ContextLocal { + return os.Stat(path) + } + return a.sftp.Stat(path) +} + +func (a *FileAccessor) UploadFile(source, destination string) error { + localFile, err := os.Open(source) + if err != nil { + return err + } + defer localFile.Close() + + if a == nil || a.ctx == nil || a.ctx.DockerHostType == ContextLocal { + if err := os.MkdirAll(filepath.Dir(destination), 0o755); err != nil { + return err + } + remoteFile, err := os.Create(destination) + if err != nil { + return err + } + defer remoteFile.Close() + _, err = io.Copy(remoteFile, localFile) + return err + } + + if err := mkdirAllRemote(a.sftp, filepath.Dir(destination)); err != nil { + return err + } + remoteFile, err := a.sftp.Create(destination) + if err != nil { + return err + } + defer remoteFile.Close() + _, err = io.Copy(remoteFile, localFile) + return err +} + +func readAllLimited(r io.Reader, limit int64) ([]byte, error) { + limited := io.LimitReader(r, limit+1) + data, err := io.ReadAll(limited) + if err != nil { + return nil, err + } + if int64(len(data)) > limit { + return nil, fmt.Errorf("remote file exceeds %d bytes", limit) + } + return data, nil +} diff --git a/pkg/config/files.go b/pkg/config/files.go index 078f98a..5510cec 100644 --- a/pkg/config/files.go +++ b/pkg/config/files.go @@ -2,9 +2,7 @@ package config import ( "fmt" - "io" "io/fs" - "os" "path/filepath" "strings" @@ -13,136 +11,55 @@ import ( // ReadFile reads a file from the context, supporting local and remote paths. func (c *Context) ReadFile(filename string) ([]byte, error) { - if c.DockerHostType == ContextLocal { - return os.ReadFile(filename) - } - - client, err := c.DialSSH() - if err != nil { - return nil, fmt.Errorf("dial ssh: %w", err) - } - defer client.Close() - - sftpClient, err := sftp.NewClient(client) - if err != nil { - return nil, fmt.Errorf("create sftp client: %w", err) - } - defer sftpClient.Close() - - remoteFile, err := sftpClient.Open(filename) + accessor, err := c.NewFileAccessor() if err != nil { - return nil, fmt.Errorf("open remote file %q: %w", filename, err) + return nil, fmt.Errorf("create file accessor: %w", err) } - defer remoteFile.Close() - - data, err := io.ReadAll(remoteFile) + defer accessor.Close() + data, err := accessor.ReadFile(filename) if err != nil { - return nil, fmt.Errorf("read remote file %q: %w", filename, err) + return nil, fmt.Errorf("read file %q: %w", filename, err) } - return data, nil } // WriteFile writes a file to the context, creating parent directories as needed. func (c *Context) WriteFile(filename string, data []byte) error { - if c.DockerHostType == ContextLocal { - if err := os.MkdirAll(filepath.Dir(filename), 0o755); err != nil { - return fmt.Errorf("create parent directories for %q: %w", filename, err) - } - return os.WriteFile(filename, data, 0o644) - } - - client, err := c.DialSSH() - if err != nil { - return fmt.Errorf("dial ssh: %w", err) - } - defer client.Close() - - sftpClient, err := sftp.NewClient(client) + accessor, err := c.NewFileAccessor() if err != nil { - return fmt.Errorf("create sftp client: %w", err) + return fmt.Errorf("create file accessor: %w", err) } - defer sftpClient.Close() - - if err := mkdirAllRemote(sftpClient, filepath.Dir(filename)); err != nil { - return err + defer accessor.Close() + if err := accessor.WriteFile(filename, data); err != nil { + return fmt.Errorf("write file %q: %w", filename, err) } - - remoteFile, err := sftpClient.Create(filename) - if err != nil { - return fmt.Errorf("create remote file %q: %w", filename, err) - } - defer remoteFile.Close() - - if _, err := remoteFile.Write(data); err != nil { - return fmt.Errorf("write remote file %q: %w", filename, err) - } - return nil } // RemoveFile removes a file from the context. func (c *Context) RemoveFile(filename string) error { - if c.DockerHostType == ContextLocal { - if err := os.Remove(filename); err != nil && !os.IsNotExist(err) { - return err - } - return nil - } - - client, err := c.DialSSH() - if err != nil { - return fmt.Errorf("dial ssh: %w", err) - } - defer client.Close() - - sftpClient, err := sftp.NewClient(client) + accessor, err := c.NewFileAccessor() if err != nil { - return fmt.Errorf("create sftp client: %w", err) + return fmt.Errorf("create file accessor: %w", err) } - defer sftpClient.Close() - - if err := sftpClient.Remove(filename); err != nil && !isSFTPNotExist(err) { - return fmt.Errorf("remove remote file %q: %w", filename, err) + defer accessor.Close() + if err := accessor.RemoveFile(filename); err != nil { + return fmt.Errorf("remove file %q: %w", filename, err) } - return nil } // ListFiles lists files under a directory relative to the directory root. func (c *Context) ListFiles(root string) ([]string, error) { - if c.DockerHostType == ContextLocal { - return listLocalFiles(root) - } - - client, err := c.DialSSH() + accessor, err := c.NewFileAccessor() if err != nil { - return nil, fmt.Errorf("dial ssh: %w", err) + return nil, fmt.Errorf("create file accessor: %w", err) } - defer client.Close() - - sftpClient, err := sftp.NewClient(client) + defer accessor.Close() + files, err := accessor.ListFiles(root) if err != nil { - return nil, fmt.Errorf("create sftp client: %w", err) + return nil, fmt.Errorf("list files under %q: %w", root, err) } - defer sftpClient.Close() - - walker := sftpClient.Walk(root) - files := []string{} - for walker.Step() { - if err := walker.Err(); err != nil { - return nil, fmt.Errorf("walk remote path %q: %w", walker.Path(), err) - } - if walker.Stat().IsDir() { - continue - } - rel, err := filepath.Rel(root, walker.Path()) - if err != nil { - return nil, fmt.Errorf("get relative path for %q: %w", walker.Path(), err) - } - files = append(files, filepath.ToSlash(rel)) - } - return files, nil } diff --git a/pkg/config/utils.go b/pkg/config/utils.go index cd54ebc..ddb9675 100644 --- a/pkg/config/utils.go +++ b/pkg/config/utils.go @@ -136,7 +136,6 @@ func contextHasStoredValues(context Context) bool { context.SSHKeyPath != "" || len(context.EnvFile) > 0 || len(context.ComposeFile) > 0 || - context.RunSudo || context.DatabaseService != "" || context.DatabaseUser != "" || context.DatabasePasswordSecret != "" || diff --git a/pkg/config/utils_test.go b/pkg/config/utils_test.go index ac18f04..4a70ca8 100644 --- a/pkg/config/utils_test.go +++ b/pkg/config/utils_test.go @@ -106,9 +106,6 @@ func TestLoadFromFlags(t *testing.T) { if ctx.Environment != "staging" { t.Errorf("Expected environment 'staging', got %q", ctx.Environment) } - if ctx.RunSudo != true { - t.Errorf("Expected site 'true', got %t", ctx.RunSudo) - } expectedSlice := []string{".env", "/tmp/.env"} if !reflect.DeepEqual(ctx.EnvFile, expectedSlice) { t.Errorf("expected env-file slice %v but got %v", expectedSlice, ctx.EnvFile) diff --git a/pkg/config/validation_helpers.go b/pkg/config/validation_helpers.go index dfcde8c..178c916 100644 --- a/pkg/config/validation_helpers.go +++ b/pkg/config/validation_helpers.go @@ -8,7 +8,6 @@ import ( "strings" "github.com/kballard/go-shellquote" - "github.com/pkg/sftp" ) func IsDockerSocketAlive(socket string) bool { @@ -31,19 +30,12 @@ func (c *Context) FileExists(path string) (bool, error) { return err == nil, err } - client, err := c.DialSSH() + accessor, err := c.NewFileAccessor() if err != nil { return false, err } - defer client.Close() - - sftpClient, err := sftp.NewClient(client) - if err != nil { - return false, err - } - defer sftpClient.Close() - - _, err = sftpClient.Stat(path) + defer accessor.Close() + _, err = accessor.Stat(path) if err != nil { return false, nil } diff --git a/pkg/docker/docker.go b/pkg/docker/docker.go index a052341..00be6dc 100644 --- a/pkg/docker/docker.go +++ b/pkg/docker/docker.go @@ -32,13 +32,16 @@ type DockerClient struct { CLI DockerAPI SshCli *ssh.Client httpClient *http.Client + ownsSSH bool } func (d *DockerClient) Close() error { var firstErr error if d.SshCli != nil { - if err := d.SshCli.Close(); err != nil && firstErr == nil { - firstErr = err + if d.ownsSSH { + if err := d.SshCli.Close(); err != nil && firstErr == nil { + firstErr = err + } } } if d.httpClient != nil { @@ -62,6 +65,23 @@ func GetDockerCli(activeCtx *config.Context) (*DockerClient, error) { if err != nil { return nil, fmt.Errorf("error establishing SSH connection: %v", err) } + return GetDockerCliWithSSH(activeCtx, sshConn, true) +} + +func GetDockerCliWithSSH(activeCtx *config.Context, sshConn *ssh.Client, ownsSSH bool) (*DockerClient, error) { + if activeCtx.DockerHostType == config.ContextLocal { + cli, err := client.NewClientWithOpts( + client.WithHost("unix://"+activeCtx.DockerSocket), + client.WithAPIVersionNegotiation(), + ) + if err != nil { + return nil, fmt.Errorf("error creating local Docker client: %v", err) + } + return &DockerClient{CLI: cli}, nil + } + if sshConn == nil { + return nil, fmt.Errorf("ssh client is required for remote docker context") + } transport := &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return sshConn.Dial("unix", activeCtx.DockerSocket) @@ -76,13 +96,16 @@ func GetDockerCli(activeCtx *config.Context) (*DockerClient, error) { client.WithAPIVersionNegotiation(), ) if err != nil { - sshConn.Close() + if ownsSSH { + sshConn.Close() + } return nil, fmt.Errorf("error creating Docker client over SSH: %v", err) } return &DockerClient{ CLI: cli, SshCli: sshConn, httpClient: httpClient, + ownsSSH: ownsSSH, }, nil } @@ -95,7 +118,11 @@ func GetSecret(ctx context.Context, cli DockerAPI, c *config.Context, containerN for _, mount := range containerJSON.Mounts { if mount.Destination == expectedTarget { secretFilePath := filepath.Join(c.ProjectDir, "secrets", secretName) - return c.ReadSmallFile(secretFilePath), nil + secret, err := c.ReadSmallFile(secretFilePath) + if err != nil { + return "", fmt.Errorf("read secret %q: %w", secretName, err) + } + return secret, nil } } return GetConfigEnv(ctx, cli, containerName, secretName) @@ -138,8 +165,10 @@ func (d *DockerClient) GetServiceIp(ctx context.Context, c *config.Context, cont } func (d *DockerClient) GetContainerName(c *config.Context, service string) (string, error) { - ctx := context.Background() + return d.GetContainerNameContext(context.Background(), c, service) +} +func (d *DockerClient) GetContainerNameContext(ctx context.Context, c *config.Context, service string) (string, error) { // Define the filters based on the Docker Compose labels. filterArgs := filters.NewArgs() filterArgs.Add("label", "com.docker.compose.project="+c.EffectiveComposeProjectName()) @@ -315,7 +344,7 @@ func getDatabaseURIsWithClient(ctx context.Context, dockerCli *DockerClient, c * dbHost := "127.0.0.1" // Get the database container name - containerName, err := dockerCli.GetContainerName(c, c.DatabaseService) + containerName, err := dockerCli.GetContainerNameContext(ctx, c, c.DatabaseService) if err != nil { return "", "", fmt.Errorf("failed to get %s container: %w", c.DatabaseService, err) } diff --git a/pkg/docker/docker_test.go b/pkg/docker/docker_test.go index e66580f..20caba5 100644 --- a/pkg/docker/docker_test.go +++ b/pkg/docker/docker_test.go @@ -127,11 +127,11 @@ func TestGetSecret_MountedSecret(t *testing.T) { fakeConfig := &config.Context{ ProjectDir: "/tmp/project", ProjectName: "test", - ReadSmallFileFunc: func(path string) string { + ReadSmallFileFunc: func(path string) (string, error) { if strings.HasSuffix(path, filepath.Join("secrets", "secretName")) { - return "fileSecret" + return "fileSecret", nil } - return "" + return "", nil }, } secret, err := GetSecret(context.Background(), fake, fakeConfig, "dummyContainer", "secretName") diff --git a/pkg/docker/summary.go b/pkg/docker/summary.go index c47d634..68bcc6b 100644 --- a/pkg/docker/summary.go +++ b/pkg/docker/summary.go @@ -12,9 +12,7 @@ import ( dockercontainer "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/filters" - "github.com/kballard/go-shellquote" "github.com/libops/sitectl/pkg/config" - "golang.org/x/crypto/ssh" ) type ServiceSummary struct { @@ -136,33 +134,7 @@ func runComposePS(ctxCfg *config.Context) (string, error) { output, err := cmd.CombinedOutput() return string(output), err } - - remoteCmd := fmt.Sprintf("cd %s && ", shellquote.Join(ctxCfg.ProjectDir)) - if ctxCfg.RunSudo { - remoteCmd += "sudo " - } - remoteCmd += shellquote.Join(append([]string{"docker"}, args...)...) - - client, err := ctxCfg.DialSSH() - if err != nil { - return "", err - } - defer client.Close() - - session, err := client.NewSession() - if err != nil { - return "", err - } - defer session.Close() - - output, err := session.CombinedOutput(remoteCmd) - if err != nil { - if _, ok := err.(*ssh.ExitError); ok && len(output) > 0 { - return string(output), nil - } - return string(output), err - } - return string(output), nil + return ctxCfg.RunQuietCommand(exec.Command("docker", args...)) } func runDockerStats(ctxCfg *config.Context) (string, error) { @@ -173,33 +145,7 @@ func runDockerStats(ctxCfg *config.Context) (string, error) { output, err := cmd.CombinedOutput() return string(output), err } - - remoteCmd := fmt.Sprintf("cd %s && ", shellquote.Join(ctxCfg.ProjectDir)) - if ctxCfg.RunSudo { - remoteCmd += "sudo " - } - remoteCmd += shellquote.Join(append([]string{"docker"}, args...)...) - - client, err := ctxCfg.DialSSH() - if err != nil { - return "", err - } - defer client.Close() - - session, err := client.NewSession() - if err != nil { - return "", err - } - defer session.Close() - - output, err := session.CombinedOutput(remoteCmd) - if err != nil { - if _, ok := err.(*ssh.ExitError); ok && len(output) > 0 { - return string(output), nil - } - return string(output), err - } - return string(output), nil + return ctxCfg.RunQuietCommand(exec.Command("docker", args...)) } func runHostMetrics(ctxCfg *config.Context) (string, error) { @@ -259,33 +205,7 @@ printf '{"load1":"%s","cpu_count":"%s","disk_total_kb":"%s","disk_avail_kb":"%s" output, err := cmd.CombinedOutput() return string(output), err } - - client, err := ctxCfg.DialSSH() - if err != nil { - return "", err - } - defer client.Close() - - session, err := client.NewSession() - if err != nil { - return "", err - } - defer session.Close() - - remoteCmd := fmt.Sprintf("cd %s && ", shellquote.Join(ctxCfg.ProjectDir)) - if ctxCfg.RunSudo { - remoteCmd += "sudo " - } - remoteCmd += shellquote.Join("sh", "-lc", script) - - output, err := session.CombinedOutput(remoteCmd) - if err != nil { - if _, ok := err.(*ssh.ExitError); ok && len(output) > 0 { - return string(output), nil - } - return string(output), err - } - return string(output), nil + return ctxCfg.RunQuietCommand(exec.Command("sh", "-lc", script)) } func composePSArgs(ctxCfg config.Context) []string { diff --git a/pkg/plugin/files.go b/pkg/plugin/files.go new file mode 100644 index 0000000..df360b0 --- /dev/null +++ b/pkg/plugin/files.go @@ -0,0 +1,439 @@ +package plugin + +import ( + "context" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "sort" + "sync" + + "github.com/libops/sitectl/pkg/config" + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +const maxRemoteReadBytes int64 = 4 << 20 +const remoteReadConcurrency = 8 + +type FileAccessor struct { + ctx *config.Context + ssh *ssh.Client + sftp *sftp.Client + ownsSSH bool + mu sync.Mutex + readFileCache map[string][]byte + readDirCache map[string][]os.FileInfo + listFilesCache map[string][]string +} + +func (s *SDK) GetFileAccessor() (*FileAccessor, error) { + ctx, err := s.GetContext() + if err != nil { + return nil, err + } + if ctx == nil || ctx.DockerHostType == config.ContextLocal { + return NewFileAccessor(ctx) + } + sshClient, err := s.getSSHClient() + if err != nil { + return nil, err + } + return NewFileAccessorWithSSH(ctx, sshClient, false) +} + +func NewFileAccessor(ctx *config.Context) (*FileAccessor, error) { + return NewFileAccessorWithSSH(ctx, nil, true) +} + +func NewFileAccessorWithSSH(ctx *config.Context, client *ssh.Client, ownsSSH bool) (*FileAccessor, error) { + accessor := &FileAccessor{ + ctx: ctx, + ownsSSH: ownsSSH, + readFileCache: map[string][]byte{}, + readDirCache: map[string][]os.FileInfo{}, + listFilesCache: map[string][]string{}, + } + if ctx == nil || ctx.DockerHostType == config.ContextLocal { + return accessor, nil + } + if client == nil { + var err error + client, err = ctx.DialSSH() + if err != nil { + return nil, err + } + } + sftpClient, err := sftp.NewClient(client) + if err != nil { + if ownsSSH { + client.Close() + } + return nil, err + } + accessor.ssh = client + accessor.sftp = sftpClient + return accessor, nil +} + +func (a *FileAccessor) Close() error { + if a == nil { + return nil + } + if a.sftp != nil { + _ = a.sftp.Close() + } + if a.ssh != nil && a.ownsSSH { + return a.ssh.Close() + } + return nil +} + +func (a *FileAccessor) ReadFile(path string) ([]byte, error) { + return a.ReadFileContext(context.Background(), path) +} + +func (a *FileAccessor) ReadFileContext(ctx context.Context, path string) ([]byte, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + if data, ok := a.cachedFile(path); ok { + return data, nil + } + + var data []byte + var err error + if a == nil || a.ctx == nil || a.ctx.DockerHostType == config.ContextLocal { + data, err = os.ReadFile(path) + } else { + file, openErr := a.sftp.Open(path) + if openErr != nil { + return nil, openErr + } + defer file.Close() + data, err = readAllLimited(file, maxRemoteReadBytes) + } + if err != nil { + return nil, err + } + a.storeFile(path, data) + return append([]byte(nil), data...), nil +} + +func (a *FileAccessor) ReadFiles(paths []string) (map[string][]byte, error) { + return a.ReadFilesContext(context.Background(), paths) +} + +func (a *FileAccessor) ReadFilesContext(ctx context.Context, paths []string) (map[string][]byte, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + results := make(map[string][]byte, len(paths)) + missing := make([]string, 0, len(paths)) + + for _, path := range paths { + if path == "" { + continue + } + if data, ok := a.cachedFile(path); ok { + results[path] = data + continue + } + missing = append(missing, path) + } + + if len(missing) == 0 { + return results, nil + } + + if a == nil || a.ctx == nil || a.ctx.DockerHostType == config.ContextLocal { + for _, path := range missing { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + a.storeFile(path, data) + results[path] = append([]byte(nil), data...) + } + return results, nil + } + + type readResult struct { + path string + data []byte + err error + } + + workers := remoteReadConcurrency + if len(missing) < workers { + workers = len(missing) + } + if workers < 1 { + workers = 1 + } + + jobs := make(chan string, len(missing)) + out := make(chan readResult, len(missing)) + var wg sync.WaitGroup + + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-ctx.Done(): + return + case path, ok := <-jobs: + if !ok { + return + } + if err := ctx.Err(); err != nil { + out <- readResult{path: path, err: err} + return + } + file, err := a.sftp.Open(path) + if err != nil { + out <- readResult{path: path, err: err} + cancel() + return + } + data, err := readAllLimited(file, maxRemoteReadBytes) + file.Close() + out <- readResult{path: path, data: data, err: err} + if err != nil { + cancel() + return + } + } + } + }() + } + +enqueue: + for _, path := range missing { + if err := ctx.Err(); err != nil { + break + } + select { + case <-ctx.Done(): + break enqueue + case jobs <- path: + } + } + close(jobs) + + go func() { + wg.Wait() + close(out) + }() + + var firstErr error + for result := range out { + if result.err != nil && firstErr == nil { + firstErr = result.err + cancel() + continue + } + if result.err != nil { + continue + } + a.storeFile(result.path, result.data) + results[result.path] = append([]byte(nil), result.data...) + } + if firstErr != nil { + return nil, firstErr + } + return results, nil +} + +func (a *FileAccessor) ListFiles(root string) ([]string, error) { + if files, ok := a.cachedList(root); ok { + return files, nil + } + + if a == nil || a.ctx == nil || a.ctx.DockerHostType == config.ContextLocal { + files := []string{} + err := filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + return nil + } + rel, err := filepath.Rel(root, path) + if err != nil { + return err + } + files = append(files, filepath.ToSlash(rel)) + return nil + }) + if err == nil { + a.storeList(root, files) + } + return files, err + } + files := []string{} + walker := a.sftp.Walk(root) + for walker.Step() { + if err := walker.Err(); err != nil { + return nil, err + } + if walker.Stat().IsDir() { + continue + } + rel, err := filepath.Rel(root, walker.Path()) + if err != nil { + return nil, err + } + files = append(files, filepath.ToSlash(rel)) + } + a.storeList(root, files) + return files, nil +} + +func (a *FileAccessor) MatchFiles(root, pattern string) ([]string, error) { + files, err := a.ListFiles(root) + if err != nil { + return nil, err + } + matches := []string{} + for _, rel := range files { + ok, err := filepath.Match(pattern, filepath.Base(rel)) + if err != nil { + return nil, err + } + if ok { + matches = append(matches, filepath.Join(root, rel)) + } + } + sort.Strings(matches) + return matches, nil +} + +func (a *FileAccessor) MatchFilesInDir(root, pattern string) ([]string, error) { + matches := []string{} + + entries, err := a.readDir(root) + if err != nil { + return nil, err + } + for _, entry := range entries { + if entry.IsDir() { + continue + } + ok, err := filepath.Match(pattern, entry.Name()) + if err != nil { + return nil, err + } + if ok { + matches = append(matches, filepath.Join(root, entry.Name())) + } + } + sort.Strings(matches) + return matches, nil +} + +func (a *FileAccessor) readDir(root string) ([]os.FileInfo, error) { + if a == nil { + return nil, os.ErrInvalid + } + a.mu.Lock() + if entries, ok := a.readDirCache[root]; ok { + a.mu.Unlock() + return entries, nil + } + a.mu.Unlock() + + var entries []os.FileInfo + if a.ctx == nil || a.ctx.DockerHostType == config.ContextLocal { + dirEntries, err := os.ReadDir(root) + if err != nil { + return nil, err + } + fileInfos := make([]os.FileInfo, 0, len(dirEntries)) + for _, entry := range dirEntries { + info, infoErr := entry.Info() + if infoErr != nil { + return nil, infoErr + } + fileInfos = append(fileInfos, info) + } + entries = fileInfos + } else { + var err error + entries, err = a.sftp.ReadDir(root) + if err != nil { + return nil, err + } + } + + a.mu.Lock() + a.readDirCache[root] = entries + a.mu.Unlock() + return entries, nil +} + +func (a *FileAccessor) cachedFile(path string) ([]byte, bool) { + if a == nil { + return nil, false + } + a.mu.Lock() + defer a.mu.Unlock() + data, ok := a.readFileCache[path] + if !ok { + return nil, false + } + return append([]byte(nil), data...), true +} + +func (a *FileAccessor) storeFile(path string, data []byte) { + if a == nil { + return + } + a.mu.Lock() + a.readFileCache[path] = append([]byte(nil), data...) + a.mu.Unlock() +} + +func (a *FileAccessor) cachedList(root string) ([]string, bool) { + if a == nil { + return nil, false + } + a.mu.Lock() + defer a.mu.Unlock() + files, ok := a.listFilesCache[root] + if !ok { + return nil, false + } + out := make([]string, len(files)) + copy(out, files) + return out, true +} + +func (a *FileAccessor) storeList(root string, files []string) { + if a == nil { + return + } + out := make([]string, len(files)) + copy(out, files) + a.mu.Lock() + a.listFilesCache[root] = out + a.mu.Unlock() +} + +func readAllLimited(r io.Reader, limit int64) ([]byte, error) { + limited := io.LimitReader(r, limit+1) + data, err := io.ReadAll(limited) + if err != nil { + return nil, err + } + if int64(len(data)) > limit { + return nil, fmt.Errorf("remote file exceeds %d bytes", limit) + } + return data, nil +} diff --git a/pkg/plugin/files_test.go b/pkg/plugin/files_test.go new file mode 100644 index 0000000..e997239 --- /dev/null +++ b/pkg/plugin/files_test.go @@ -0,0 +1,107 @@ +package plugin + +import ( + "os" + "path/filepath" + "slices" + "testing" + + "github.com/libops/sitectl/pkg/config" +) + +func TestNewFileAccessorLocalReadListAndMatch(t *testing.T) { + root := t.TempDir() + + if err := os.MkdirAll(filepath.Join(root, "nested"), 0o755); err != nil { + t.Fatalf("MkdirAll() error = %v", err) + } + if err := os.WriteFile(filepath.Join(root, "alpha.yml"), []byte("alpha"), 0o644); err != nil { + t.Fatalf("WriteFile(alpha) error = %v", err) + } + if err := os.WriteFile(filepath.Join(root, "nested", "beta.yml"), []byte("beta"), 0o644); err != nil { + t.Fatalf("WriteFile(beta) error = %v", err) + } + if err := os.WriteFile(filepath.Join(root, "nested", "gamma.txt"), []byte("gamma"), 0o644); err != nil { + t.Fatalf("WriteFile(gamma) error = %v", err) + } + + ctx := &config.Context{DockerHostType: config.ContextLocal} + accessor, err := NewFileAccessor(ctx) + if err != nil { + t.Fatalf("NewFileAccessor() error = %v", err) + } + defer accessor.Close() + + got, err := accessor.ReadFile(filepath.Join(root, "alpha.yml")) + if err != nil { + t.Fatalf("ReadFile() error = %v", err) + } + if string(got) != "alpha" { + t.Fatalf("expected alpha content, got %q", string(got)) + } + + files, err := accessor.ListFiles(root) + if err != nil { + t.Fatalf("ListFiles() error = %v", err) + } + wantFiles := []string{"alpha.yml", "nested/beta.yml", "nested/gamma.txt"} + if !slices.Equal(files, wantFiles) { + t.Fatalf("unexpected files: got %v want %v", files, wantFiles) + } + + matches, err := accessor.MatchFiles(root, "*.yml") + if err != nil { + t.Fatalf("MatchFiles() error = %v", err) + } + wantMatches := []string{ + filepath.Join(root, "alpha.yml"), + filepath.Join(root, "nested", "beta.yml"), + } + if !slices.Equal(matches, wantMatches) { + t.Fatalf("unexpected matches: got %v want %v", matches, wantMatches) + } + + flatMatches, err := accessor.MatchFilesInDir(root, "*.yml") + if err != nil { + t.Fatalf("MatchFilesInDir() error = %v", err) + } + wantFlatMatches := []string{ + filepath.Join(root, "alpha.yml"), + } + if !slices.Equal(flatMatches, wantFlatMatches) { + t.Fatalf("unexpected flat matches: got %v want %v", flatMatches, wantFlatMatches) + } +} + +func TestSDKGetFileAccessorUsesResolvedContext(t *testing.T) { + tempHome := t.TempDir() + t.Setenv("HOME", tempHome) + + ctx := config.Context{ + Name: "museum", + Site: "museum", + Plugin: "isle", + DockerHostType: config.ContextLocal, + DockerSocket: "/var/run/docker.sock", + ProjectDir: tempHome, + } + if err := config.SaveContext(&ctx, true); err != nil { + t.Fatalf("SaveContext() error = %v", err) + } + + sdk := NewSDK(Metadata{Name: "drupal"}) + sdk.Config.Context = "museum" + + accessor, err := sdk.GetFileAccessor() + if err != nil { + t.Fatalf("GetFileAccessor() error = %v", err) + } + defer accessor.Close() + + if accessor.ctx == nil { + t.Fatal("expected accessor context to be set") + } + if accessor.ctx.Name != "museum" { + t.Fatalf("unexpected accessor context %q", accessor.ctx.Name) + } +} diff --git a/pkg/plugin/sdk.go b/pkg/plugin/sdk.go index 108f4b1..901799e 100644 --- a/pkg/plugin/sdk.go +++ b/pkg/plugin/sdk.go @@ -8,8 +8,10 @@ import ( "log/slog" "os" "os/exec" + "os/signal" "strconv" "strings" + "syscall" "charm.land/fang/v2" "github.com/libops/sitectl/pkg/component" @@ -18,6 +20,7 @@ import ( "github.com/libops/sitectl/pkg/helpers" "github.com/libops/sitectl/pkg/validate" "github.com/spf13/cobra" + "golang.org/x/crypto/ssh" "golang.org/x/term" ) @@ -49,6 +52,8 @@ type SDK struct { Config Config RootCmd *cobra.Command contextValidators []validate.Validator + contextCache *config.Context + sshClient *ssh.Client } // NewSDK creates a new plugin SDK instance @@ -65,6 +70,9 @@ func NewSDK(metadata Metadata) *SDK { PersistentPreRunE: func(cmd *cobra.Command, args []string) error { return sdk.setupLogging(cmd) }, + PersistentPostRun: func(cmd *cobra.Command, args []string) { + sdk.Close() + }, Annotations: map[string]string{ cobra.CommandDisplayNameAnnotation: fmt.Sprintf("sitectl %s", metadata.Name), }, @@ -94,17 +102,13 @@ func (s *SDK) setupLogging(cmd *cobra.Command) error { opts := &slog.HandlerOptions{ Level: level, } - handler := slog.New(slog.NewTextHandler(os.Stdout, opts)) + handler := slog.New(slog.NewTextHandler(os.Stderr, opts)) slog.SetDefault(handler) - // Store config for plugin use + // Store config for plugin use. s.Config.LogLevel = ll if s.RootCmd.PersistentFlags().Lookup("context") != nil { - if cmd.Flags().Changed("context") { - s.Config.Context, _ = cmd.Flags().GetString("context") - } else { - s.Config.Context = "" - } + s.Config.Context, _ = cmd.Flags().GetString("context") } return nil @@ -132,8 +136,14 @@ func (s *SDK) AddCommand(cmd *cobra.Command) { // Execute runs the plugin func (s *SDK) Execute() { + runCtx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + go func() { + <-runCtx.Done() + _ = s.Close() + }() if err := fang.Execute( - context.Background(), + runCtx, s.RootCmd, fang.WithVersion(s.RootCmd.Version), ); err != nil { @@ -179,14 +189,27 @@ func (s *SDK) GetDockerClient() (*docker.DockerClient, error) { if err != nil { return nil, fmt.Errorf("failed to get context: %w", err) } + if ctx.DockerHostType == config.ContextLocal { + return docker.GetDockerCli(ctx) + } + sshClient, err := s.getSSHClient() + if err != nil { + return nil, err + } + return docker.GetDockerCliWithSSH(ctx, sshClient, false) +} - return docker.GetDockerCli(ctx) +func (s *SDK) GetSSHClient() (*ssh.Client, error) { + return s.getSSHClient() } // GetContext loads the sitectl context configuration // This is useful for plugins that need to access context-specific settings // If no context is specified, returns the current context from config func (s *SDK) GetContext() (*config.Context, error) { + if s.contextCache != nil { + return s.contextCache, nil + } // Load the config cfg, err := config.Load() if err != nil { @@ -212,13 +235,44 @@ func (s *SDK) GetContext() (*config.Context, error) { if err := validateContextPlugin(ctx.Plugin, s.Metadata.Name); err != nil { return nil, fmt.Errorf("context %q is not supported by plugin %q: %w", ctx.Name, s.Metadata.Name, err) } - return &ctx, nil + s.contextCache = &ctx + return s.contextCache, nil } } return nil, fmt.Errorf("context %q not found", contextName) } +func (s *SDK) getSSHClient() (*ssh.Client, error) { + if s.sshClient != nil { + return s.sshClient, nil + } + ctx, err := s.GetContext() + if err != nil { + return nil, err + } + if ctx == nil || ctx.DockerHostType == config.ContextLocal { + return nil, nil + } + s.sshClient, err = ctx.DialSSH() + if err != nil { + return nil, err + } + return s.sshClient, nil +} + +func (s *SDK) Close() error { + if s == nil { + return nil + } + if s.sshClient != nil { + err := s.sshClient.Close() + s.sshClient = nil + return err + } + return nil +} + func validateContextPlugin(contextPlugin, requestedPlugin string) error { contextPlugin = strings.TrimSpace(contextPlugin) requestedPlugin = strings.TrimSpace(requestedPlugin) @@ -312,35 +366,48 @@ func (s *SDK) PromptAndSaveLocalContext(opts config.LocalContextCreateOptions) ( return config.PromptAndSaveLocalContext(opts) } -// ExecInContainer executes a command in a Docker container -// This is a convenience wrapper for plugins -func (s *SDK) ExecInContainer(ctx context.Context, containerID string, cmd []string) (int, error) { +// ExecContainer executes a command in a Docker container using the shared SDK Docker path. +func (s *SDK) ExecContainer(ctx context.Context, opts docker.ExecOptions) (int, error) { cli, err := s.GetDockerClient() if err != nil { return -1, fmt.Errorf("failed to create Docker client: %w", err) } defer cli.Close() - return cli.ExecSimple(ctx, containerID, cmd) + return cli.Exec(ctx, opts) } -// ExecInContainerInteractive executes an interactive command in a Docker container with TTY -// This is a convenience wrapper for plugins -func (s *SDK) ExecInContainerInteractive(ctx context.Context, containerID string, cmd []string) (int, error) { - cli, err := s.GetDockerClient() - if err != nil { - return -1, fmt.Errorf("failed to create Docker client: %w", err) - } - defer cli.Close() +// ExecInContainer executes a command in a Docker container. +// This is a convenience wrapper for plugins. +func (s *SDK) ExecInContainer(ctx context.Context, containerID string, cmd []string) (int, error) { + return s.ExecContainer(ctx, docker.ExecOptions{ + Container: containerID, + Cmd: cmd, + AttachStdout: true, + AttachStderr: true, + }) +} - return cli.ExecInteractive(ctx, containerID, cmd) +// ExecInContainerInteractive executes an interactive command in a Docker container with TTY. +// This is a convenience wrapper for plugins. +func (s *SDK) ExecInContainerInteractive(ctx context.Context, containerID string, cmd []string) (int, error) { + return s.ExecContainer(ctx, docker.ExecOptions{ + Container: containerID, + Cmd: cmd, + AttachStdin: true, + AttachStdout: true, + AttachStderr: true, + Tty: true, + }) } type CommandExecOptions struct { - Stdin io.Reader - Stdout io.Writer - Stderr io.Writer - Capture bool + Context context.Context + Stdin io.Reader + Stdout io.Writer + Stderr io.Writer + Capture bool + LiveStderr bool } func (s *SDK) InvokePluginCommand(pluginName string, args []string, opts CommandExecOptions) (string, error) { @@ -357,8 +424,13 @@ func (s *SDK) InvokePluginCommand(pluginName string, args []string, opts Command invocation = append(invocation, "--log-level", s.Config.LogLevel) } invocation = append(invocation, args...) + slog.Debug("invoking plugin command", "plugin", pluginName, "path", installed.Path, "args", invocation, "capture", opts.Capture) - cmd := exec.Command(installed.Path, invocation...) + execCtx := opts.Context + if execCtx == nil { + execCtx = context.Background() + } + cmd := exec.CommandContext(execCtx, installed.Path, invocation...) cmd.Env = os.Environ() if width, ok := terminalColumns(); ok { cmd.Env = append(cmd.Env, fmt.Sprintf("COLUMNS=%d", width)) @@ -370,7 +442,15 @@ func (s *SDK) InvokePluginCommand(pluginName string, args []string, opts Command var stdout bytes.Buffer var stderr bytes.Buffer cmd.Stdout = &stdout - cmd.Stderr = &stderr + var stderrSink io.Writer + if opts.Stderr != nil && opts.LiveStderr { + stderrSink = io.MultiWriter(opts.Stderr, &stderr) + } else if opts.LiveStderr { + stderrSink = io.MultiWriter(os.Stderr, &stderr) + } else { + stderrSink = &stderr + } + cmd.Stderr = stderrSink if err := cmd.Run(); err != nil { detail := strings.TrimSpace(stderr.String()) if detail == "" { @@ -381,6 +461,7 @@ func (s *SDK) InvokePluginCommand(pluginName string, args []string, opts Command } return "", fmt.Errorf("run plugin %q: %w", pluginName, err) } + slog.Debug("plugin command completed", "plugin", pluginName, "path", installed.Path) return stdout.String(), nil } @@ -396,6 +477,7 @@ func (s *SDK) InvokePluginCommand(pluginName string, args []string, opts Command if err := cmd.Run(); err != nil { return "", fmt.Errorf("run plugin %q: %w", pluginName, err) } + slog.Debug("plugin command completed", "plugin", pluginName, "path", installed.Path) return "", nil } diff --git a/pkg/plugin/sdk_test.go b/pkg/plugin/sdk_test.go index e43225a..b7099e7 100644 --- a/pkg/plugin/sdk_test.go +++ b/pkg/plugin/sdk_test.go @@ -1,9 +1,15 @@ package plugin import ( + "bytes" + "os" + "path/filepath" + "reflect" + "strings" "testing" "github.com/libops/sitectl/pkg/config" + "github.com/libops/sitectl/pkg/validate" ) func TestGetContextAllowsIncludedPlugin(t *testing.T) { @@ -70,3 +76,188 @@ func TestContextPluginSupportsBuiltinHierarchy(t *testing.T) { t.Fatal("did not expect drupal contexts to support isle") } } + +func TestPluginIncludesMergesBuiltinAndInstalledWithoutDuplicates(t *testing.T) { + dir := t.TempDir() + t.Setenv("PATH", dir) + + script := `#!/bin/sh +if [ "$1" = "plugin-info" ]; then + echo "Name: isle" + echo "Includes: drupal,libops" + exit 0 +fi +if [ "$1" = "create" ] && [ "$2" = "--help" ]; then + exit 0 +fi +exit 1 +` + writePluginScript(t, dir, "sitectl-isle", script) + + got := pluginIncludes("isle") + want := []string{"drupal", "libops"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("pluginIncludes() = %v, want %v", got, want) + } +} + +func TestInvokePluginCommandCapturePassesContextAndLogLevel(t *testing.T) { + dir := t.TempDir() + t.Setenv("PATH", dir) + t.Setenv("COLUMNS", "123") + + script := `#!/bin/sh +if [ "$1" = "plugin-info" ]; then + echo "Name: child" + exit 0 +fi +if [ "$1" = "create" ] && [ "$2" = "--help" ]; then + exit 1 +fi +printf 'ARGS=%s\n' "$*" +printf 'COLUMNS=%s\n' "$COLUMNS" +` + writePluginScript(t, dir, "sitectl-child", script) + + sdk := NewSDK(Metadata{Name: "isle"}) + sdk.Config.Context = "demo" + sdk.Config.LogLevel = "DEBUG" + + out, err := sdk.InvokePluginCommand("child", []string{"__debug", "--verbose"}, CommandExecOptions{Capture: true}) + if err != nil { + t.Fatalf("InvokePluginCommand() error = %v", err) + } + if !strings.Contains(out, "ARGS=--context demo --log-level DEBUG __debug --verbose") { + t.Fatalf("expected context/log-level args in output, got %q", out) + } + if !strings.Contains(out, "COLUMNS=123") { + t.Fatalf("expected COLUMNS env in output, got %q", out) + } +} + +func TestInvokePluginCommandCaptureReturnsStderrDetail(t *testing.T) { + dir := t.TempDir() + t.Setenv("PATH", dir) + + script := `#!/bin/sh +if [ "$1" = "plugin-info" ]; then + echo "Name: broken" + exit 0 +fi +echo "something went wrong" >&2 +exit 2 +` + writePluginScript(t, dir, "sitectl-broken", script) + + sdk := NewSDK(Metadata{Name: "isle"}) + _, err := sdk.InvokePluginCommand("broken", []string{"__debug"}, CommandExecOptions{Capture: true}) + if err == nil { + t.Fatal("expected InvokePluginCommand() error") + } + if !strings.Contains(err.Error(), "something went wrong") { + t.Fatalf("expected stderr detail in error, got %v", err) + } +} + +func TestInvokePluginCommandCaptureCanMirrorLiveStderr(t *testing.T) { + dir := t.TempDir() + t.Setenv("PATH", dir) + + script := `#!/bin/sh +if [ "$1" = "plugin-info" ]; then + echo "Name: noisy" + exit 0 +fi +echo "visible stderr" >&2 +echo "stdout payload" +` + writePluginScript(t, dir, "sitectl-noisy", script) + + sdk := NewSDK(Metadata{Name: "isle"}) + var stderr bytes.Buffer + out, err := sdk.InvokePluginCommand("noisy", []string{"__debug"}, CommandExecOptions{ + Capture: true, + LiveStderr: true, + Stderr: &stderr, + }) + if err != nil { + t.Fatalf("InvokePluginCommand() error = %v", err) + } + if !strings.Contains(stderr.String(), "visible stderr") { + t.Fatalf("expected mirrored stderr, got %q", stderr.String()) + } + if !strings.Contains(out, "stdout payload") { + t.Fatalf("expected stdout payload, got %q", out) + } +} + +func TestInvokeIncludedPluginCommandRejectsUnincludedPlugin(t *testing.T) { + sdk := NewSDK(Metadata{Name: "isle", Includes: []string{"drupal"}}) + + _, err := sdk.InvokeIncludedPluginCommand("libops", []string{"__debug"}, CommandExecOptions{Capture: true}) + if err == nil { + t.Fatal("expected included plugin validation error") + } + if !strings.Contains(err.Error(), `is not included by "isle"`) { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestInvokeIncludedPluginsCollectsTrimmedOutputs(t *testing.T) { + dir := t.TempDir() + t.Setenv("PATH", dir) + + writePluginScript(t, dir, "sitectl-drupal", `#!/bin/sh +if [ "$1" = "plugin-info" ]; then + echo "Name: drupal" + exit 0 +fi +echo " drupal output " +`) + writePluginScript(t, dir, "sitectl-libops", `#!/bin/sh +if [ "$1" = "plugin-info" ]; then + echo "Name: libops" + exit 0 +fi +echo "" +`) + + sdk := NewSDK(Metadata{Name: "isle", Includes: []string{"drupal", "libops"}}) + outputs, err := sdk.InvokeIncludedPlugins([]string{"__debug"}) + if err != nil { + t.Fatalf("InvokeIncludedPlugins() error = %v", err) + } + want := []string{"drupal output"} + if !reflect.DeepEqual(outputs, want) { + t.Fatalf("InvokeIncludedPlugins() = %v, want %v", outputs, want) + } +} + +func TestContextValidatorsReturnsCopy(t *testing.T) { + sdk := NewSDK(Metadata{Name: "isle"}) + first := validate.Validator(func(*config.Context) ([]validate.Result, error) { return nil, nil }) + second := validate.Validator(func(*config.Context) ([]validate.Result, error) { return nil, nil }) + + sdk.RegisterContextValidator(first) + sdk.RegisterContextValidator(nil) + sdk.RegisterContextValidator(second) + + got := sdk.ContextValidators() + if len(got) != 2 { + t.Fatalf("expected 2 validators, got %d", len(got)) + } + + got[0] = nil + again := sdk.ContextValidators() + if again[0] == nil { + t.Fatal("expected ContextValidators() to return a copy") + } +} + +func writePluginScript(t *testing.T, dir, name, script string) { + t.Helper() + path := filepath.Join(dir, name) + if err := os.WriteFile(path, []byte(script), 0o755); err != nil { + t.Fatalf("WriteFile(%s) error = %v", name, err) + } +} diff --git a/pkg/tui/dashboard.go b/pkg/tui/dashboard.go index 5ae0ef3..1895ab6 100644 --- a/pkg/tui/dashboard.go +++ b/pkg/tui/dashboard.go @@ -26,7 +26,6 @@ import ( "github.com/libops/sitectl/pkg/docker" "github.com/libops/sitectl/pkg/plugin" zone "github.com/lrstanley/bubblezone/v2" - "golang.org/x/crypto/ssh" ) type siteGroup struct { @@ -1286,33 +1285,7 @@ func fetchComposeLogs(ctx config.Context) (string, error) { output, err := cmd.CombinedOutput() return string(output), err } - - remoteCmd := fmt.Sprintf("cd %s && ", shellquote.Join(ctx.ProjectDir)) - if ctx.RunSudo { - remoteCmd += "sudo " - } - remoteCmd += shellquote.Join(append([]string{"docker"}, args...)...) - - client, err := ctx.DialSSH() - if err != nil { - return "", err - } - defer client.Close() - - session, err := client.NewSession() - if err != nil { - return "", err - } - defer session.Close() - - output, err := session.CombinedOutput(remoteCmd) - if err != nil { - if _, ok := err.(*ssh.ExitError); ok && len(output) > 0 { - return string(output), nil - } - return string(output), err - } - return string(output), nil + return ctx.RunQuietCommand(exec.Command("docker", args...)) } func composeArgs(ctx config.Context, subcommand ...string) []string { @@ -2083,25 +2056,5 @@ func fetchContainerLogs(ctx config.Context, containerName string) (string, error output, err := cmd.CombinedOutput() return string(output), err } - - remoteCmd := fmt.Sprintf("cd %s && ", shellquote.Join(ctx.ProjectDir)) - if ctx.RunSudo { - remoteCmd += "sudo " - } - remoteCmd += shellquote.Join(append([]string{"docker"}, args...)...) - - client, err := ctx.DialSSH() - if err != nil { - return "", err - } - defer client.Close() - - session, err := client.NewSession() - if err != nil { - return "", err - } - defer session.Close() - - output, err := session.CombinedOutput(remoteCmd) - return string(output), err + return ctx.RunQuietCommand(exec.Command("docker", args...)) }