diff --git a/pkg/config/context.go b/pkg/config/context.go index 4183bc2..4e64c2f 100644 --- a/pkg/config/context.go +++ b/pkg/config/context.go @@ -1,9 +1,12 @@ package config import ( + "crypto/ed25519" + "crypto/rand" "errors" "fmt" "log/slog" + "net" "net/url" "os" "os/user" @@ -278,7 +281,12 @@ func (c *Context) DialSSH() (*ssh.Client, error) { Timeout: 5 * time.Second, } - sshAddr := fmt.Sprintf("%s:%d", c.SSHHostname, c.SSHPort) + sshAddr := net.JoinHostPort(c.SSHHostname, strconv.Itoa(int(c.SSHPort))) + sshConfig.HostKeyAlgorithms = knownHostKeyAlgorithms(hostKeyCallback, sshAddr) + if len(sshConfig.HostKeyAlgorithms) > 0 { + slog.Debug("Restricting SSH host key algorithms from known_hosts", "host", sshAddr, "algorithms", sshConfig.HostKeyAlgorithms) + } + slog.Debug("Dialing " + sshAddr) client, err := ssh.Dial("tcp", sshAddr, sshConfig) if err != nil { @@ -302,6 +310,45 @@ func (c *Context) DialSSH() (*ssh.Client, error) { return client, nil } +func knownHostKeyAlgorithms(hostKeyCallback ssh.HostKeyCallback, hostWithPort string) []string { + if hostKeyCallback == nil { + return nil + } + + _, privateKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil + } + + signer, err := ssh.NewSignerFromKey(privateKey) + if err != nil { + return nil + } + + err = hostKeyCallback(hostWithPort, &net.TCPAddr{}, signer.PublicKey()) + if err == nil { + return nil + } + + var keyErr *knownhosts.KeyError + if !errors.As(err, &keyErr) || len(keyErr.Want) == 0 { + return nil + } + + algorithms := make([]string, 0, len(keyErr.Want)) + seen := make(map[string]struct{}, len(keyErr.Want)) + for _, knownKey := range keyErr.Want { + algorithm := knownKey.Key.Type() + if _, ok := seen[algorithm]; ok { + continue + } + seen[algorithm] = struct{}{} + algorithms = append(algorithms, algorithm) + } + + return algorithms +} + func defaultKnownHostsPath() (string, error) { home := os.Getenv("HOME") if strings.TrimSpace(home) == "" { diff --git a/pkg/config/context_test.go b/pkg/config/context_test.go index 35c34f3..ee4ad44 100644 --- a/pkg/config/context_test.go +++ b/pkg/config/context_test.go @@ -1,6 +1,8 @@ package config import ( + "crypto/ed25519" + "crypto/rand" "errors" "os" "path/filepath" @@ -8,6 +10,8 @@ import ( "testing" "github.com/spf13/pflag" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" "gopkg.in/yaml.v3" ) @@ -338,6 +342,69 @@ func TestDefaultKnownHostsPathUsesHomeSSHDirectory(t *testing.T) { } } +func TestKnownHostKeyAlgorithmsReturnsKnownAlgorithms(t *testing.T) { + tempDir := t.TempDir() + knownHostsPath := filepath.Join(tempDir, "known_hosts") + + _, privateKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("failed to generate host key: %v", err) + } + + signer, err := ssh.NewSignerFromKey(privateKey) + if err != nil { + t.Fatalf("failed to create signer: %v", err) + } + + line := knownhosts.Line([]string{"example.com"}, signer.PublicKey()) + if err := os.WriteFile(knownHostsPath, []byte(line), 0600); err != nil { + t.Fatalf("failed to write known_hosts file: %v", err) + } + + callback, err := knownhosts.New(knownHostsPath) + if err != nil { + t.Fatalf("failed to create known_hosts callback: %v", err) + } + + algorithms := knownHostKeyAlgorithms(callback, "example.com:22") + if len(algorithms) != 1 { + t.Fatalf("expected 1 algorithm, got %d: %v", len(algorithms), algorithms) + } + if algorithms[0] != ssh.KeyAlgoED25519 { + t.Fatalf("expected %q, got %q", ssh.KeyAlgoED25519, algorithms[0]) + } +} + +func TestKnownHostKeyAlgorithmsReturnsNilForUnknownHost(t *testing.T) { + tempDir := t.TempDir() + knownHostsPath := filepath.Join(tempDir, "known_hosts") + + _, privateKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("failed to generate host key: %v", err) + } + + signer, err := ssh.NewSignerFromKey(privateKey) + if err != nil { + t.Fatalf("failed to create signer: %v", err) + } + + line := knownhosts.Line([]string{"example.com"}, signer.PublicKey()) + if err := os.WriteFile(knownHostsPath, []byte(line), 0600); err != nil { + t.Fatalf("failed to write known_hosts file: %v", err) + } + + callback, err := knownhosts.New(knownHostsPath) + if err != nil { + t.Fatalf("failed to create known_hosts callback: %v", err) + } + + algorithms := knownHostKeyAlgorithms(callback, "other.example.com:22") + if algorithms != nil { + t.Fatalf("expected nil algorithms for unknown host, got %v", algorithms) + } +} + func TestProjectDirExistsLocal(t *testing.T) { tempHome := t.TempDir() t.Setenv("HOME", tempHome)