diff --git a/cli/cmd/tcping.go b/cli/cmd/tcping.go new file mode 100644 index 0000000000..5574ef6506 --- /dev/null +++ b/cli/cmd/tcping.go @@ -0,0 +1,162 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +package cmd + +import ( + "context" + "fmt" + "math" + "net" + "os" + "os/signal" + "syscall" + "time" + + "github.com/spf13/cobra" +) + +var ( + tcpingCount int + tcpingInterval time.Duration + tcpingTimeout time.Duration +) + +var tcpingCmd = &cobra.Command{ + Use: "tcping HOST PORT", + Short: "Ping a TCP port on a remote host", + Long: `Probe a TCP port by performing a TCP handshake and measuring round-trip latency. + +This is useful for verifying TCP connectivity and measuring connection setup time +without needing additional tools like curl in a loop or nmap. + +The command runs continuously until interrupted (Ctrl+C) or the specified count +is reached, then prints summary statistics.`, + Example: ` # Continuously ping a web server on port 443 + kubectl retina tcping example.com 443 + + # Send exactly 10 probes with a 500ms interval + kubectl retina tcping example.com 80 --count 10 --interval 500ms + + # Use a 5-second connection timeout + kubectl retina tcping 10.0.0.1 8080 --timeout 5s`, + Args: cobra.ExactArgs(2), + RunE: runTCPing, +} + +func init() { + Retina.AddCommand(tcpingCmd) + tcpingCmd.Flags().IntVarP(&tcpingCount, "count", "c", 0, "Number of probes to send (0 = unlimited)") + tcpingCmd.Flags().DurationVarP(&tcpingInterval, "interval", "i", 1*time.Second, "Interval between probes") + tcpingCmd.Flags().DurationVarP(&tcpingTimeout, "timeout", "t", 2*time.Second, "TCP connection timeout per probe") +} + +type tcpingStats struct { + sent int + succeeded int + minRTT time.Duration + maxRTT time.Duration + totalRTT time.Duration +} + +func (s *tcpingStats) record(rtt time.Duration, ok bool) { + s.sent++ + if !ok { + return + } + s.succeeded++ + s.totalRTT += rtt + if s.succeeded == 1 || rtt < s.minRTT { + s.minRTT = rtt + } + if rtt > s.maxRTT { + s.maxRTT = rtt + } +} + +func (s *tcpingStats) avgRTT() time.Duration { + if s.succeeded == 0 { + return 0 + } + return time.Duration(math.Round(float64(s.totalRTT) / float64(s.succeeded))) +} + +func (s *tcpingStats) lossPercent() float64 { + if s.sent == 0 { + return 0 + } + return float64(s.sent-s.succeeded) / float64(s.sent) * 100 +} + +func runTCPing(cmd *cobra.Command, args []string) error { + host := args[0] + port := args[1] + addr := net.JoinHostPort(host, port) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + defer signal.Stop(sigCh) + go func() { + select { + case <-sigCh: + cancel() + case <-ctx.Done(): + } + }() + + fmt.Fprintf(cmd.OutOrStdout(), "TCPing %s: tcp connect ...\n", addr) + + stats := &tcpingStats{} + seq := 0 + for { + if tcpingCount > 0 && seq >= tcpingCount { + break + } + + if seq > 0 { + select { + case <-ctx.Done(): + printSummary(cmd, addr, stats) + return nil + case <-time.After(tcpingInterval): + } + } + + if ctx.Err() != nil { + break + } + + seq++ + start := time.Now() + conn, err := net.DialTimeout("tcp", addr, tcpingTimeout) + rtt := time.Since(start) + + if err != nil { + stats.record(rtt, false) + fmt.Fprintf(cmd.OutOrStdout(), "seq=%d %s - timeout/error: %v\n", seq, addr, err) + } else { + conn.Close() + stats.record(rtt, true) + fmt.Fprintf(cmd.OutOrStdout(), "seq=%d %s rtt=%v\n", seq, addr, rtt.Round(time.Microsecond)) + } + } + + printSummary(cmd, addr, stats) + return nil +} + +func printSummary(cmd *cobra.Command, addr string, stats *tcpingStats) { + out := cmd.OutOrStdout() + fmt.Fprintln(out) + fmt.Fprintf(out, "--- %s tcping statistics ---\n", addr) + fmt.Fprintf(out, "%d probes sent, %d successful, %.1f%% loss\n", + stats.sent, stats.succeeded, stats.lossPercent()) + if stats.succeeded > 0 { + fmt.Fprintf(out, "rtt min/avg/max = %v/%v/%v\n", + stats.minRTT.Round(time.Microsecond), + stats.avgRTT().Round(time.Microsecond), + stats.maxRTT.Round(time.Microsecond)) + } +} diff --git a/cli/cmd/tcping_test.go b/cli/cmd/tcping_test.go new file mode 100644 index 0000000000..9792fee75b --- /dev/null +++ b/cli/cmd/tcping_test.go @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +package cmd + +import ( + "bytes" + "fmt" + "net" + "testing" + "time" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func runTCPingForTest(t *testing.T, args []string) (string, error) { + t.Helper() + var buf bytes.Buffer + Retina.SetOut(&buf) + Retina.SetErr(&buf) + Retina.SetArgs(append([]string{"tcping"}, args...)) + t.Cleanup(func() { + Retina.SetArgs(nil) + Retina.SetOut(nil) + Retina.SetErr(nil) + tcpingCount = 0 + tcpingInterval = 1 * time.Second + tcpingTimeout = 2 * time.Second + }) + err := Retina.Execute() + return buf.String(), err +} + +func TestTCPingSuccess(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + conn.Close() + } + }() + + _, port, err := net.SplitHostPort(ln.Addr().String()) + require.NoError(t, err) + + output, err := runTCPingForTest(t, []string{"127.0.0.1", port, "--count", "3", "--interval", "50ms"}) + require.NoError(t, err) + + assert.Contains(t, output, "TCPing 127.0.0.1:") + assert.Contains(t, output, "seq=1") + assert.Contains(t, output, "seq=3") + assert.Contains(t, output, "3 probes sent, 3 successful, 0.0% loss") + assert.Contains(t, output, "rtt min/avg/max") +} + +func TestTCPingFailure(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + _, port, err := net.SplitHostPort(ln.Addr().String()) + require.NoError(t, err) + ln.Close() + + output, err := runTCPingForTest(t, []string{"127.0.0.1", port, "--count", "2", "--interval", "50ms", "--timeout", "200ms"}) + require.NoError(t, err) + + assert.Contains(t, output, "2 probes sent, 0 successful, 100.0% loss") +} + +func TestTCPingStats(t *testing.T) { + s := &tcpingStats{} + + s.record(10*time.Millisecond, true) + s.record(20*time.Millisecond, true) + s.record(30*time.Millisecond, false) + s.record(5*time.Millisecond, true) + + assert.Equal(t, 4, s.sent) + assert.Equal(t, 3, s.succeeded) + assert.Equal(t, 5*time.Millisecond, s.minRTT) + assert.Equal(t, 20*time.Millisecond, s.maxRTT) + assert.InDelta(t, 25.0, s.lossPercent(), 0.1) + + avg := s.avgRTT() + expected := (10*time.Millisecond + 20*time.Millisecond + 5*time.Millisecond) / 3 + assert.InDelta(t, float64(expected), float64(avg), float64(time.Millisecond)) +} + +func TestTCPingHelp(t *testing.T) { + output, err := runTCPingForTest(t, []string{"--help"}) + require.NoError(t, err) + + assert.Contains(t, output, "tcping HOST PORT") + assert.Contains(t, output, "Probe a TCP port by performing a TCP handshake") + assert.Contains(t, output, "--count") + assert.Contains(t, output, "--interval") + assert.Contains(t, output, "--timeout") +} + +func TestTCPingMissingArgs(t *testing.T) { + err := tcpingCmd.Args(tcpingCmd, []string{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "accepts 2 arg(s)") + + err = tcpingCmd.Args(tcpingCmd, []string{"only-host"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "accepts 2 arg(s)") + + err = tcpingCmd.Args(tcpingCmd, []string{"host", "port"}) + require.NoError(t, err) +} + +func TestTCPingStatsEmpty(t *testing.T) { + s := &tcpingStats{} + assert.Equal(t, time.Duration(0), s.avgRTT()) + assert.Equal(t, float64(0), s.lossPercent()) +} + +func TestTCPingStatsAllFail(t *testing.T) { + s := &tcpingStats{} + s.record(10*time.Millisecond, false) + s.record(10*time.Millisecond, false) + + assert.Equal(t, 2, s.sent) + assert.Equal(t, 0, s.succeeded) + assert.Equal(t, 100.0, s.lossPercent()) + assert.Equal(t, time.Duration(0), s.avgRTT()) +} + +func TestPrintSummary(t *testing.T) { + s := &tcpingStats{} + s.record(10*time.Millisecond, true) + s.record(20*time.Millisecond, true) + + var buf bytes.Buffer + c := &cobra.Command{} + c.SetOut(&buf) + + printSummary(c, "example.com:80", s) + output := buf.String() + assert.Contains(t, output, "example.com:80 tcping statistics") + assert.Contains(t, output, fmt.Sprintf("2 probes sent, 2 successful, %.1f%% loss", 0.0)) + assert.Contains(t, output, "rtt min/avg/max") +}