diff --git a/.github/skills/configuration-management/SKILL.md b/.github/skills/configuration-management/SKILL.md index bca1410..bd23817 100644 --- a/.github/skills/configuration-management/SKILL.md +++ b/.github/skills/configuration-management/SKILL.md @@ -53,7 +53,6 @@ Validate configuration values after parsing. type Config struct { SOCKS5Addr string `env:"SOCKS5_ADDR" envDefault:"0.0.0.0:1080"` HTTPProxyAddr string `env:"HTTP_PROXY_ADDR" envDefault:"0.0.0.0:8080"` -Credentials []string `env:"CREDENTIALS"` TOREnabled bool `env:"TOR_ENABLED" envDefault:"false"` TORControlAddr string `env:"TOR_CONTROLLER_ADDR" envDefault:"127.0.0.1:9051"` Timezone string `env:"TIMEZONE" envDefault:"UTC"` @@ -136,7 +135,7 @@ logger.Fatal().Err(err).Msg("failed to start SOCKS5 server") ## Related skills -- Credential Management - Handling credentials in config +- Credential Management - Handling persisted proxy users and admin-auth flows - Docker Deployment - Passing config via environment - SOCKS5 Protocol - Using config in SOCKS5 setup - HTTP Proxy - Using config in HTTP proxy setup diff --git a/.github/workflows/build-release.yml b/.github/workflows/build-release.yml index fdd2c1a..91485d3 100644 --- a/.github/workflows/build-release.yml +++ b/.github/workflows/build-release.yml @@ -22,7 +22,7 @@ jobs: - uses: actions/setup-go@v2 with: - go-version: 1.25 + go-version: 1.26 - run: go mod download diff --git a/.github/workflows/test-coverage.yml b/.github/workflows/test-coverage.yml index 868ea84..516958b 100644 --- a/.github/workflows/test-coverage.yml +++ b/.github/workflows/test-coverage.yml @@ -32,7 +32,7 @@ jobs: unset GOCOVERDIR go version go tool -n covdata || true - go test ./... -coverprofile=./cover.out -covermode=atomic -coverpkg=./... + go test ./... -coverprofile=./cover.out -covermode=atomic - name: Check Test Coverage id: coverage diff --git a/.goreleaser.yml b/.goreleaser.yml index a81ef99..8d7526a 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -137,6 +137,7 @@ archives: - zip files: - README.md + - docs/CONFIG.md - LICENSE @@ -185,9 +186,6 @@ nfpms: - src: systemd/nanoproxy.service dst: /etc/systemd/system/nanoproxy.service type: "config|noreplace" - - src: config/nanoproxy - dst: /etc/nanoproxy/nanoproxy - type: "config|noreplace" formats: - deb - rpm diff --git a/AGENTS.md b/AGENTS.md index 73a733b..75e26b3 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -13,8 +13,8 @@ ## Key architecture and data flow - Config is env-driven via struct tags in `pkg/config/config.go` (`caarlos0/env/v10`), not via JSON/YAML files. -- Credentials are loaded from `CREDENTIALS` (`username:bcryptHash` list), stored in `pkg/credential/credentials.go`, and - validated with bcrypt. +- User credentials are managed through the Admin Console and stored in `USER_STORE_PATH` (BoltDB). Credentials are + validated with bcrypt in `pkg/credential/credentials.go`. - SOCKS5 flow (`pkg/socks5/socks5.go`): handshake -> auth negotiation -> request parse (`pkg/socks5/request.go`) -> optional DNS resolve -> relay. - HTTP flow (`pkg/httpproxy/httpproxy.go`): `ServeHTTP` dispatches `CONNECT` vs normal HTTP; hop-by-hop headers are @@ -28,8 +28,8 @@ `pkg/tor/controller.go`). - `Dockerfile-tor` + `supervisord.conf` run both Tor and NanoProxy in one container; this is the intended Tor deployment path. -- System package/service deployment uses `systemd/nanoproxy.service` and env file `config/nanoproxy` ( - `/etc/nanoproxy/nanoproxy`). +- System package/service deployment uses `systemd/nanoproxy.service` with inline `Environment=` values and optional + systemd drop-ins for overrides. ## Developer workflows that matter here diff --git a/Dockerfile b/Dockerfile index 5c1409d..7e6847d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,13 @@ -FROM alpine:3 +FROM alpine:3.21 ARG TARGETPLATFORM -COPY $TARGETPLATFORM/nanoproxy /usr/bin +COPY --chmod=0755 $TARGETPLATFORM/nanoproxy /usr/bin/nanoproxy + +ENV USER_STORE_PATH=/var/lib/nanoproxy/data.db + +VOLUME ["/var/lib/nanoproxy"] + EXPOSE 1080 EXPOSE 8080 +EXPOSE 9090 ENTRYPOINT ["nanoproxy"] \ No newline at end of file diff --git a/Dockerfile-tor b/Dockerfile-tor index 1f19c32..c015883 100644 --- a/Dockerfile-tor +++ b/Dockerfile-tor @@ -4,17 +4,23 @@ ARG TARGETPLATFORM RUN apk update && \ apk add --no-cache tor supervisor -RUN mkdir -p /var/log/supervisor +RUN mkdir -p /var/log/supervisor /var/log/tor -COPY $TARGETPLATFORM/nanoproxy /usr/bin +COPY --chmod=0755 $TARGETPLATFORM/nanoproxy /usr/bin/nanoproxy COPY supervisord.conf /etc/supervisord.conf RUN mkdir -p /etc/tor && \ echo -e "ControlPort 9051\nCookieAuthentication 0" > /etc/tor/torrc -RUN mkdir -p /var/lib/tor +RUN mkdir -p /var/lib/tor /var/lib/nanoproxy + + +ENV USER_STORE_PATH=/var/lib/nanoproxy/data.db + +VOLUME ["/var/lib/tor", "/var/lib/nanoproxy"] EXPOSE 1080 EXPOSE 8080 +EXPOSE 9090 ENTRYPOINT ["/usr/bin/supervisord", "-c", "/etc/supervisord.conf"] \ No newline at end of file diff --git a/Makefile b/Makefile index b8b5d6a..923b52f 100644 --- a/Makefile +++ b/Makefile @@ -6,12 +6,12 @@ install-go-test-coverage: .PHONY: check-coverage check-coverage: install-go-test-coverage - env -u GOCOVERDIR go test ./... -coverprofile=./cover.out -covermode=atomic -coverpkg=./... + env -u GOCOVERDIR go test ./... -coverprofile=./cover.out -covermode=atomic env -u GOCOVERDIR ${GOBIN}/go-test-coverage --config=./.testcoverage.yml .PHONY: coverage-only coverage-only: - env -u GOCOVERDIR go test ./... -coverprofile=./cover.out -covermode=atomic -coverpkg=./... + env -u GOCOVERDIR go test ./... -coverprofile=./cover.out -covermode=atomic clean: @echo "Cleaning up dist directory..." diff --git a/README.md b/README.md index e1b0e7b..d045fbc 100644 --- a/README.md +++ b/README.md @@ -115,10 +115,8 @@ NanoProxy provides the following features: - [x] **TOR support.** NanoProxy can be run with Tor support to provide anonymized network traffic (Docker only). - [x] **IP Rotation with Tor.** NanoProxy allows for IP rotation using the Tor network, providing enhanced anonymity and privacy by periodically changing exit nodes. -- [ ] **Authentication Management from Dashboard.** Easily manage user authentication settings and credentials via a +- [x] **Authentication Management from Dashboard.** Easily manage user authentication settings and credentials via a comprehensive and user-friendly web dashboard, ensuring secure access to proxy features. -- [ ] **Change IP via API.** Programmatically request IP changes through a robust API, facilitating automated and - dynamic IP management for different use cases. ## Installation @@ -208,64 +206,32 @@ nanoproxy You can also run NanoProxy using Docker. To do so, you can use the following command: ```shell -docker run -p 1080:1080 -p 8080:8080 ghcr.io/ryanbekhen/nanoproxy:latest +docker run -p 1080:1080 -p 8080:8080 -p 9090:9090 ghcr.io/ryanbekhen/nanoproxy:latest ``` You can also run NanoProxy behind Tor using the following command: ```shell -docker run --rm -e TOR_ENABLED=true -d --privileged --cap-add=NET_ADMIN --sysctl net.ipv6.conf.all.disable_ipv6=0 --sysctl net.ipv4.conf.all.src_valid_mark=1 -p 1080:1080 -p 8080:8080 ghcr.io/ryanbekhen/nanoproxy-tor:latest +docker run --rm -e TOR_ENABLED=true -d --privileged --cap-add=NET_ADMIN --sysctl net.ipv6.conf.all.disable_ipv6=0 --sysctl net.ipv4.conf.all.src_valid_mark=1 -p 1080:1080 -p 8080:8080 -p 9090:9090 ghcr.io/ryanbekhen/nanoproxy-tor:latest ``` ## Configuration -You can also set the configuration using environment variables. Create a file at `/etc/nanoproxy/nanoproxy` and add the -desired values: - -```text -ADDR=:1080 -ADDR_HTTP=:8080 -NETWORK=tcp -TZ=Asia/Jakarta -CLIENT_TIMEOUT=10s -DNS_TIMEOUT=10s -CREDENTIALS=username:passwordHash -``` - -For the creation of the password hash, you can use the `htpasswd -nB username` command, but you need to install the -`apache2-utils` package first. To install the package, run the following command: - -```shell -sudo apt install apache2-utils -``` +NanoProxy is configured entirely through environment variables. For detailed information about all available +configuration options, environment variable reference, and examples, see [docs/CONFIG.md](docs/CONFIG.md). -Then, you can use the `htpasswd` command to generate the password hash: +### Quick Start Configuration ```shell -htpasswd -nB username +export ADDR=:1080 +export ADDR_HTTP=:8080 +export ADDR_ADMIN=:9090 +export LOG_LEVEL=info +export USER_STORE_PATH=./nanoproxy-data.db ``` -This will prompt you to enter the password. After entering the password, the command will output the username and the -password hash. You can then use the output to set the `CREDENTIALS` environment variable. - -The following table lists the available configuration options: - -| Name | Description | Default Value | -|-----------------------|-----------------------------------------------------------------|---------------| -| ADDR | The address to listen on. | `:1080` | -| ADDR_HTTP | The address to listen on for HTTP requests. | `:8080` | -| NETWORK | The network to listen on. (tcp, tcp4, tcp6) | `tcp` | -| TZ | The timezone to use. | `Local` | -| CLIENT_TIMEOUT | The timeout for connecting to the destination Server. | `10s` | -| DNS_TIMEOUT | The timeout for DNS resolution. | `10s` | -| CREDENTIALS | The credentials to use for authentication. | `""` | -| TOR_ENABLED | Enable Tor support. (works only on Docker) | `false` | -| TOR_IDENTITY_INTERVAL | The interval to change the Tor identity. (works only on Docker) | `10m` | - -- **ADDR_HTTP**: By default, NanoProxy listens for HTTP proxy traffic on `:8080`. You can set this address to any host: - port combination for custom setups. -- **CREDENTIALS**: When enabled, both SOCKS5 and HTTP Proxy requests are authenticated using the credentials provided in - this field. This supports `username:password` pairs. +Then access the admin panel at `http://localhost:9090/admin/setup` to create your initial admin account and add proxy +users. ## Logging @@ -293,7 +259,8 @@ curl -x socks5://localhost:1080 https://google.com curl -x localhost:8080 https://google.com ``` -If credentials are enabled for HTTP Proxy, use the `-U` flag to supply the username and password: +If the HTTP proxy requires authentication, use the `-U` flag to supply the username and password for a proxy user +created in the Admin Console: ```shell curl -x http://localhost:8080 -U username:password https://example.com @@ -303,8 +270,9 @@ In both cases, replace `localhost:8080` with the actual address and port where y ## Authentication for HTTP Proxy -If authentication is enabled (via the `CREDENTIALS` configuration), the HTTP Proxy requires clients to include the -`Proxy-Authorization` header in their requests. The header must use the following format: +Proxy users are managed through the Admin Console. After setting up an admin account, you can create and manage proxy +users through the web interface. The HTTP Proxy requires clients to include the `Proxy-Authorization` header in their +requests with the correct format: ```http Proxy-Authorization: Basic diff --git a/config/nanoproxy b/config/nanoproxy deleted file mode 100644 index f0b8a81..0000000 --- a/config/nanoproxy +++ /dev/null @@ -1,3 +0,0 @@ -ADDR=:1080 -NETWORK=tcp -TZ=Local diff --git a/docs/CONFIG.md b/docs/CONFIG.md new file mode 100644 index 0000000..ab3617a --- /dev/null +++ b/docs/CONFIG.md @@ -0,0 +1,224 @@ +# NanoProxy Configuration + +NanoProxy is configured entirely through environment variables. There is no configuration file needed. + +## Environment Variables Reference + +### Network Configuration + +| Variable | Type | Default | Description | +|--------------|--------|---------|--------------------------------------------------------| +| `NETWORK` | string | `tcp` | Network protocol for listening (`tcp`, `tcp4`, `tcp6`) | +| `ADDR` | string | `:1080` | SOCKS5 server listen address (host:port) | +| `ADDR_HTTP` | string | `:8080` | HTTP proxy server listen address (host:port) | +| `ADDR_ADMIN` | string | `:9090` | Admin panel listen address (host:port) | + +### Timeout Configuration + +| Variable | Type | Default | Description | +|------------------|----------|---------|--------------------------------------------------| +| `CLIENT_TIMEOUT` | duration | `15s` | Client connection timeout (read/write deadlines) | +| `DEST_TIMEOUT` | duration | `15s` | Destination connection timeout | + +### Logging Configuration + +| Variable | Type | Default | Description | +|-------------|--------|---------|-----------------------------------------------------------------| +| `LOG_LEVEL` | string | `info` | Log level: `debug`, `info`, `warn`, `error` | +| `TZ` | string | `Local` | Timezone for logging timestamps (IANA timezone name or `Local`) | + +### User Storage Configuration + +| Variable | Type | Default | Description | +|-------------------|--------|---------------------|--------------------------------------------------------------------------| +| `USER_STORE_PATH` | string | `nanoproxy-data.db` | Path to BoltDB database for persistent user storage and traffic tracking | + +### Admin Panel Configuration + +| Variable | Type | Default | Description | +|----------------------------|--------------|---------|--------------------------------------------------------------------------------------------------| +| `ADMIN_COOKIE_SECURE` | bool | `false` | Enable secure cookie flag for HTTPS deployments (`true`/`false`) | +| `ADMIN_MAX_LOGIN_ATTEMPTS` | int | `5` | Maximum failed login attempts before account lockout | +| `ADMIN_LOGIN_WINDOW` | duration | `5m` | Time window for tracking failed login attempts | +| `ADMIN_LOCKOUT_DURATION` | duration | `10m` | Duration to lock account after max failed attempts | +| `ADMIN_ALLOWED_ORIGINS` | string (csv) | empty | Comma-separated list of allowed CORS origins for admin panel (e.g., `https://admin.example.com`) | + +### Tor Integration + +| Variable | Type | Default | Description | +|-------------------------|----------|---------|----------------------------------------------------------------| +| `TOR_ENABLED` | bool | `false` | Enable Tor integration for anonymous proxying (`true`/`false`) | +| `TOR_IDENTITY_INTERVAL` | duration | `10m` | Interval for switching Tor exit node identity | + +## Configuration Examples + +### Basic SOCKS5 + HTTP Proxy (No Auth) + +```bash +ADDR=:1080 +ADDR_HTTP=:8080 +LOG_LEVEL=info +``` + +### Basic Setup with User Storage + +```bash +ADDR=:1080 +ADDR_HTTP=:8080 +ADDR_ADMIN=:9090 +USER_STORE_PATH=/var/lib/nanoproxy/data.db +LOG_LEVEL=info +``` + +Then access the admin console to create users. + +### Production Deployment + +```bash +NETWORK=tcp +ADDR=0.0.0.0:1080 +ADDR_HTTP=0.0.0.0:8080 +ADDR_ADMIN=127.0.0.1:9090 +LOG_LEVEL=warn +CLIENT_TIMEOUT=30s +DEST_TIMEOUT=30s +USER_STORE_PATH=/var/lib/nanoproxy/data.db +ADMIN_COOKIE_SECURE=true +ADMIN_ALLOWED_ORIGINS=https://admin.example.com +``` + +### Tor Mode (Anonymized) + +```bash +ADDR=:1080 +ADDR_HTTP=:8080 +TOR_ENABLED=true +TOR_IDENTITY_INTERVAL=5m +LOG_LEVEL=info +``` + +### Debug Mode (Development) + +```bash +ADDR=127.0.0.1:1080 +ADDR_HTTP=127.0.0.1:8080 +LOG_LEVEL=debug +CLIENT_TIMEOUT=60s +DEST_TIMEOUT=60s +``` + +## Duration Format + +Durations use Go duration syntax: + +- `s` = seconds (e.g., `5s`) +- `m` = minutes (e.g., `5m`) +- `h` = hours (e.g., `1h`) + +Examples: `15s`, `5m`, `1h30m` + +## Timezone Format + +Use IANA timezone names (e.g., `America/New_York`, `Europe/London`, `Asia/Tokyo`) or `Local` for system timezone. + +## Log Levels + +- `debug` - Detailed debug information (request completion, connection details) +- `info` - General informational messages (server startup, important events) +- `warn` - Warning messages (non-critical issues) +- `error` - Error messages only (authentication failures, connection errors, etc.) + +## Docker/Compose Example + +```yaml +version: '3.8' +services: + nanoproxy: + image: nanoproxy:latest + ports: + - "1080:1080" + - "8080:8080" + - "9090:9090" + environment: + ADDR: 0.0.0.0:1080 + ADDR_HTTP: 0.0.0.0:8080 + ADDR_ADMIN: 0.0.0.0:9090 + LOG_LEVEL: info + USER_STORE_PATH: /data/nanoproxy.db + ADMIN_COOKIE_SECURE: "true" + volumes: + - nanoproxy_data:/data + +volumes: + nanoproxy_data: +``` + +## Systemd Service Example + +```ini +[Unit] +Description=NanoProxy +After=network.target + +[Service] +Type=simple +User=nanoproxy +WorkingDirectory=/opt/nanoproxy +ExecStart=/opt/nanoproxy/nanoproxy +Restart=on-failure +RestartSec=10s + +Environment="ADDR=0.0.0.0:1080" +Environment="ADDR_HTTP=0.0.0.0:8080" +Environment="LOG_LEVEL=info" +Environment="USER_STORE_PATH=/var/lib/nanoproxy/data.db" + +[Install] +WantedBy=multi-user.target +``` + +## Admin Console and Persistent Proxy Users + +NanoProxy starts an admin console on `ADDR_ADMIN`. + +- Visit `/` or `/admin` on the admin address. +- On first run (when no admin exists in `USER_STORE_PATH`), create the admin account at `/admin/setup`. +- After setup, log in using that admin account. +- Add or delete proxy users from the UI. +- Those proxy users are saved to `USER_STORE_PATH` and loaded again on restart. + +Example: + +```shell +export ADDR=:1080 +export ADDR_HTTP=:8080 +export ADDR_ADMIN=:9090 +export USER_STORE_PATH=/var/lib/nanoproxy/data.db +export ADMIN_COOKIE_SECURE=true +export ADMIN_ALLOWED_ORIGINS=https://admin.example.com +go run . +``` + +### Admin Console Notes + +- Admin-managed users are stored separately and reloaded automatically. +- Both HTTP and SOCKS5 reuse the same in-memory authentication view, so behavior stays aligned across protocols. + +### Admin Security Notes + +- Admin state-changing actions use CSRF tokens. +- CSRF tokens are rotated after successful state-changing actions. +- Login attempts are rate-limited (`ADMIN_MAX_LOGIN_ATTEMPTS`, `ADMIN_LOGIN_WINDOW`, `ADMIN_LOCKOUT_DURATION`). +- If `ADMIN_ALLOWED_ORIGINS` is configured, requests without allowed `Origin`/`Referer` are rejected. + +## Notes + +- All environment variables are optional and use sensible defaults +- SOCKS5 and HTTP proxy share the same credential store (database) +- All users are managed through the BoltDB database specified by `USER_STORE_PATH` +- Admin panel is accessible at `http://localhost:9090` (or configured `ADDR_ADMIN`) +- Initial admin account is set up through the `/admin/setup` web interface on first run +- Proxy users can be created and managed through the admin panel +- For production, always use `LOG_LEVEL=warn` or `error` to reduce log noise +- Database file location should be on persistent storage in containers/orchestration + diff --git a/go.mod b/go.mod index be91bfa..4b25f55 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/caarlos0/env/v10 v10.0.0 github.com/rs/zerolog v1.34.0 github.com/stretchr/testify v1.11.1 + go.etcd.io/bbolt v1.4.3 golang.org/x/crypto v0.49.0 golang.org/x/net v0.52.0 ) diff --git a/go.sum b/go.sum index 3b246a7..f594da7 100644 --- a/go.sum +++ b/go.sum @@ -19,19 +19,17 @@ github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= -golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +go.etcd.io/bbolt v1.4.3 h1:dEadXpI6G79deX5prL3QRNP6JB8UxVkqo4UPnHaNXJo= +go.etcd.io/bbolt v1.4.3/go.mod h1:tKQlpPaYCVFctUIgFKFnAlvbmB3tpy1vkTnDWohtc0E= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= -golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= -golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= -golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/nanoproxy.go b/nanoproxy.go index 06c375c..7e535c4 100644 --- a/nanoproxy.go +++ b/nanoproxy.go @@ -2,6 +2,7 @@ package main import ( "errors" + "fmt" "net" "net/http" "os" @@ -10,12 +11,14 @@ import ( "github.com/caarlos0/env/v10" "github.com/rs/zerolog" + "github.com/ryanbekhen/nanoproxy/pkg/admin" "github.com/ryanbekhen/nanoproxy/pkg/config" "github.com/ryanbekhen/nanoproxy/pkg/credential" "github.com/ryanbekhen/nanoproxy/pkg/httpproxy" "github.com/ryanbekhen/nanoproxy/pkg/resolver" "github.com/ryanbekhen/nanoproxy/pkg/socks5" "github.com/ryanbekhen/nanoproxy/pkg/tor" + "github.com/ryanbekhen/nanoproxy/pkg/traffic" _ "time/tzdata" ) @@ -28,25 +31,33 @@ func main() { logger.Fatal().Msg(err.Error()) } + level, err := zerolog.ParseLevel(strings.ToLower(strings.TrimSpace(cfg.LogLevel))) + if err != nil { + logger.Warn().Str("log_level", cfg.LogLevel).Msg("Invalid LOG_LEVEL, falling back to info") + level = zerolog.InfoLevel + } + zerolog.SetGlobalLevel(level) + logger = logger.Level(level) + loc, _ := time.LoadLocation(cfg.Timezone) if loc != nil { time.Local = loc } - var credentials credential.Store - if len(cfg.Credentials) > 0 { - credentials = credential.NewStaticCredentialStore() - for _, cred := range cfg.Credentials { - credArr := strings.Split(cred, ":") - if len(credArr) != 2 { - logger.Fatal().Msgf("Invalid credential: %s", cred) - } - - credentials.Add(credArr[0], credArr[1]) - } + credentials, userFileStore, err := buildCredentialStore(cfg) + if err != nil { + logger.Fatal().Err(err).Msg("Failed to initialize credentials") } + adminStore := admin.NewBoltAdminStore(cfg.UserStorePath) dnsResolver := &resolver.DNSResolver{} + trafficTracker := traffic.NewTracker() + + // Load persisted traffic totals + trafficStore := traffic.NewBoltStore(cfg.UserStorePath) + if err := trafficTracker.LoadPersistedTotals(trafficStore); err != nil { + logger.Warn().Err(err).Msg("Failed to load persisted traffic totals") + } httpConfig := httpproxy.Config{ Credentials: credentials, @@ -55,6 +66,7 @@ func main() { ClientConnTimeout: cfg.ClientTimeout, Dial: net.Dial, Resolver: dnsResolver, + Tracker: trafficTracker, } httpServer := httpproxy.New(&httpConfig) @@ -64,6 +76,7 @@ func main() { DestConnTimeout: cfg.DestTimeout, ClientConnTimeout: cfg.ClientTimeout, Resolver: dnsResolver, + Tracker: trafficTracker, } if cfg.TorEnabled { @@ -82,14 +95,14 @@ func main() { }() } - if len(cfg.Credentials) > 0 { + if credentials != nil { authenticator := &socks5.UserPassAuthenticator{ Credentials: credentials, } socks5Config.Authentication = append(socks5Config.Authentication, authenticator) } - sock5Server := socks5.New(&socks5Config) + socks5Server := socks5.New(&socks5Config) go func() { logger.Info().Msgf("Starting HTTP proxy server on %s://%s", cfg.Network, cfg.ADDRHttp) @@ -109,10 +122,51 @@ func main() { go func() { logger.Info().Msgf("Starting SOCKS5 server on %s://%s", cfg.Network, cfg.ADDR) - if err := sock5Server.ListenAndServe(cfg.Network, cfg.ADDR); err != nil { + if err := socks5Server.ListenAndServe(cfg.Network, cfg.ADDR); err != nil { + logger.Fatal().Msg(err.Error()) + } + }() + + adminServer := admin.New(&admin.Config{ + Credentials: credentials, + UserStore: userFileStore, + AdminStore: adminStore, + TrafficStore: trafficStore, + Tracker: trafficTracker, + CookieSecure: cfg.AdminCookieSecure, + MaxLoginAttempts: cfg.AdminMaxLoginAttempts, + LoginWindow: cfg.AdminLoginWindow, + LockoutDuration: cfg.AdminLockoutDuration, + AllowedOrigins: cfg.AdminAllowedOrigins, + Logger: &logger, + }) + + go func() { + logger.Info().Msgf("Starting admin server on %s", cfg.ADDRAdmin) + + server := &http.Server{ + Addr: cfg.ADDRAdmin, + Handler: adminServer.Handler(), + ReadTimeout: 15 * time.Second, + WriteTimeout: 15 * time.Second, + IdleTimeout: 60 * time.Second, + } + + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { logger.Fatal().Msg(err.Error()) } }() select {} } + +func buildCredentialStore(cfg *config.Config) (*credential.StaticCredentialStore, credential.PersistentStore, error) { + userStore := credential.NewBoltStore(cfg.UserStorePath) + + credentials := credential.NewStaticCredentialStore() + if err := credential.LoadInto(userStore, credentials); err != nil { + return nil, userStore, fmt.Errorf("load persisted proxy users: %w", err) + } + + return credentials, userStore, nil +} diff --git a/nanoproxy_test.go b/nanoproxy_test.go new file mode 100644 index 0000000..d741728 --- /dev/null +++ b/nanoproxy_test.go @@ -0,0 +1,38 @@ +package main + +import ( + "path/filepath" + "testing" + + "github.com/ryanbekhen/nanoproxy/pkg/config" + "github.com/ryanbekhen/nanoproxy/pkg/credential" +) + +func TestBuildCredentialStore_LoadsFromDatabase(t *testing.T) { + t.Parallel() + + storePath := filepath.Join(t.TempDir(), "data.db") + userStore := credential.NewBoltStore(storePath) + + persisted := credential.NewStaticCredentialStore() + persisted.Add("db-user", "password") + if err := userStore.Save(persisted.Snapshot()); err != nil { + t.Fatalf("save persisted users: %v", err) + } + + cfg := &config.Config{ + UserStorePath: storePath, + } + + credentials, _, err := buildCredentialStore(cfg) + if err != nil { + t.Fatalf("buildCredentialStore returned error: %v", err) + } + if credentials == nil { + t.Fatal("expected non-nil credentials") + } + + if !credentials.Valid("db-user", "password") { + t.Fatal("expected db-user to be valid from database") + } +} diff --git a/pkg/admin/admin_store.go b/pkg/admin/admin_store.go new file mode 100644 index 0000000..18368dd --- /dev/null +++ b/pkg/admin/admin_store.go @@ -0,0 +1,93 @@ +package admin + +import ( + "errors" + "os" + + "go.etcd.io/bbolt" +) + +var adminBucket = []byte("admin") + +// AdminCredentialStore persists the admin login account used by the admin console. +type AdminCredentialStore interface { + Load() (username string, passwordHash string, found bool, err error) + Save(username string, passwordHash string) error +} + +type BoltAdminStore struct { + path string +} + +func NewBoltAdminStore(path string) *BoltAdminStore { + return &BoltAdminStore{path: path} +} + +func (b *BoltAdminStore) Load() (string, string, bool, error) { + if b == nil || b.path == "" { + return "", "", false, nil + } + + if _, err := os.Stat(b.path); err != nil { + if errors.Is(err, os.ErrNotExist) { + return "", "", false, nil + } + return "", "", false, err + } + + db, err := bbolt.Open(b.path, 0o600, nil) + if err != nil { + return "", "", false, err + } + defer db.Close() + + var username, passwordHash string + err = db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(adminBucket) + if bucket == nil { + return nil + } + + username = string(bucket.Get([]byte("username"))) + passwordHash = string(bucket.Get([]byte("password_hash"))) + return nil + }) + if err != nil { + return "", "", false, err + } + + if username == "" || passwordHash == "" { + return "", "", false, nil + } + + return username, passwordHash, true, nil +} + +func (b *BoltAdminStore) Save(username string, passwordHash string) error { + if b == nil || b.path == "" { + return nil + } + + db, err := bbolt.Open(b.path, 0o600, nil) + if err != nil { + return err + } + defer db.Close() + + return db.Update(func(tx *bbolt.Tx) error { + _ = tx.DeleteBucket(adminBucket) + + bucket, err := tx.CreateBucket(adminBucket) + if err != nil { + return err + } + + if err := bucket.Put([]byte("username"), []byte(username)); err != nil { + return err + } + if err := bucket.Put([]byte("password_hash"), []byte(passwordHash)); err != nil { + return err + } + return nil + }) +} diff --git a/pkg/admin/admin_store_test.go b/pkg/admin/admin_store_test.go new file mode 100644 index 0000000..6a75b13 --- /dev/null +++ b/pkg/admin/admin_store_test.go @@ -0,0 +1,200 @@ +package admin + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/bcrypt" +) + +func TestBoltAdminStore_SaveAndLoad(t *testing.T) { + t.Parallel() + + dbPath := filepath.Join(t.TempDir(), "admin.db") + store := NewBoltAdminStore(dbPath) + + hash, err := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.DefaultCost) + require.NoError(t, err) + + // Initially, no admin should exist + username, passwordHash, found, err := store.Load() + assert.NoError(t, err) + assert.False(t, found) + assert.Equal(t, "", username) + assert.Equal(t, "", passwordHash) + + // Save admin credentials + err = store.Save("admin", string(hash)) + assert.NoError(t, err) + + // Load should now return the saved credentials + username, passwordHash, found, err = store.Load() + assert.NoError(t, err) + assert.True(t, found) + assert.Equal(t, "admin", username) + assert.Equal(t, string(hash), passwordHash) +} + +func TestBoltAdminStore_SaveOverwrites(t *testing.T) { + t.Parallel() + + dbPath := filepath.Join(t.TempDir(), "admin.db") + store := NewBoltAdminStore(dbPath) + + hash1, err := bcrypt.GenerateFromPassword([]byte("secret1"), bcrypt.DefaultCost) + require.NoError(t, err) + hash2, err := bcrypt.GenerateFromPassword([]byte("secret2"), bcrypt.DefaultCost) + require.NoError(t, err) + + // Save first admin + err = store.Save("admin", string(hash1)) + assert.NoError(t, err) + + // Save second admin (should overwrite) + err = store.Save("newadmin", string(hash2)) + assert.NoError(t, err) + + // Load should return the new admin + username, passwordHash, found, err := store.Load() + assert.NoError(t, err) + assert.True(t, found) + assert.Equal(t, "newadmin", username) + assert.Equal(t, string(hash2), passwordHash) + assert.NotEqual(t, string(hash1), passwordHash) +} + +func TestBoltAdminStore_NilStore(t *testing.T) { + t.Parallel() + + var store *BoltAdminStore + + // Load from nil should return empty + username, passwordHash, found, err := store.Load() + assert.NoError(t, err) + assert.False(t, found) + assert.Equal(t, "", username) + assert.Equal(t, "", passwordHash) + + // Save to nil should not error + err = store.Save("admin", "hash") + assert.NoError(t, err) +} + +func TestBoltAdminStore_EmptyPath(t *testing.T) { + t.Parallel() + + store := NewBoltAdminStore("") + + // Load from empty path should return empty + _, _, found, err := store.Load() + assert.NoError(t, err) + assert.False(t, found) + + // Save to empty path should not error + err = store.Save("admin", "hash") + assert.NoError(t, err) +} + +func TestBoltAdminStore_PartialCredentials(t *testing.T) { + t.Parallel() + + dbPath := filepath.Join(t.TempDir(), "admin.db") + store := NewBoltAdminStore(dbPath) + + // Save with empty password hash + err := store.Save("admin", "") + assert.NoError(t, err) + + // Should not be found (both username and hash must exist) + _, _, found, err := store.Load() + assert.NoError(t, err) + assert.False(t, found) +} + +func TestBoltAdminStore_SaveFailsIfDirMissing(t *testing.T) { + t.Parallel() + + dbDir := filepath.Join(t.TempDir(), "subdir", "nested") + dbPath := filepath.Join(dbDir, "admin.db") + store := NewBoltAdminStore(dbPath) + + // Directory should not exist yet + _, err := os.Stat(dbDir) + assert.True(t, os.IsNotExist(err)) + + // Save should fail when directory is not pre-created. + hash, _ := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.DefaultCost) + err = store.Save("admin", string(hash)) + assert.Error(t, err) +} + +func TestBoltAdminStore_PersistsAcrossInstances(t *testing.T) { + t.Parallel() + + dbPath := filepath.Join(t.TempDir(), "admin.db") + + hash, err := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.DefaultCost) + require.NoError(t, err) + + // First instance: save credentials + store1 := NewBoltAdminStore(dbPath) + err = store1.Save("admin", string(hash)) + assert.NoError(t, err) + + // Second instance: load credentials + store2 := NewBoltAdminStore(dbPath) + username, passwordHash, found, err := store2.Load() + assert.NoError(t, err) + assert.True(t, found) + assert.Equal(t, "admin", username) + assert.Equal(t, string(hash), passwordHash) +} + +func TestBoltAdminStore_InvalidDatabaseFile(t *testing.T) { + t.Parallel() + + dbPath := filepath.Join(t.TempDir(), "not_a_db.txt") + + // Create a non-database file + err := os.WriteFile(dbPath, []byte("not a valid boltdb"), 0o644) + require.NoError(t, err) + + store := NewBoltAdminStore(dbPath) + + // Loading from invalid DB should error + _, _, _, err = store.Load() + assert.Error(t, err) + + // Saving should also error (can't overwrite with valid DB) + err = store.Save("admin", "hash") + assert.Error(t, err) +} + +func TestBoltAdminStore_PermissionDenied(t *testing.T) { + if os.Geteuid() == 0 { + t.Skip("skipping permission test when running as root") + } + + t.Parallel() + + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "admin.db") + + // Create DB first + hash, _ := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.DefaultCost) + store := NewBoltAdminStore(dbPath) + err := store.Save("admin", string(hash)) + require.NoError(t, err) + + // Remove read permission from directory + err = os.Chmod(tmpDir, 0o000) + require.NoError(t, err) + defer os.Chmod(tmpDir, 0o755) + + // Load should fail due to permission error + _, _, _, err = store.Load() + assert.Error(t, err) +} diff --git a/pkg/admin/server.go b/pkg/admin/server.go new file mode 100644 index 0000000..1345681 --- /dev/null +++ b/pkg/admin/server.go @@ -0,0 +1,1072 @@ +package admin + +import ( + "crypto/rand" + "crypto/subtle" + "embed" + "encoding/base64" + "fmt" + "html/template" + "io" + "net" + "net/http" + "net/url" + "os" + "path" + "regexp" + "sort" + "strings" + "sync" + "time" + + "github.com/rs/zerolog" + "github.com/ryanbekhen/nanoproxy/pkg/credential" + "github.com/ryanbekhen/nanoproxy/pkg/traffic" + "golang.org/x/crypto/bcrypt" +) + +//go:embed templates/*.gohtml +var templatesFS embed.FS + +const ( + sessionCookieName = "nanoproxy_admin_session" + sessionTTL = 12 * time.Hour + offlineInactivityWindow = 10 * time.Minute + statusActive = "Active" + statusOffline = "Offline" + minUsernameLength = 3 + maxUsernameLength = 64 + defaultMaxLoginAttempts = 5 + defaultLoginWindow = 5 * time.Minute + defaultLockoutDuration = 10 * time.Minute + minAdminPasswordLength = 8 +) + +var usernamePattern = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`) + +type Config struct { + Credentials *credential.StaticCredentialStore + UserStore credential.PersistentStore + AdminStore AdminCredentialStore + TrafficStore traffic.Store + Tracker *traffic.Tracker + CookieSecure bool + MaxLoginAttempts int + LoginWindow time.Duration + LockoutDuration time.Duration + AllowedOrigins []string + Logger *zerolog.Logger +} + +type Server struct { + config *Config + tmpl *template.Template + sessions map[string]session + logins map[string]loginAttempt + admin adminCredential + mu sync.Mutex +} + +type adminCredential struct { + Username string + PasswordHash string +} + +type session struct { + ExpiresAt time.Time + CSRFToken string +} + +type loginAttempt struct { + Count int + FirstFailed time.Time + LockedUntil time.Time +} + +type usersViewData struct { + Error string + Success string + GeneratedUsername string + GeneratedPassword string + CSRFToken string + ProxyUsers []proxyUserView + TotalUsers int +} + +type setupViewData struct { + Error string +} + +type proxyUserView struct { + Username string + ActiveClients int + ClientIP string + UploadRate string + DownloadRate string + UploadTotal string + DownloadTotal string + Status string + StartedAgo string +} + +func New(conf *Config) *Server { + if conf.Logger == nil { + logger := zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.RFC3339}).With().Timestamp().Logger() + conf.Logger = &logger + } + + if conf.MaxLoginAttempts <= 0 { + conf.MaxLoginAttempts = defaultMaxLoginAttempts + } + + if conf.LoginWindow <= 0 { + conf.LoginWindow = defaultLoginWindow + } + + if conf.LockoutDuration <= 0 { + conf.LockoutDuration = defaultLockoutDuration + } + + if conf.Credentials == nil { + conf.Credentials = credential.NewStaticCredentialStore() + } + + conf.AllowedOrigins = normalizeAllowedOrigins(conf.AllowedOrigins) + + tmpl := template.Must(template.ParseFS(templatesFS, "templates/*.gohtml")) + + server := &Server{ + config: conf, + tmpl: tmpl, + sessions: make(map[string]session), + logins: make(map[string]loginAttempt), + } + + if conf.AdminStore != nil { + username, passwordHash, found, err := conf.AdminStore.Load() + if err != nil { + conf.Logger.Warn().Err(err).Msg("failed to load admin credentials from store") + } else if found { + server.admin = adminCredential{Username: username, PasswordHash: passwordHash} + } + } + + return server +} + +func (s *Server) Handler() http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("/", s.handleRoot) + mux.HandleFunc("/admin", s.handleIndex) + mux.HandleFunc("/admin/setup", s.handleSetup) + mux.HandleFunc("/admin/login", s.handleLogin) + mux.HandleFunc("/admin/logout", s.handleLogout) + mux.HandleFunc("/admin/users", s.handleUsers) + mux.HandleFunc("/admin/users/rows", s.handleUserRows) + mux.HandleFunc("/admin/users/", s.handleUserByName) + return s.withSecurityHeaders(mux) +} + +func (s *Server) withSecurityHeaders(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Frame-Options", "DENY") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("Referrer-Policy", "no-referrer") + w.Header().Set("Cache-Control", "no-store") + next.ServeHTTP(w, r) + }) +} + +func (s *Server) handleRoot(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" { + http.NotFound(w, r) + return + } + + http.Redirect(w, r, "/admin", http.StatusSeeOther) +} + +func (s *Server) handleIndex(w http.ResponseWriter, r *http.Request) { + if !s.hasConfiguredAdminCredentials() { + http.Redirect(w, r, "/admin/setup", http.StatusSeeOther) + return + } + + if s.isAuthenticated(r) { + http.Redirect(w, r, "/admin/users", http.StatusSeeOther) + return + } + http.Redirect(w, r, "/admin/login", http.StatusSeeOther) +} + +func (s *Server) handleSetup(w http.ResponseWriter, r *http.Request) { + if s.hasConfiguredAdminCredentials() { + if s.isAuthenticated(r) { + http.Redirect(w, r, "/admin/users", http.StatusSeeOther) + return + } + http.Redirect(w, r, "/admin/login", http.StatusSeeOther) + return + } + + switch r.Method { + case http.MethodGet: + s.renderTemplate(w, "setup.gohtml", setupViewData{}, http.StatusOK) + case http.MethodPost: + // Check again if admin already configured (prevent double setup) + if s.hasConfiguredAdminCredentials() { + http.Redirect(w, r, "/admin/login", http.StatusSeeOther) + return + } + + if err := s.verifyOrigin(r); err != nil { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + + if err := r.ParseForm(); err != nil { + http.Error(w, "invalid form", http.StatusBadRequest) + return + } + + username := strings.TrimSpace(r.FormValue("username")) + password := r.FormValue("password") + confirmPassword := r.FormValue("confirm_password") + + if err := validateUsername(username); err != nil { + s.renderTemplate(w, "setup.gohtml", setupViewData{Error: err.Error()}, http.StatusBadRequest) + return + } + + if len(password) < minAdminPasswordLength { + s.renderTemplate(w, "setup.gohtml", setupViewData{Error: fmt.Sprintf("password must be at least %d characters", minAdminPasswordLength)}, http.StatusBadRequest) + return + } + + if subtle.ConstantTimeCompare([]byte(password), []byte(confirmPassword)) != 1 { + s.renderTemplate(w, "setup.gohtml", setupViewData{Error: "password confirmation does not match"}, http.StatusBadRequest) + return + } + + if err := s.bootstrapAdminCredentials(username, password); err != nil { + if strings.Contains(err.Error(), "already configured") { + http.Redirect(w, r, "/admin/login", http.StatusSeeOther) + return + } + s.renderTemplate(w, "setup.gohtml", setupViewData{Error: "failed to create admin credentials"}, http.StatusInternalServerError) + return + } + + if err := s.createSession(w); err != nil { + http.Error(w, "failed to create session", http.StatusInternalServerError) + return + } + + http.Redirect(w, r, "/admin/users", http.StatusSeeOther) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + if !s.hasConfiguredAdminCredentials() { + http.Redirect(w, r, "/admin/setup", http.StatusSeeOther) + return + } + + if s.isAuthenticated(r) { + http.Redirect(w, r, "/admin/users", http.StatusSeeOther) + return + } + s.renderTemplate(w, "login.gohtml", map[string]any{"Error": ""}, http.StatusOK) + case http.MethodPost: + if !s.hasConfiguredAdminCredentials() { + http.Redirect(w, r, "/admin/setup", http.StatusSeeOther) + return + } + + if err := s.verifyOrigin(r); err != nil { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + + if err := r.ParseForm(); err != nil { + http.Error(w, "invalid form", http.StatusBadRequest) + return + } + + clientIP := extractClientIP(r.RemoteAddr) + if s.isLocked(clientIP) { + s.renderTemplate(w, "login.gohtml", map[string]any{"Error": "Too many failed attempts. Try again later."}, http.StatusTooManyRequests) + return + } + + username := strings.TrimSpace(r.FormValue("username")) + password := r.FormValue("password") + if !s.validateAdminCredentials(username, password) { + s.recordFailedLogin(clientIP) + s.renderTemplate(w, "login.gohtml", map[string]any{"Error": "Invalid admin credentials"}, http.StatusUnauthorized) + return + } + s.clearFailedLogins(clientIP) + + if err := s.createSession(w); err != nil { + http.Error(w, "failed to create session", http.StatusInternalServerError) + return + } + + http.Redirect(w, r, "/admin/users", http.StatusSeeOther) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + if err := s.verifyCSRF(r); err != nil { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + + cookie, err := r.Cookie(sessionCookieName) + if err == nil { + s.mu.Lock() + delete(s.sessions, cookie.Value) + s.mu.Unlock() + } + + // #nosec G124 -- Secure is configurable for local HTTP admin use; HttpOnly and SameSite are enforced. + http.SetCookie(w, &http.Cookie{ + Name: sessionCookieName, + Value: "", + Path: "/", + HttpOnly: true, + Secure: s.config.CookieSecure, + SameSite: http.SameSiteStrictMode, + Expires: time.Unix(0, 0), + MaxAge: -1, + }) + + http.Redirect(w, r, "/admin/login", http.StatusSeeOther) +} + +func (s *Server) handleUsers(w http.ResponseWriter, r *http.Request) { + if !s.isAuthenticated(r) { + s.redirectToLogin(w, r) + return + } + + switch r.Method { + case http.MethodGet: + csrfToken, err := s.currentCSRFToken(r) + if err != nil { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + s.renderUsers(w, usersViewData{CSRFToken: csrfToken}, http.StatusOK) + case http.MethodPost: + if err := s.verifyCSRF(r); err != nil { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + rotatedCSRFToken, err := s.rotateCSRFToken(r) + if err != nil { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + + if err := r.ParseForm(); err != nil { + http.Error(w, "invalid form", http.StatusBadRequest) + return + } + + username := strings.TrimSpace(r.FormValue("username")) + if err := validateUsername(username); err != nil { + s.renderUsers(w, usersViewData{Error: err.Error(), CSRFToken: rotatedCSRFToken}, http.StatusBadRequest) + return + } + + if s.config.Credentials.Exists(username) { + s.renderUsers(w, usersViewData{Error: "username already exists", CSRFToken: rotatedCSRFToken}, http.StatusConflict) + return + } + + password, err := generatePassword(20) + if err != nil { + s.renderUsers(w, usersViewData{Error: "failed to generate password", CSRFToken: rotatedCSRFToken}, http.StatusInternalServerError) + return + } + + s.config.Credentials.Add(username, password) + if err := s.persistUsers(); err != nil { + s.config.Credentials.Delete(username) + s.renderUsers(w, usersViewData{Error: "failed to persist users", CSRFToken: rotatedCSRFToken}, http.StatusInternalServerError) + return + } + + s.renderUsers(w, usersViewData{ + Success: "Proxy user created successfully.", + GeneratedUsername: username, + GeneratedPassword: password, + CSRFToken: rotatedCSRFToken, + }, http.StatusOK) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func (s *Server) handleUserByName(w http.ResponseWriter, r *http.Request) { + if !s.isAuthenticated(r) { + s.redirectToLogin(w, r) + return + } + + relativePath := strings.TrimPrefix(r.URL.Path, "/admin/users/") + cleanPath := path.Clean("/" + relativePath) + segments := strings.Split(strings.TrimPrefix(cleanPath, "/"), "/") + if len(segments) == 0 || segments[0] == "" { + http.Error(w, "username is required", http.StatusBadRequest) + return + } + + username := segments[0] + if username == "" { + http.Error(w, "username is required", http.StatusBadRequest) + return + } + + if r.Method == http.MethodPost && len(segments) == 2 && segments[1] == "reset-password" { + if err := s.verifyCSRF(r); err != nil { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + rotatedCSRFToken, err := s.rotateCSRFToken(r) + if err != nil { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + + password, err := generatePassword(20) + if err != nil { + s.renderUsers(w, usersViewData{Error: "failed to generate password", CSRFToken: rotatedCSRFToken}, http.StatusInternalServerError) + return + } + + previousHash, existed := s.config.Credentials.GetHashed(username) + if !existed { + s.renderUsers(w, usersViewData{Error: "user not found", CSRFToken: rotatedCSRFToken}, http.StatusNotFound) + return + } + + s.config.Credentials.Add(username, password) + if err := s.persistUsers(); err != nil { + s.config.Credentials.SetHashed(username, previousHash) + s.renderUsers(w, usersViewData{Error: "failed to persist users", CSRFToken: rotatedCSRFToken}, http.StatusInternalServerError) + return + } + + s.renderUsers(w, usersViewData{ + Success: "Password reset successfully.", + GeneratedUsername: username, + GeneratedPassword: password, + CSRFToken: rotatedCSRFToken, + }, http.StatusOK) + return + } + + if r.Method == http.MethodPost && len(segments) == 2 && segments[1] == "reset-stats" { + if err := s.verifyCSRF(r); err != nil { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + rotatedCSRFToken, err := s.rotateCSRFToken(r) + if err != nil { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + + if !s.config.Credentials.Exists(username) { + s.renderUsers(w, usersViewData{Error: "user not found", CSRFToken: rotatedCSRFToken}, http.StatusNotFound) + return + } + + s.config.Tracker.ResetUserStats(username) + if err := s.persistTraffic(); err != nil { + s.renderUsers(w, usersViewData{Error: "failed to reset stats", CSRFToken: rotatedCSRFToken}, http.StatusInternalServerError) + return + } + + s.renderUsers(w, usersViewData{ + Success: "Traffic statistics reset successfully.", + CSRFToken: rotatedCSRFToken, + }, http.StatusOK) + return + } + + if r.Method != http.MethodDelete { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + if err := s.verifyCSRF(r); err != nil { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + rotatedCSRFToken, err := s.rotateCSRFToken(r) + if err != nil { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + + passwordHash, existed := s.config.Credentials.GetHashed(username) + s.config.Credentials.Delete(username) + if err := s.persistUsers(); err != nil { + if existed { + s.config.Credentials.SetHashed(username, passwordHash) + } + s.renderUsers(w, usersViewData{Error: "failed to persist users", CSRFToken: rotatedCSRFToken}, http.StatusInternalServerError) + return + } + + s.renderUsers(w, usersViewData{Success: "Proxy user deleted successfully.", CSRFToken: rotatedCSRFToken}, http.StatusOK) +} + +func (s *Server) renderUsers(w http.ResponseWriter, data usersViewData, status int) { + data.ProxyUsers = s.proxyUsersWithTraffic() + data.TotalUsers = len(data.ProxyUsers) + s.renderTemplate(w, "users.gohtml", data, status) +} + +func (s *Server) handleUserRows(w http.ResponseWriter, r *http.Request) { + if !s.isAuthenticated(r) { + s.redirectToLogin(w, r) + return + } + + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + rows := s.proxyUsersWithTraffic() + s.renderTemplate(w, "user_rows.gohtml", usersViewData{ProxyUsers: rows, TotalUsers: len(rows)}, http.StatusOK) +} + +func (s *Server) proxyUsersWithTraffic() []proxyUserView { + usernames := s.config.Credentials.ListUsers() + now := time.Now() + rows := make([]proxyUserView, 0, len(usernames)) + byUser := make(map[string]*proxyUserView, len(usernames)) + startedByUser := make(map[string]time.Time, len(usernames)) + ipSetByUser := make(map[string]map[string]struct{}, len(usernames)) + totalByUser := make(map[string]traffic.UserTotals) + + for _, username := range usernames { + row := proxyUserView{ + Username: username, + ActiveClients: 0, + ClientIP: "-", + UploadRate: "0 B/s", + DownloadRate: "0 B/s", + UploadTotal: "0 B", + DownloadTotal: "0 B", + Status: statusOffline, + StartedAgo: "-", + } + rows = append(rows, row) + byUser[username] = &rows[len(rows)-1] + ipSetByUser[username] = make(map[string]struct{}) + } + + if s.config.Tracker == nil { + return rows + } + totalByUser = s.config.Tracker.TotalsByUser() + + for _, item := range s.config.Tracker.Snapshot() { + row, ok := byUser[item.Username] + if !ok { + continue + } + + row.ActiveClients++ + row.Status = statusActive + + if item.ClientIP != "" { + ipSetByUser[item.Username][item.ClientIP] = struct{}{} + } + + if startedByUser[item.Username].IsZero() || item.StartedAt.Before(startedByUser[item.Username]) { + startedByUser[item.Username] = item.StartedAt + } + } + + for i := range rows { + totals := totalByUser[rows[i].Username] + + if rows[i].ActiveClients == 0 { + if !totals.LastSeenAt.IsZero() { + if now.Sub(totals.LastSeenAt) <= offlineInactivityWindow { + rows[i].Status = statusActive + } + if totals.LastClientIP != "" { + rows[i].ClientIP = totals.LastClientIP + } + rows[i].StartedAgo = formatStartedAgo(totals.LastSeenAt) + } + continue + } + + ips := make([]string, 0, len(ipSetByUser[rows[i].Username])) + for ip := range ipSetByUser[rows[i].Username] { + ips = append(ips, ip) + } + sort.Strings(ips) + rows[i].ClientIP = strings.Join(ips, ", ") + rows[i].StartedAgo = formatStartedAgo(startedByUser[rows[i].Username]) + } + + for i := range rows { + totals := totalByUser[rows[i].Username] + rows[i].UploadRate = formatByteRate(totals.UploadBPS) + rows[i].DownloadRate = formatByteRate(totals.DownloadBPS) + rows[i].UploadTotal = formatBytes(totals.UploadBytes) + rows[i].DownloadTotal = formatBytes(totals.DownloadBytes) + } + + return rows +} + +func formatByteRate(n uint64) string { + return fmt.Sprintf("%s/s", formatBytes(n)) +} + +func formatBytes(n uint64) string { + const unit = 1024 + if n < unit { + return fmt.Sprintf("%d B", n) + } + div, exp := uint64(unit), 0 + for value := n / unit; value >= unit && exp < 5; value /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.2f %cB", float64(n)/float64(div), "KMGTPE"[exp]) +} + +func formatStartedAgo(startedAt time.Time) string { + if startedAt.IsZero() { + return "unknown" + } + d := time.Since(startedAt) + if d < time.Minute { + return "just now" + } + if d < time.Hour { + return fmt.Sprintf("%d minutes ago", int(d.Minutes())) + } + if d < 24*time.Hour { + return fmt.Sprintf("%d hours ago", int(d.Hours())) + } + return fmt.Sprintf("%d days ago", int(d.Hours()/24)) +} + +func (s *Server) renderTemplate(w http.ResponseWriter, name string, data any, status int) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(status) + if err := s.tmpl.ExecuteTemplate(w, name, data); err != nil { + s.config.Logger.Error().Err(err).Msg("failed to render template") + } +} + +func (s *Server) redirectToLogin(w http.ResponseWriter, r *http.Request) { + if strings.EqualFold(r.Header.Get("HX-Request"), "true") { + w.Header().Set("HX-Redirect", "/admin/login") + w.WriteHeader(http.StatusUnauthorized) + return + } + + http.Redirect(w, r, "/admin/login", http.StatusSeeOther) +} + +func (s *Server) persistUsers() error { + if s.config.UserStore == nil { + return nil + } + + return s.config.UserStore.Save(s.config.Credentials.Snapshot()) +} + +func (s *Server) persistTraffic() error { + if s.config.TrafficStore == nil || s.config.Tracker == nil { + return nil + } + + return s.config.TrafficStore.SaveTraffic(s.config.Tracker.TotalsByUser()) +} + +func (s *Server) validateAdminCredentials(username, password string) bool { + storedUsername, storedPasswordHash, found := s.currentStoredAdminCredentials() + if !found { + return false + } + + if subtle.ConstantTimeCompare([]byte(username), []byte(storedUsername)) != 1 { + return false + } + + return bcrypt.CompareHashAndPassword([]byte(storedPasswordHash), []byte(password)) == nil +} + +func (s *Server) createSession(w http.ResponseWriter) error { + token, err := newRandomToken(32) + if err != nil { + return err + } + + csrfToken, err := newRandomToken(32) + if err != nil { + return err + } + + expiresAt := time.Now().Add(sessionTTL) + s.mu.Lock() + s.sessions[token] = session{ExpiresAt: expiresAt, CSRFToken: csrfToken} + s.mu.Unlock() + + // #nosec G124 -- Secure is configurable for local HTTP admin use; HttpOnly and SameSite are enforced. + http.SetCookie(w, &http.Cookie{ + Name: sessionCookieName, + Value: token, + Path: "/", + HttpOnly: true, + Secure: s.config.CookieSecure, + SameSite: http.SameSiteStrictMode, + Expires: expiresAt, + }) + + return nil +} + +func (s *Server) hasConfiguredAdminCredentials() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.hasConfiguredAdminCredentialsLocked() +} + +func (s *Server) hasConfiguredAdminCredentialsLocked() bool { + return s.admin.Username != "" && s.admin.PasswordHash != "" +} + +func (s *Server) currentStoredAdminCredentials() (string, string, bool) { + s.mu.Lock() + defer s.mu.Unlock() + if s.admin.Username == "" || s.admin.PasswordHash == "" { + return "", "", false + } + + return s.admin.Username, s.admin.PasswordHash, true +} + +func (s *Server) bootstrapAdminCredentials(username, password string) error { + if s.config.AdminStore == nil { + return httpError("admin credential store is not configured") + } + + passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + + if s.hasConfiguredAdminCredentialsLocked() { + return httpError("admin credentials already configured") + } + + if err := s.config.AdminStore.Save(username, string(passwordHash)); err != nil { + return err + } + + s.admin = adminCredential{Username: username, PasswordHash: string(passwordHash)} + return nil +} + +func (s *Server) isAuthenticated(r *http.Request) bool { + _, ok := s.currentSession(r) + return ok +} + +func (s *Server) currentSession(r *http.Request) (session, bool) { + cookie, err := r.Cookie(sessionCookieName) + if err != nil { + return session{}, false + } + + now := time.Now() + s.mu.Lock() + defer s.mu.Unlock() + + sess, ok := s.sessions[cookie.Value] + if !ok { + return session{}, false + } + if now.After(sess.ExpiresAt) { + delete(s.sessions, cookie.Value) + return session{}, false + } + + return sess, true +} + +func (s *Server) currentCSRFToken(r *http.Request) (string, error) { + sess, ok := s.currentSession(r) + if !ok { + return "", httpError("invalid session") + } + + return sess.CSRFToken, nil +} + +func (s *Server) verifyCSRF(r *http.Request) error { + if err := s.verifyOrigin(r); err != nil { + return err + } + + sess, ok := s.currentSession(r) + if !ok { + return httpError("invalid session") + } + + provided := r.Header.Get("X-CSRF-Token") + if provided == "" { + if err := r.ParseForm(); err == nil { + provided = r.FormValue("_csrf") + } + } + + if provided == "" { + return httpError("csrf token required") + } + + if subtle.ConstantTimeCompare([]byte(provided), []byte(sess.CSRFToken)) != 1 { + return httpError("invalid csrf token") + } + + return nil +} + +func (s *Server) rotateCSRFToken(r *http.Request) (string, error) { + cookie, err := r.Cookie(sessionCookieName) + if err != nil { + return "", httpError("invalid session") + } + + newToken, err := newRandomToken(32) + if err != nil { + return "", err + } + + s.mu.Lock() + defer s.mu.Unlock() + + sess, ok := s.sessions[cookie.Value] + if !ok { + return "", httpError("invalid session") + } + + sess.CSRFToken = newToken + s.sessions[cookie.Value] = sess + + return newToken, nil +} + +func (s *Server) verifyOrigin(r *http.Request) error { + if len(s.config.AllowedOrigins) == 0 { + return nil + } + + originHeader := strings.TrimSpace(r.Header.Get("Origin")) + if originHeader != "" { + origin, err := normalizeOrigin(originHeader) + if err != nil { + return httpError("invalid origin") + } + if s.isAllowedOrigin(origin) { + return nil + } + return httpError("origin not allowed") + } + + refererHeader := strings.TrimSpace(r.Header.Get("Referer")) + if refererHeader != "" { + refererURL, err := url.Parse(refererHeader) + if err != nil || refererURL.Scheme == "" || refererURL.Host == "" { + return httpError("invalid referer") + } + refererOrigin := refererURL.Scheme + "://" + refererURL.Host + if s.isAllowedOrigin(strings.ToLower(refererOrigin)) { + return nil + } + return httpError("referer not allowed") + } + + return httpError("origin or referer required") +} + +func (s *Server) isAllowedOrigin(origin string) bool { + for _, allowed := range s.config.AllowedOrigins { + if subtle.ConstantTimeCompare([]byte(origin), []byte(allowed)) == 1 { + return true + } + } + + return false +} + +func normalizeAllowedOrigins(origins []string) []string { + if len(origins) == 0 { + return nil + } + + normalized := make([]string, 0, len(origins)) + seen := make(map[string]struct{}, len(origins)) + for _, raw := range origins { + origin, err := normalizeOrigin(raw) + if err != nil { + continue + } + if _, ok := seen[origin]; ok { + continue + } + seen[origin] = struct{}{} + normalized = append(normalized, origin) + } + + return normalized +} + +func normalizeOrigin(raw string) (string, error) { + parsed, err := url.Parse(strings.TrimSpace(raw)) + if err != nil { + return "", err + } + if parsed.Scheme == "" || parsed.Host == "" { + return "", httpError("invalid origin") + } + + return strings.ToLower(parsed.Scheme + "://" + parsed.Host), nil +} + +func (s *Server) isLocked(clientIP string) bool { + s.mu.Lock() + defer s.mu.Unlock() + + attempt, ok := s.logins[clientIP] + if !ok { + return false + } + + if attempt.LockedUntil.IsZero() { + return false + } + + if time.Now().After(attempt.LockedUntil) { + delete(s.logins, clientIP) + return false + } + + return true +} + +func (s *Server) recordFailedLogin(clientIP string) { + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now() + attempt := s.logins[clientIP] + + if attempt.FirstFailed.IsZero() || now.Sub(attempt.FirstFailed) > s.config.LoginWindow { + attempt.FirstFailed = now + attempt.Count = 0 + attempt.LockedUntil = time.Time{} + } + + attempt.Count++ + if attempt.Count >= s.config.MaxLoginAttempts { + attempt.LockedUntil = now.Add(s.config.LockoutDuration) + } + + s.logins[clientIP] = attempt +} + +func (s *Server) clearFailedLogins(clientIP string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.logins, clientIP) +} + +func extractClientIP(remoteAddr string) string { + host, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + return remoteAddr + } + return host +} + +func newRandomToken(byteLength int) (string, error) { + buf := make([]byte, byteLength) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(buf), nil +} + +func generatePassword(length int) (string, error) { + const charset = "ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz23456789" + if length <= 0 { + return "", httpError("password length must be greater than zero") + } + + bytes := make([]byte, length) + if _, err := io.ReadFull(rand.Reader, bytes); err != nil { + return "", err + } + + var builder strings.Builder + builder.Grow(length) + for _, value := range bytes { + builder.WriteByte(charset[int(value)%len(charset)]) + } + + return builder.String(), nil +} + +func validateUsername(username string) error { + if username == "" { + return httpError("username is required") + } + if len(username) < minUsernameLength || len(username) > maxUsernameLength { + return httpError("username must be between 3 and 64 characters") + } + if strings.Contains(username, ":") { + return httpError("username cannot contain ':'") + } + if !usernamePattern.MatchString(username) { + return httpError("username may only contain letters, numbers, dots, underscores, and hyphens") + } + return nil +} + +type httpError string + +func (e httpError) Error() string { + return string(e) +} diff --git a/pkg/admin/server_test.go b/pkg/admin/server_test.go new file mode 100644 index 0000000..6ba7fca --- /dev/null +++ b/pkg/admin/server_test.go @@ -0,0 +1,1265 @@ +package admin + +import ( + "html" + "io" + "net/http" + "net/http/cookiejar" + "net/http/httptest" + "net/url" + "path/filepath" + "regexp" + "strings" + "testing" + "time" + + "github.com/rs/zerolog" + "github.com/ryanbekhen/nanoproxy/pkg/credential" + "github.com/ryanbekhen/nanoproxy/pkg/traffic" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/bcrypt" +) + +func newSeededAdminStore(t *testing.T, username, password string) AdminCredentialStore { + t.Helper() + + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + require.NoError(t, err) + + store := NewBoltAdminStore(filepath.Join(t.TempDir(), "admin.db")) + require.NoError(t, store.Save(username, string(hash))) + + return store +} + +func TestServer_LoginAndUserManagement(t *testing.T) { + logger := zerolog.New(io.Discard) + credentials := credential.NewStaticCredentialStore() + userStore := credential.NewBoltStore(filepath.Join(t.TempDir(), "data.db")) + + s := New(&Config{ + Credentials: credentials, + UserStore: userStore, + AdminStore: newSeededAdminStore(t, "admin", "secret"), + Logger: &logger, + }) + + ts := httptest.NewServer(s.Handler()) + defer ts.Close() + + jar, err := cookiejar.New(nil) + assert.NoError(t, err) + + client := &http.Client{Jar: jar} + + resp, err := client.Get(ts.URL + "/") + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "/admin/login", resp.Request.URL.Path) + _ = resp.Body.Close() + + resp, err = client.Get(ts.URL + "/admin/users") + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "/admin/login", resp.Request.URL.Path) + _ = resp.Body.Close() + + resp, err = client.PostForm(ts.URL+"/admin/login", url.Values{ + "username": {"admin"}, + "password": {"wrong"}, + }) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + _ = resp.Body.Close() + + resp, err = client.PostForm(ts.URL+"/admin/login", url.Values{ + "username": {"admin"}, + "password": {"secret"}, + }) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "/admin/users", resp.Request.URL.Path) + _ = resp.Body.Close() + + resp, err = client.Get(ts.URL + "/admin/users") + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + _ = resp.Body.Close() + + csrfToken := extractCSRFToken(t, string(body)) + + resp, err = client.PostForm(ts.URL+"/admin/users", url.Values{"username": {"proxyuser"}, "_csrf": {csrfToken}}) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + body, err = io.ReadAll(resp.Body) + assert.NoError(t, err) + _ = resp.Body.Close() + csrfToken = extractCSRFToken(t, string(body)) + + createdPassword := extractGeneratedPassword(t, string(body)) + assert.NotEmpty(t, createdPassword) + + assert.True(t, credentials.Valid("proxyuser", createdPassword)) + + restartedStore := credential.NewStaticCredentialStore() + assert.NoError(t, credential.LoadInto(userStore, restartedStore)) + assert.True(t, restartedStore.Valid("proxyuser", createdPassword)) + + resp, err = client.PostForm(ts.URL+"/admin/users", url.Values{ + "username": {"proxyuser"}, + "_csrf": {csrfToken}, + }) + assert.NoError(t, err) + assert.Equal(t, http.StatusConflict, resp.StatusCode) + body, err = io.ReadAll(resp.Body) + assert.NoError(t, err) + _ = resp.Body.Close() + csrfToken = extractCSRFToken(t, string(body)) + + resp, err = client.PostForm(ts.URL+"/admin/users/proxyuser/reset-password", url.Values{"_csrf": {csrfToken}}) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + body, err = io.ReadAll(resp.Body) + assert.NoError(t, err) + _ = resp.Body.Close() + csrfToken = extractCSRFToken(t, string(body)) + + resetPassword := extractGeneratedPassword(t, string(body)) + assert.NotEmpty(t, resetPassword) + assert.NotEqual(t, createdPassword, resetPassword) + assert.False(t, credentials.Valid("proxyuser", createdPassword)) + assert.True(t, credentials.Valid("proxyuser", resetPassword)) + assert.NoError(t, credential.LoadInto(userStore, restartedStore)) + assert.True(t, restartedStore.Valid("proxyuser", resetPassword)) + + req, err := http.NewRequest(http.MethodDelete, ts.URL+"/admin/users/proxyuser", nil) + assert.NoError(t, err) + req.Header.Set("X-CSRF-Token", csrfToken) + resp, err = client.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + _ = resp.Body.Close() + + assert.False(t, credentials.Valid("proxyuser", resetPassword)) + assert.NoError(t, credential.LoadInto(userStore, restartedStore)) + assert.False(t, restartedStore.Valid("proxyuser", resetPassword)) +} + +func TestValidateUsername(t *testing.T) { + assert.Error(t, validateUsername("")) + assert.Error(t, validateUsername("ab")) + assert.Error(t, validateUsername("foo:bar")) + assert.Error(t, validateUsername("bad user")) + assert.Error(t, validateUsername("bad/user")) + assert.Error(t, validateUsername(strings.Repeat("a", 65))) + assert.NoError(t, validateUsername("foo")) + assert.NoError(t, validateUsername("user.name-01")) +} + +func TestServer_LoginRequiresConfiguredCredentials(t *testing.T) { + logger := zerolog.New(io.Discard) + credentials := credential.NewStaticCredentialStore() + + s := New(&Config{ + Credentials: credentials, + UserStore: credential.NewBoltStore(filepath.Join(t.TempDir(), "data.db")), + AdminStore: newSeededAdminStore(t, "admin", "secret"), + Logger: &logger, + }) + + ts := httptest.NewServer(s.Handler()) + defer ts.Close() + + resp, err := http.Post(ts.URL+"/admin/login", "application/x-www-form-urlencoded", strings.NewReader("username=bad&password=bad")) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + _ = resp.Body.Close() +} + +func TestServer_SetupFlow_CreatesAdminAndPersistsAcrossRestart(t *testing.T) { + logger := zerolog.New(io.Discard) + credentials := credential.NewStaticCredentialStore() + dbPath := filepath.Join(t.TempDir(), "data.db") + + s := New(&Config{ + Credentials: credentials, + UserStore: credential.NewBoltStore(dbPath), + AdminStore: NewBoltAdminStore(dbPath), + Logger: &logger, + }) + + ts := httptest.NewServer(s.Handler()) + defer ts.Close() + + jar, err := cookiejar.New(nil) + require.NoError(t, err) + client := &http.Client{Jar: jar} + + resp, err := client.Get(ts.URL + "/admin/login") + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "/admin/setup", resp.Request.URL.Path) + _ = resp.Body.Close() + + resp, err = client.PostForm(ts.URL+"/admin/setup", url.Values{ + "username": {"admin"}, + "password": {"super-secret"}, + "confirm_password": {"super-secret"}, + }) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "/admin/users", resp.Request.URL.Path) + _ = resp.Body.Close() + + restarted := New(&Config{ + Credentials: credential.NewStaticCredentialStore(), + UserStore: credential.NewBoltStore(dbPath), + AdminStore: NewBoltAdminStore(dbPath), + Logger: &logger, + }) + + tsRestarted := httptest.NewServer(restarted.Handler()) + defer tsRestarted.Close() + + noFollow := &http.Client{CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }} + + resp, err = noFollow.Post(tsRestarted.URL+"/admin/login", "application/x-www-form-urlencoded", strings.NewReader("username=admin&password=super-secret")) + require.NoError(t, err) + assert.Equal(t, http.StatusSeeOther, resp.StatusCode) + assert.Equal(t, "/admin/users", resp.Header.Get("Location")) + _ = resp.Body.Close() +} + +func TestServer_SetupFlow_RejectsInvalidInput(t *testing.T) { + logger := zerolog.New(io.Discard) + s := New(&Config{ + Credentials: credential.NewStaticCredentialStore(), + UserStore: credential.NewBoltStore(filepath.Join(t.TempDir(), "data.db")), + AdminStore: NewBoltAdminStore(filepath.Join(t.TempDir(), "admin.db")), + Logger: &logger, + }) + + ts := httptest.NewServer(s.Handler()) + defer ts.Close() + + resp, err := http.PostForm(ts.URL+"/admin/setup", url.Values{ + "username": {"ad"}, + "password": {"short"}, + "confirm_password": {"short"}, + }) + require.NoError(t, err) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + _ = resp.Body.Close() +} + +func TestServer_SetupFlow_PasswordMismatch(t *testing.T) { + logger := zerolog.New(io.Discard) + s := New(&Config{ + Credentials: credential.NewStaticCredentialStore(), + UserStore: credential.NewBoltStore(filepath.Join(t.TempDir(), "data.db")), + AdminStore: NewBoltAdminStore(filepath.Join(t.TempDir(), "admin.db")), + Logger: &logger, + }) + + ts := httptest.NewServer(s.Handler()) + defer ts.Close() + + resp, err := http.PostForm(ts.URL+"/admin/setup", url.Values{ + "username": {"admin"}, + "password": {"password123"}, + "confirm_password": {"password456"}, + }) + require.NoError(t, err) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + _ = resp.Body.Close() +} + +func TestServer_SetupFlow_InvalidUsername(t *testing.T) { + logger := zerolog.New(io.Discard) + s := New(&Config{ + Credentials: credential.NewStaticCredentialStore(), + UserStore: credential.NewBoltStore(filepath.Join(t.TempDir(), "data.db")), + AdminStore: NewBoltAdminStore(filepath.Join(t.TempDir(), "admin.db")), + Logger: &logger, + }) + + ts := httptest.NewServer(s.Handler()) + defer ts.Close() + + testCases := []string{ + "a", // too short + "admin:user", // colon not allowed + "admin user", // space not allowed + "admin/user", // slash not allowed + strings.Repeat("a", 65), // too long + } + + for _, username := range testCases { + resp, err := http.PostForm(ts.URL+"/admin/setup", url.Values{ + "username": {username}, + "password": {"validpass123"}, + "confirm_password": {"validpass123"}, + }) + require.NoError(t, err) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode, "username %q should be rejected", username) + _ = resp.Body.Close() + } +} + +func TestServer_SetupFlow_ShortPassword(t *testing.T) { + logger := zerolog.New(io.Discard) + s := New(&Config{ + Credentials: credential.NewStaticCredentialStore(), + UserStore: credential.NewBoltStore(filepath.Join(t.TempDir(), "data.db")), + AdminStore: NewBoltAdminStore(filepath.Join(t.TempDir(), "admin.db")), + Logger: &logger, + }) + + ts := httptest.NewServer(s.Handler()) + defer ts.Close() + + resp, err := http.PostForm(ts.URL+"/admin/setup", url.Values{ + "username": {"admin"}, + "password": {"1234567"}, + "confirm_password": {"1234567"}, + }) + require.NoError(t, err) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + _ = resp.Body.Close() +} + +func TestServer_SetupFlow_PreventDoubleSetup(t *testing.T) { + logger := zerolog.New(io.Discard) + credentials := credential.NewStaticCredentialStore() + dbPath := filepath.Join(t.TempDir(), "data.db") + + s := New(&Config{ + Credentials: credentials, + UserStore: credential.NewBoltStore(dbPath), + AdminStore: NewBoltAdminStore(dbPath), + Logger: &logger, + }) + + ts := httptest.NewServer(s.Handler()) + defer ts.Close() + + jar, err := cookiejar.New(nil) + require.NoError(t, err) + client := &http.Client{Jar: jar} + + // First setup + resp, err := client.PostForm(ts.URL+"/admin/setup", url.Values{ + "username": {"admin"}, + "password": {"password123"}, + "confirm_password": {"password123"}, + }) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + _ = resp.Body.Close() + + // Second setup attempt should redirect to users (already configured and authenticated) + // Use non-following client to catch the redirect + noFollow := &http.Client{ + Jar: jar, + CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + resp, err = noFollow.PostForm(ts.URL+"/admin/setup", url.Values{ + "username": {"attacker"}, + "password": {"password123"}, + "confirm_password": {"password123"}, + }) + require.NoError(t, err) + assert.Equal(t, http.StatusSeeOther, resp.StatusCode) + assert.Equal(t, "/admin/users", resp.Header.Get("Location")) + _ = resp.Body.Close() +} + +func TestServer_SetupFlow_GetShowsForm(t *testing.T) { + logger := zerolog.New(io.Discard) + s := New(&Config{ + Credentials: credential.NewStaticCredentialStore(), + UserStore: credential.NewBoltStore(filepath.Join(t.TempDir(), "data.db")), + AdminStore: NewBoltAdminStore(filepath.Join(t.TempDir(), "admin.db")), + Logger: &logger, + }) + + ts := httptest.NewServer(s.Handler()) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/admin/setup") + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + _ = resp.Body.Close() + + // Check form fields are present + bodyStr := string(body) + assert.Contains(t, bodyStr, "username") + assert.Contains(t, bodyStr, "password") + assert.Contains(t, bodyStr, "confirm_password") +} + +func TestServer_SetupFlow_RedirectWhenAlreadyConfigured(t *testing.T) { + logger := zerolog.New(io.Discard) + s := New(&Config{ + Credentials: credential.NewStaticCredentialStore(), + UserStore: credential.NewBoltStore(filepath.Join(t.TempDir(), "data.db")), + AdminStore: newSeededAdminStore(t, "admin", "secret"), + Logger: &logger, + }) + + ts := httptest.NewServer(s.Handler()) + defer ts.Close() + + noFollow := &http.Client{CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }} + + // GET /admin/setup when admin exists should redirect to login + resp, err := noFollow.Get(ts.URL + "/admin/setup") + require.NoError(t, err) + assert.Equal(t, http.StatusSeeOther, resp.StatusCode) + assert.Equal(t, "/admin/login", resp.Header.Get("Location")) + _ = resp.Body.Close() +} + +func TestServer_LoginRedirectsToSetupWhenNoAdmin(t *testing.T) { + logger := zerolog.New(io.Discard) + s := New(&Config{ + Credentials: credential.NewStaticCredentialStore(), + UserStore: credential.NewBoltStore(filepath.Join(t.TempDir(), "data.db")), + AdminStore: NewBoltAdminStore(filepath.Join(t.TempDir(), "admin.db")), + Logger: &logger, + }) + + ts := httptest.NewServer(s.Handler()) + defer ts.Close() + + noFollow := &http.Client{CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }} + + resp, err := noFollow.Get(ts.URL + "/admin/login") + require.NoError(t, err) + assert.Equal(t, http.StatusSeeOther, resp.StatusCode) + assert.Equal(t, "/admin/setup", resp.Header.Get("Location")) + _ = resp.Body.Close() +} + +func TestServer_IndexRedirectsToSetupWhenNoAdmin(t *testing.T) { + logger := zerolog.New(io.Discard) + s := New(&Config{ + Credentials: credential.NewStaticCredentialStore(), + UserStore: credential.NewBoltStore(filepath.Join(t.TempDir(), "data.db")), + AdminStore: NewBoltAdminStore(filepath.Join(t.TempDir(), "admin.db")), + Logger: &logger, + }) + + ts := httptest.NewServer(s.Handler()) + defer ts.Close() + + noFollow := &http.Client{CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }} + + resp, err := noFollow.Get(ts.URL + "/admin") + require.NoError(t, err) + assert.Equal(t, http.StatusSeeOther, resp.StatusCode) + assert.Equal(t, "/admin/setup", resp.Header.Get("Location")) + _ = resp.Body.Close() +} + +func TestGeneratePassword(t *testing.T) { + password, err := generatePassword(20) + assert.NoError(t, err) + assert.Len(t, password, 20) + + _, err = generatePassword(0) + assert.Error(t, err) +} + +func TestServer_CreateUserRejectsInvalidUsername(t *testing.T) { + logger := zerolog.New(io.Discard) + credentials := credential.NewStaticCredentialStore() + + s := New(&Config{ + Credentials: credentials, + UserStore: credential.NewBoltStore(filepath.Join(t.TempDir(), "data.db")), + AdminStore: newSeededAdminStore(t, "admin", "secret"), + Logger: &logger, + }) + + ts := httptest.NewServer(s.Handler()) + defer ts.Close() + + jar, err := cookiejar.New(nil) + assert.NoError(t, err) + client := &http.Client{Jar: jar} + + resp, err := client.PostForm(ts.URL+"/admin/login", url.Values{ + "username": {"admin"}, + "password": {"secret"}, + }) + assert.NoError(t, err) + _ = resp.Body.Close() + + resp, err = client.Get(ts.URL + "/admin/users") + assert.NoError(t, err) + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + _ = resp.Body.Close() + + csrfToken := extractCSRFToken(t, string(body)) + + resp, err = client.PostForm(ts.URL+"/admin/users", url.Values{"username": {"bad user"}, "_csrf": {csrfToken}}) + assert.NoError(t, err) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + body, err = io.ReadAll(resp.Body) + assert.NoError(t, err) + _ = resp.Body.Close() + + assert.Contains(t, string(body), "username may only contain letters, numbers, dots, underscores, and hyphens") + assert.False(t, credentials.Exists("bad user")) +} + +func TestServer_StateChangingRoutesRequireCSRF(t *testing.T) { + logger := zerolog.New(io.Discard) + credentials := credential.NewStaticCredentialStore() + + s := New(&Config{ + Credentials: credentials, + UserStore: credential.NewBoltStore(filepath.Join(t.TempDir(), "data.db")), + AdminStore: newSeededAdminStore(t, "admin", "secret"), + Logger: &logger, + }) + + ts := httptest.NewServer(s.Handler()) + defer ts.Close() + + jar, err := cookiejar.New(nil) + assert.NoError(t, err) + client := &http.Client{Jar: jar} + + resp, err := client.PostForm(ts.URL+"/admin/login", url.Values{"username": {"admin"}, "password": {"secret"}}) + assert.NoError(t, err) + _ = resp.Body.Close() + + resp, err = client.PostForm(ts.URL+"/admin/users", url.Values{"username": {"proxyuser"}}) + assert.NoError(t, err) + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + _ = resp.Body.Close() +} + +func TestServer_CSRFTokenRotationRejectsOldToken(t *testing.T) { + logger := zerolog.New(io.Discard) + credentials := credential.NewStaticCredentialStore() + + s := New(&Config{ + Credentials: credentials, + UserStore: credential.NewBoltStore(filepath.Join(t.TempDir(), "data.db")), + AdminStore: newSeededAdminStore(t, "admin", "secret"), + Logger: &logger, + }) + + ts := httptest.NewServer(s.Handler()) + defer ts.Close() + + jar, err := cookiejar.New(nil) + assert.NoError(t, err) + client := &http.Client{Jar: jar} + + resp, err := client.PostForm(ts.URL+"/admin/login", url.Values{"username": {"admin"}, "password": {"secret"}}) + assert.NoError(t, err) + _ = resp.Body.Close() + + resp, err = client.Get(ts.URL + "/admin/users") + assert.NoError(t, err) + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + _ = resp.Body.Close() + + firstToken := extractCSRFToken(t, string(body)) + + resp, err = client.PostForm(ts.URL+"/admin/users", url.Values{"username": {"userone"}, "_csrf": {firstToken}}) + assert.NoError(t, err) + body, err = io.ReadAll(resp.Body) + assert.NoError(t, err) + _ = resp.Body.Close() + + rotatedToken := extractCSRFToken(t, string(body)) + assert.NotEqual(t, firstToken, rotatedToken) + + resp, err = client.PostForm(ts.URL+"/admin/users", url.Values{"username": {"usertwo"}, "_csrf": {firstToken}}) + assert.NoError(t, err) + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + _ = resp.Body.Close() +} + +func TestServer_OriginPolicy(t *testing.T) { + logger := zerolog.New(io.Discard) + credentials := credential.NewStaticCredentialStore() + + s := New(&Config{ + Credentials: credentials, + UserStore: credential.NewBoltStore(filepath.Join(t.TempDir(), "data.db")), + AdminStore: newSeededAdminStore(t, "admin", "secret"), + AllowedOrigins: []string{"http://allowed.local"}, + Logger: &logger, + }) + + ts := httptest.NewServer(s.Handler()) + defer ts.Close() + + resp, err := http.Post(ts.URL+"/admin/login", "application/x-www-form-urlencoded", strings.NewReader("username=admin&password=secret")) + assert.NoError(t, err) + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + _ = resp.Body.Close() + + req, err := http.NewRequest(http.MethodPost, ts.URL+"/admin/login", strings.NewReader("username=admin&password=secret")) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Origin", "http://allowed.local") + + client := &http.Client{} + resp, err = client.Do(req) + assert.NoError(t, err) + assert.NotEqual(t, http.StatusForbidden, resp.StatusCode) + _ = resp.Body.Close() +} + +// loginHelper logs the given admin in and returns a client with a valid session +// plus the current CSRF token extracted from /admin/users. +func loginHelper(t *testing.T, baseURL string) (*http.Client, string) { + t.Helper() + + jar, err := cookiejar.New(nil) + require.NoError(t, err) + client := &http.Client{Jar: jar} + + resp, err := client.PostForm(baseURL+"/admin/login", url.Values{ + "username": {"admin"}, + "password": {"secret"}, + }) + require.NoError(t, err) + _ = resp.Body.Close() + + resp, err = client.Get(baseURL + "/admin/users") + require.NoError(t, err) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + _ = resp.Body.Close() + + return client, extractCSRFToken(t, string(body)) +} + +// newAdminServer creates a Server with an in-memory credential store and a temp-dir BoltDB store. +func newAdminServer(t *testing.T) (*Server, *httptest.Server) { + t.Helper() + logger := zerolog.New(io.Discard) + creds := credential.NewStaticCredentialStore() + s := New(&Config{ + Credentials: creds, + UserStore: credential.NewBoltStore(filepath.Join(t.TempDir(), "data.db")), + AdminStore: newSeededAdminStore(t, "admin", "secret"), + Logger: &logger, + }) + ts := httptest.NewServer(s.Handler()) + t.Cleanup(ts.Close) + return s, ts +} + +func TestServer_NilLogger(t *testing.T) { + creds := credential.NewStaticCredentialStore() + s := New(&Config{ + Credentials: creds, + UserStore: credential.NewBoltStore(filepath.Join(t.TempDir(), "data.db")), + AdminStore: newSeededAdminStore(t, "admin", "secret"), + // Logger intentionally nil — default logger must be created + }) + assert.NotNil(t, s) + assert.NotNil(t, s.config.Logger) +} + +func TestServer_Root_NotFound(t *testing.T) { + _, ts := newAdminServer(t) + + resp, err := http.Get(ts.URL + "/does-not-exist") + assert.NoError(t, err) + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + _ = resp.Body.Close() +} + +func TestServer_Root_Redirect(t *testing.T) { + _, ts := newAdminServer(t) + + noFollow := &http.Client{CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }} + + resp, err := noFollow.Get(ts.URL + "/") + assert.NoError(t, err) + assert.Equal(t, http.StatusSeeOther, resp.StatusCode) + assert.Equal(t, "/admin", resp.Header.Get("Location")) + _ = resp.Body.Close() +} + +func TestServer_Index_Unauthenticated(t *testing.T) { + _, ts := newAdminServer(t) + + jar, _ := cookiejar.New(nil) + client := &http.Client{Jar: jar} + + resp, err := client.Get(ts.URL + "/admin") + assert.NoError(t, err) + assert.Equal(t, "/admin/login", resp.Request.URL.Path) + _ = resp.Body.Close() +} + +func TestServer_Index_Authenticated(t *testing.T) { + _, ts := newAdminServer(t) + client, _ := loginHelper(t, ts.URL) + + noFollow := &http.Client{ + Jar: client.Jar, + CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }, + } + resp, err := noFollow.Get(ts.URL + "/admin") + assert.NoError(t, err) + assert.Equal(t, http.StatusSeeOther, resp.StatusCode) + assert.Equal(t, "/admin/users", resp.Header.Get("Location")) + _ = resp.Body.Close() +} + +func TestServer_LoginPage_GetShowsForm(t *testing.T) { + _, ts := newAdminServer(t) + + resp, err := http.Get(ts.URL + "/admin/login") + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + _ = resp.Body.Close() +} + +func TestServer_LoginPage_GetRedirectsWhenAuthenticated(t *testing.T) { + _, ts := newAdminServer(t) + client, _ := loginHelper(t, ts.URL) + + noFollow := &http.Client{ + Jar: client.Jar, + CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }, + } + resp, err := noFollow.Get(ts.URL + "/admin/login") + assert.NoError(t, err) + assert.Equal(t, http.StatusSeeOther, resp.StatusCode) + assert.Equal(t, "/admin/users", resp.Header.Get("Location")) + _ = resp.Body.Close() +} + +func TestServer_Login_MethodNotAllowed(t *testing.T) { + _, ts := newAdminServer(t) + + req, err := http.NewRequest(http.MethodPut, ts.URL+"/admin/login", nil) + assert.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) + _ = resp.Body.Close() +} + +func TestServer_Logout_MethodNotAllowed(t *testing.T) { + _, ts := newAdminServer(t) + + resp, err := http.Get(ts.URL + "/admin/logout") + assert.NoError(t, err) + assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) + _ = resp.Body.Close() +} + +func TestServer_Logout_InvalidCSRF(t *testing.T) { + _, ts := newAdminServer(t) + + resp, err := http.Post(ts.URL+"/admin/logout", "application/x-www-form-urlencoded", + strings.NewReader("_csrf=bad-token")) + assert.NoError(t, err) + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + _ = resp.Body.Close() +} + +func TestServer_Logout_Valid(t *testing.T) { + _, ts := newAdminServer(t) + client, csrfToken := loginHelper(t, ts.URL) + + // POST logout with valid CSRF token + resp, err := client.PostForm(ts.URL+"/admin/logout", url.Values{"_csrf": {csrfToken}}) + assert.NoError(t, err) + assert.Equal(t, "/admin/login", resp.Request.URL.Path) + _ = resp.Body.Close() + + // After logout, /admin/users must redirect to login + resp, err = client.Get(ts.URL + "/admin/users") + assert.NoError(t, err) + assert.Equal(t, "/admin/login", resp.Request.URL.Path) + _ = resp.Body.Close() +} + +func TestServer_RateLimiting(t *testing.T) { + logger := zerolog.New(io.Discard) + creds := credential.NewStaticCredentialStore() + s := New(&Config{ + Credentials: creds, + UserStore: credential.NewBoltStore(filepath.Join(t.TempDir(), "data.db")), + AdminStore: newSeededAdminStore(t, "admin", "secret"), + MaxLoginAttempts: 3, + LoginWindow: time.Minute, + LockoutDuration: 50 * time.Millisecond, + Logger: &logger, + }) + ts := httptest.NewServer(s.Handler()) + defer ts.Close() + + // Exhaust the login attempts + for i := 0; i < 3; i++ { + resp, err := http.Post(ts.URL+"/admin/login", "application/x-www-form-urlencoded", + strings.NewReader("username=admin&password=wrong")) + assert.NoError(t, err) + _ = resp.Body.Close() + } + + // Next attempt must be rejected with 429 + resp, err := http.Post(ts.URL+"/admin/login", "application/x-www-form-urlencoded", + strings.NewReader("username=admin&password=secret")) + assert.NoError(t, err) + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + _ = resp.Body.Close() + + // Wait for lockout to expire + time.Sleep(100 * time.Millisecond) + + // Now a valid login should succeed again + jar, _ := cookiejar.New(nil) + client := &http.Client{Jar: jar} + resp, err = client.PostForm(ts.URL+"/admin/login", url.Values{ + "username": {"admin"}, + "password": {"secret"}, + }) + assert.NoError(t, err) + assert.Equal(t, "/admin/users", resp.Request.URL.Path) + _ = resp.Body.Close() +} + +func TestServer_Users_MethodNotAllowed(t *testing.T) { + _, ts := newAdminServer(t) + client, _ := loginHelper(t, ts.URL) + + req, err := http.NewRequest(http.MethodPut, ts.URL+"/admin/users", nil) + assert.NoError(t, err) + resp, err := client.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) + _ = resp.Body.Close() +} + +func TestServer_UserByName_EmptyUsername(t *testing.T) { + _, ts := newAdminServer(t) + client, _ := loginHelper(t, ts.URL) + + resp, err := client.Get(ts.URL + "/admin/users/") + assert.NoError(t, err) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + _ = resp.Body.Close() +} + +func TestServer_UserByName_MethodNotAllowed(t *testing.T) { + _, ts := newAdminServer(t) + client, _ := loginHelper(t, ts.URL) + + resp, err := client.Get(ts.URL + "/admin/users/someuser") + assert.NoError(t, err) + assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) + _ = resp.Body.Close() +} + +func TestServer_UserByName_UnauthenticatedRedirectsToLogin(t *testing.T) { + _, ts := newAdminServer(t) + + jar, _ := cookiejar.New(nil) + client := &http.Client{Jar: jar} + + resp, err := client.Get(ts.URL + "/admin/users/anyuser") + assert.NoError(t, err) + assert.Equal(t, "/admin/login", resp.Request.URL.Path) + _ = resp.Body.Close() +} + +func TestServer_UserRows_UnauthenticatedHTMXRedirectsToLogin(t *testing.T) { + _, ts := newAdminServer(t) + + req, err := http.NewRequest(http.MethodGet, ts.URL+"/admin/users/rows", nil) + require.NoError(t, err) + req.Header.Set("HX-Request", "true") + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.Equal(t, "/admin/login", resp.Header.Get("HX-Redirect")) +} + +func TestServer_ResetPassword_UserNotFound(t *testing.T) { + _, ts := newAdminServer(t) + client, csrfToken := loginHelper(t, ts.URL) + + resp, err := client.PostForm(ts.URL+"/admin/users/nonexistent/reset-password", + url.Values{"_csrf": {csrfToken}}) + assert.NoError(t, err) + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + _ = resp.Body.Close() +} + +func TestServer_ProxyUsersWithTraffic_RecentSessionStillActive(t *testing.T) { + logger := zerolog.New(io.Discard) + creds := credential.NewStaticCredentialStore() + creds.Add("alice", "password") + + tracker := traffic.NewTracker() + sess := tracker.Start("alice", "10.0.0.2") + sess.AddUpload(256) + sess.AddDownload(512) + sess.Close() + + s := New(&Config{ + Credentials: creds, + Tracker: tracker, + AdminStore: newSeededAdminStore(t, "admin", "secret"), + Logger: &logger, + }) + + rows := s.proxyUsersWithTraffic() + require.Len(t, rows, 1) + assert.Equal(t, "alice", rows[0].Username) + assert.Equal(t, "Active", rows[0].Status) + assert.Equal(t, 0, rows[0].ActiveClients) + assert.Equal(t, "10.0.0.2", rows[0].ClientIP) + assert.NotEqual(t, "0 B", rows[0].DownloadTotal) + assert.NotEqual(t, "0 B", rows[0].UploadTotal) +} + +func TestServer_NilUserStore_CreateUser(t *testing.T) { + logger := zerolog.New(io.Discard) + creds := credential.NewStaticCredentialStore() + s := New(&Config{ + Credentials: creds, + UserStore: nil, // nil — persistUsers must short-circuit + AdminStore: newSeededAdminStore(t, "admin", "secret"), + Logger: &logger, + }) + ts := httptest.NewServer(s.Handler()) + defer ts.Close() + + client, csrf := loginHelper(t, ts.URL) + + resp, err := client.PostForm(ts.URL+"/admin/users", + url.Values{"username": {"newuser"}, "_csrf": {csrf}}) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + _ = resp.Body.Close() + + assert.True(t, creds.Exists("newuser")) +} + +func TestServer_SessionExpired(t *testing.T) { + s, ts := newAdminServer(t) + + jar, _ := cookiejar.New(nil) + client := &http.Client{Jar: jar} + + resp, err := client.PostForm(ts.URL+"/admin/login", url.Values{ + "username": {"admin"}, + "password": {"secret"}, + }) + require.NoError(t, err) + _ = resp.Body.Close() + + // Expire every session manually + s.mu.Lock() + for k, v := range s.sessions { + v.ExpiresAt = time.Now().Add(-time.Hour) + s.sessions[k] = v + } + s.mu.Unlock() + + // After expiry, /admin/users must redirect to login + resp, err = client.Get(ts.URL + "/admin/users") + assert.NoError(t, err) + assert.Equal(t, "/admin/login", resp.Request.URL.Path) + _ = resp.Body.Close() +} + +func TestServer_SessionInvalidToken(t *testing.T) { + _, ts := newAdminServer(t) + + noFollow := &http.Client{CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }} + + req, _ := http.NewRequest(http.MethodGet, ts.URL+"/admin/users", nil) + req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "token-that-does-not-exist"}) + resp, err := noFollow.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusSeeOther, resp.StatusCode) + _ = resp.Body.Close() +} + +func TestNormalizeOrigin_NoScheme(t *testing.T) { + _, err := normalizeOrigin("no-scheme") + assert.Error(t, err) +} + +func TestNormalizeOrigin_NoHost(t *testing.T) { + _, err := normalizeOrigin("http://") + assert.Error(t, err) +} + +func TestNormalizeAllowedOrigins_InvalidEntry(t *testing.T) { + result := normalizeAllowedOrigins([]string{"bad-origin", "http://valid.com"}) + assert.Equal(t, []string{"http://valid.com"}, result) +} + +func TestNormalizeAllowedOrigins_DuplicateEntry(t *testing.T) { + result := normalizeAllowedOrigins([]string{"http://x.com", "http://x.com", "http://y.com"}) + assert.Equal(t, []string{"http://x.com", "http://y.com"}, result) +} + +func TestExtractClientIP_NoPort(t *testing.T) { + ip := extractClientIP("192.168.1.1") + assert.Equal(t, "192.168.1.1", ip) +} + +func TestServer_OriginPolicy_InvalidOriginHeader(t *testing.T) { + logger := zerolog.New(io.Discard) + creds := credential.NewStaticCredentialStore() + s := New(&Config{ + Credentials: creds, + UserStore: credential.NewBoltStore(filepath.Join(t.TempDir(), "data.db")), + AdminStore: newSeededAdminStore(t, "admin", "secret"), + AllowedOrigins: []string{"http://allowed.local"}, + Logger: &logger, + }) + ts := httptest.NewServer(s.Handler()) + defer ts.Close() + + req, _ := http.NewRequest(http.MethodPost, ts.URL+"/admin/login", + strings.NewReader("username=admin&password=secret")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Origin", "://bad-origin") // malformed + resp, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + _ = resp.Body.Close() +} + +func TestServer_OriginPolicy_WrongOrigin(t *testing.T) { + logger := zerolog.New(io.Discard) + creds := credential.NewStaticCredentialStore() + s := New(&Config{ + Credentials: creds, + UserStore: credential.NewBoltStore(filepath.Join(t.TempDir(), "data.db")), + AdminStore: newSeededAdminStore(t, "admin", "secret"), + AllowedOrigins: []string{"http://allowed.local"}, + Logger: &logger, + }) + ts := httptest.NewServer(s.Handler()) + defer ts.Close() + + req, _ := http.NewRequest(http.MethodPost, ts.URL+"/admin/login", + strings.NewReader("username=admin&password=secret")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Origin", "http://evil.com") + resp, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + _ = resp.Body.Close() +} + +func TestServer_OriginPolicy_RefererAllowed(t *testing.T) { + logger := zerolog.New(io.Discard) + creds := credential.NewStaticCredentialStore() + s := New(&Config{ + Credentials: creds, + UserStore: credential.NewBoltStore(filepath.Join(t.TempDir(), "data.db")), + AdminStore: newSeededAdminStore(t, "admin", "secret"), + AllowedOrigins: []string{"http://allowed.local"}, + Logger: &logger, + }) + ts := httptest.NewServer(s.Handler()) + defer ts.Close() + + req, _ := http.NewRequest(http.MethodPost, ts.URL+"/admin/login", + strings.NewReader("username=admin&password=wrong")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Referer", "http://allowed.local/admin/login") + resp, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + // wrong password → 401, but origin was accepted (not 403) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + _ = resp.Body.Close() +} + +func TestServer_OriginPolicy_RefererNotAllowed(t *testing.T) { + logger := zerolog.New(io.Discard) + creds := credential.NewStaticCredentialStore() + s := New(&Config{ + Credentials: creds, + UserStore: credential.NewBoltStore(filepath.Join(t.TempDir(), "data.db")), + AdminStore: newSeededAdminStore(t, "admin", "secret"), + AllowedOrigins: []string{"http://allowed.local"}, + Logger: &logger, + }) + ts := httptest.NewServer(s.Handler()) + defer ts.Close() + + req, _ := http.NewRequest(http.MethodPost, ts.URL+"/admin/login", + strings.NewReader("username=admin&password=secret")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Referer", "http://evil.com/path") + resp, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + _ = resp.Body.Close() +} + +func TestServer_OriginPolicy_BadReferer(t *testing.T) { + logger := zerolog.New(io.Discard) + creds := credential.NewStaticCredentialStore() + s := New(&Config{ + Credentials: creds, + UserStore: credential.NewBoltStore(filepath.Join(t.TempDir(), "data.db")), + AdminStore: newSeededAdminStore(t, "admin", "secret"), + AllowedOrigins: []string{"http://allowed.local"}, + Logger: &logger, + }) + ts := httptest.NewServer(s.Handler()) + defer ts.Close() + + req, _ := http.NewRequest(http.MethodPost, ts.URL+"/admin/login", + strings.NewReader("username=admin&password=secret")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Referer", "not-a-url-no-scheme") + resp, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + _ = resp.Body.Close() +} + +type stubTrafficStore struct { + saved map[string]traffic.UserTotals + err error +} + +func (s *stubTrafficStore) LoadTraffic() (map[string]traffic.UserTotals, error) { + return nil, nil +} + +func (s *stubTrafficStore) SaveTraffic(totals map[string]traffic.UserTotals) error { + s.saved = totals + return s.err +} + +func (s *stubTrafficStore) ResetUserTraffic(username string) error { + return nil +} + +func TestFormatBytesAndRate(t *testing.T) { + assert.Equal(t, "0 B", formatBytes(0)) + assert.Equal(t, "1023 B", formatBytes(1023)) + assert.Equal(t, "1.00 KB", formatBytes(1024)) + assert.Equal(t, "1.00 MB", formatBytes(1024*1024)) + assert.Equal(t, "1.00 GB", formatBytes(1024*1024*1024)) + assert.Equal(t, "1.00 KB/s", formatByteRate(1024)) +} + +func TestFormatStartedAgoVariants(t *testing.T) { + assert.Equal(t, "unknown", formatStartedAgo(time.Time{})) + assert.Equal(t, "just now", formatStartedAgo(time.Now().Add(-30*time.Second))) + assert.Equal(t, "5 minutes ago", formatStartedAgo(time.Now().Add(-5*time.Minute))) + assert.Equal(t, "2 hours ago", formatStartedAgo(time.Now().Add(-2*time.Hour))) + assert.Equal(t, "3 days ago", formatStartedAgo(time.Now().Add(-72*time.Hour))) +} + +func TestServer_PersistTraffic(t *testing.T) { + tracker := traffic.NewTracker() + store := &stubTrafficStore{} + + s := New(&Config{ + Credentials: credential.NewStaticCredentialStore(), + AdminStore: newSeededAdminStore(t, "admin", "secret"), + TrafficStore: store, + Tracker: tracker, + }) + + // No active sessions means empty totals map should still be persisted. + err := s.persistTraffic() + assert.NoError(t, err) + assert.NotNil(t, store.saved) +} + +func TestServer_PersistTraffic_NoStoreOrTracker(t *testing.T) { + s := New(&Config{ + Credentials: credential.NewStaticCredentialStore(), + AdminStore: newSeededAdminStore(t, "admin", "secret"), + }) + + assert.NoError(t, s.persistTraffic()) +} + +func TestServer_PersistTraffic_SaveError(t *testing.T) { + tracker := traffic.NewTracker() + store := &stubTrafficStore{err: assert.AnError} + + s := New(&Config{ + Credentials: credential.NewStaticCredentialStore(), + AdminStore: newSeededAdminStore(t, "admin", "secret"), + TrafficStore: store, + Tracker: tracker, + }) + + err := s.persistTraffic() + assert.ErrorIs(t, err, assert.AnError) +} + +func extractGeneratedPassword(t *testing.T, content string) string { + t.Helper() + + re := regexp.MustCompile(`]*>([^<]+)`) + matches := re.FindStringSubmatch(content) + if len(matches) != 2 { + t.Fatalf("generated password not found in response: %s", content) + } + + return html.UnescapeString(matches[1]) +} + +func extractCSRFToken(t *testing.T, content string) string { + t.Helper() + + re := regexp.MustCompile(``) + matches := re.FindStringSubmatch(content) + if len(matches) != 2 { + t.Fatalf("csrf token not found in response: %s", content) + } + + return matches[1] +} diff --git a/pkg/admin/templates/login.gohtml b/pkg/admin/templates/login.gohtml new file mode 100644 index 0000000..29ebf24 --- /dev/null +++ b/pkg/admin/templates/login.gohtml @@ -0,0 +1,35 @@ + + + + + + NanoProxy Admin Login + + + +
+

NanoProxy

+

Admin Console

+ {{if .Error}} +
{{.Error}}
+ {{end}} +
+
+ + +
+
+ + +
+ +
+
+ + + diff --git a/pkg/admin/templates/setup.gohtml b/pkg/admin/templates/setup.gohtml new file mode 100644 index 0000000..737b212 --- /dev/null +++ b/pkg/admin/templates/setup.gohtml @@ -0,0 +1,42 @@ + + + + + + NanoProxy Admin Setup + + + +
+

NanoProxy

+

Initial Admin Setup

+

Create your first admin account to access the console.

+ {{if .Error}} +
{{.Error}}
+ {{end}} +
+
+ + +
+
+ + +
+
+ + +
+ +
+
+ + + diff --git a/pkg/admin/templates/toast.gohtml b/pkg/admin/templates/toast.gohtml new file mode 100644 index 0000000..50220be --- /dev/null +++ b/pkg/admin/templates/toast.gohtml @@ -0,0 +1,14 @@ +{{if .Error}} +
+

Action failed

+

{{.Error}}

+
+{{end}} +{{if and .Success (not .GeneratedPassword)}} +
+

Done

+

{{.Success}}

+
+{{end}} diff --git a/pkg/admin/templates/traffic_rows.gohtml b/pkg/admin/templates/traffic_rows.gohtml new file mode 100644 index 0000000..e69de29 diff --git a/pkg/admin/templates/user_rows.gohtml b/pkg/admin/templates/user_rows.gohtml new file mode 100644 index 0000000..14b5b0f --- /dev/null +++ b/pkg/admin/templates/user_rows.gohtml @@ -0,0 +1,84 @@ +{{range .ProxyUsers}} + + +
+ {{.Username}} + {{if ne .ClientIP "-"}} + {{.ClientIP}} + {{end}} + + {{.Status}} + +
+ + +
+ ↓ {{.DownloadRate}} + {{.DownloadTotal}} +
+ + +
+ ↑ {{.UploadRate}} + {{.UploadTotal}} +
+ + +
+ + + + + + +
+ + +{{else}} + + +

No proxy users yet

+ + +{{end}} diff --git a/pkg/admin/templates/users.gohtml b/pkg/admin/templates/users.gohtml new file mode 100644 index 0000000..7f8d92e --- /dev/null +++ b/pkg/admin/templates/users.gohtml @@ -0,0 +1,328 @@ + + + + + + + NanoProxy Admin + + + + + + +
+ {{template "toast.gohtml" .}} +
+ +
+
+
+
+

NanoProxy

+

Admin Console

+

Manage proxy users and access credentials.

+
+
+ +
+ + +
+
+
+
+ + {{if .GeneratedPassword}} +
+
+
+

{{if .Success}}{{.Success}}{{else}}Generated credentials ready.{{end}}

+

Store this password now. It is only shown once.

+
+ +
+
+
+

Username

+ {{.GeneratedUsername}} +
+
+

Password

+ {{.GeneratedPassword}} +
+
+
+ {{end}} + +
+
+
+

Proxy users

+ + Total: + {{.TotalUsers}} + + Live +
+ +
+ +
+ + + + + + + + + + + {{template "user_rows.gohtml" .}} + +
User + Download + Upload + + Actions +
+
+ +
+
+ + + + + + diff --git a/pkg/config/config.go b/pkg/config/config.go index d7239a4..efff6c6 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -3,13 +3,20 @@ package config import "time" type Config struct { - Timezone string `env:"TZ" envDefault:"Local"` - Network string `env:"NETWORK" envDefault:"tcp"` - ADDR string `env:"ADDR" envDefault:":1080"` - ADDRHttp string `env:"ADDR_HTTP" envDefault:":8080"` - Credentials []string `env:"CREDENTIALS" envSeparator:","` - ClientTimeout time.Duration `env:"CLIENT_TIMEOUT" envDefault:"15s"` - DestTimeout time.Duration `env:"DEST_TIMEOUT" envDefault:"15s"` - TorEnabled bool `env:"TOR_ENABLED" envDefault:"false"` - TorIdentityInterval time.Duration `env:"TOR_IDENTITY_INTERVAL" envDefault:"10m"` + Timezone string `env:"TZ" envDefault:"Local"` + LogLevel string `env:"LOG_LEVEL" envDefault:"info"` + Network string `env:"NETWORK" envDefault:"tcp"` + ADDR string `env:"ADDR" envDefault:":1080"` + ADDRHttp string `env:"ADDR_HTTP" envDefault:":8080"` + ADDRAdmin string `env:"ADDR_ADMIN" envDefault:":9090"` + UserStorePath string `env:"USER_STORE_PATH" envDefault:"nanoproxy-data.db"` + AdminCookieSecure bool `env:"ADMIN_COOKIE_SECURE" envDefault:"false"` + AdminMaxLoginAttempts int `env:"ADMIN_MAX_LOGIN_ATTEMPTS" envDefault:"5"` + AdminLoginWindow time.Duration `env:"ADMIN_LOGIN_WINDOW" envDefault:"5m"` + AdminLockoutDuration time.Duration `env:"ADMIN_LOCKOUT_DURATION" envDefault:"10m"` + AdminAllowedOrigins []string `env:"ADMIN_ALLOWED_ORIGINS" envSeparator:","` + ClientTimeout time.Duration `env:"CLIENT_TIMEOUT" envDefault:"15s"` + DestTimeout time.Duration `env:"DEST_TIMEOUT" envDefault:"15s"` + TorEnabled bool `env:"TOR_ENABLED" envDefault:"false"` + TorIdentityInterval time.Duration `env:"TOR_IDENTITY_INTERVAL" envDefault:"10m"` } diff --git a/pkg/credential/bolt_store.go b/pkg/credential/bolt_store.go new file mode 100644 index 0000000..0d8a3d8 --- /dev/null +++ b/pkg/credential/bolt_store.go @@ -0,0 +1,100 @@ +package credential + +import ( + "errors" + "os" + "path/filepath" + + "go.etcd.io/bbolt" +) + +var usersBucket = []byte("users") + +type BoltStore struct { + path string +} + +func NewBoltStore(path string) *BoltStore { + return &BoltStore{path: path} +} + +func (b *BoltStore) Path() string { + if b == nil { + return "" + } + + return b.path +} + +func (b *BoltStore) Load() (map[string]string, error) { + if b == nil || b.path == "" { + return map[string]string{}, nil + } + + if _, err := os.Stat(b.path); err != nil { + if errors.Is(err, os.ErrNotExist) { + return map[string]string{}, nil + } + return nil, err + } + + db, err := bbolt.Open(b.path, 0o600, nil) + if err != nil { + return nil, err + } + defer db.Close() + + snapshot := map[string]string{} + err = db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(usersBucket) + if bucket == nil { + return nil + } + + return bucket.ForEach(func(k, v []byte) error { + snapshot[string(k)] = string(v) + return nil + }) + }) + if err != nil { + return nil, err + } + + return snapshot, nil +} + +func (b *BoltStore) Save(snapshot map[string]string) error { + if b == nil || b.path == "" { + return nil + } + + dir := filepath.Dir(b.path) + if dir != "." { + if err := os.MkdirAll(dir, 0o750); err != nil { + return err + } + } + + db, err := bbolt.Open(b.path, 0o600, nil) + if err != nil { + return err + } + defer db.Close() + + return db.Update(func(tx *bbolt.Tx) error { + _ = tx.DeleteBucket(usersBucket) + + bucket, err := tx.CreateBucket(usersBucket) + if err != nil { + return err + } + + for username, passwordHash := range snapshot { + if err := bucket.Put([]byte(username), []byte(passwordHash)); err != nil { + return err + } + } + + return nil + }) +} diff --git a/pkg/credential/bolt_store_test.go b/pkg/credential/bolt_store_test.go new file mode 100644 index 0000000..ffadc2a --- /dev/null +++ b/pkg/credential/bolt_store_test.go @@ -0,0 +1,36 @@ +package credential + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBoltStore_SaveAndLoad(t *testing.T) { + t.Parallel() + + path := filepath.Join(t.TempDir(), "data.db") + store := NewBoltStore(path) + + snapshot := map[string]string{ + "alice": "$2a$10$JGKWBfX0VTqflV6kNfSLweBzA6YxQ8fFiQvCg2Vf1uNhM6o6z8brS", + "bob": "$2a$10$wRPS8Qnmfjzb2n4h2ZVqKegc7MypvJ.p3nQoIc0K2fWzEo.5hF7R2", + } + + require.NoError(t, store.Save(snapshot)) + + restored, err := store.Load() + require.NoError(t, err) + assert.Equal(t, snapshot, restored) +} + +func TestBoltStore_Load_FileNotExist(t *testing.T) { + t.Parallel() + + store := NewBoltStore(filepath.Join(t.TempDir(), "missing.db")) + restored, err := store.Load() + require.NoError(t, err) + assert.Empty(t, restored) +} diff --git a/pkg/credential/credentials.go b/pkg/credential/credentials.go index fab2a9c..9d57824 100644 --- a/pkg/credential/credentials.go +++ b/pkg/credential/credentials.go @@ -1,6 +1,9 @@ package credential import ( + "sort" + "sync" + "golang.org/x/crypto/bcrypt" ) @@ -9,8 +12,42 @@ type Store interface { Valid(user, password string) bool } +type CombinedStore struct { + stores []Store +} + +func NewCombinedStore(stores ...Store) *CombinedStore { + filtered := make([]Store, 0, len(stores)) + for _, store := range stores { + if store != nil { + filtered = append(filtered, store) + } + } + + return &CombinedStore{stores: filtered} +} + +func (s *CombinedStore) Add(user, password string) { + if len(s.stores) == 0 { + return + } + + s.stores[0].Add(user, password) +} + +func (s *CombinedStore) Valid(user, password string) bool { + for _, store := range s.stores { + if store.Valid(user, password) { + return true + } + } + + return false +} + type StaticCredentialStore struct { store map[string]string + mu sync.RWMutex } func NewStaticCredentialStore() *StaticCredentialStore { @@ -19,23 +56,25 @@ func NewStaticCredentialStore() *StaticCredentialStore { } } -func (s StaticCredentialStore) Add(user, password string) { - if _, err := bcrypt.Cost([]byte(password)); err == nil { - // The credential is already a bcrypt hash, keep it as-is. - s.store[user] = password - return - } - - hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) +func (s *StaticCredentialStore) Add(user, password string) { + hash, err := normalizePassword(password) if err != nil { - // Fail closed if hashing fails. return } - s.store[user] = string(hash) + s.SetHashed(user, hash) +} + +func (s *StaticCredentialStore) SetHashed(user, passwordHash string) { + s.mu.Lock() + defer s.mu.Unlock() + s.store[user] = passwordHash } -func (s StaticCredentialStore) Valid(user, password string) bool { +func (s *StaticCredentialStore) Valid(user, password string) bool { + s.mu.RLock() + defer s.mu.RUnlock() + pass, ok := s.store[user] if !ok { return false @@ -44,3 +83,87 @@ func (s StaticCredentialStore) Valid(user, password string) bool { err := bcrypt.CompareHashAndPassword([]byte(pass), []byte(password)) return err == nil } + +func (s *StaticCredentialStore) Delete(user string) bool { + s.mu.Lock() + defer s.mu.Unlock() + + if _, ok := s.store[user]; !ok { + return false + } + + delete(s.store, user) + return true +} + +func (s *StaticCredentialStore) ListUsers() []string { + s.mu.RLock() + defer s.mu.RUnlock() + + users := make([]string, 0, len(s.store)) + for user := range s.store { + users = append(users, user) + } + + sort.Strings(users) + return users +} + +func (s *StaticCredentialStore) Count() int { + s.mu.RLock() + defer s.mu.RUnlock() + + return len(s.store) +} + +func (s *StaticCredentialStore) Exists(user string) bool { + s.mu.RLock() + defer s.mu.RUnlock() + + _, ok := s.store[user] + return ok +} + +func (s *StaticCredentialStore) GetHashed(user string) (string, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + + pass, ok := s.store[user] + return pass, ok +} + +func (s *StaticCredentialStore) Snapshot() map[string]string { + s.mu.RLock() + defer s.mu.RUnlock() + + clone := make(map[string]string, len(s.store)) + for user, pass := range s.store { + clone[user] = pass + } + + return clone +} + +func (s *StaticCredentialStore) Replace(snapshot map[string]string) { + s.mu.Lock() + defer s.mu.Unlock() + + s.store = make(map[string]string, len(snapshot)) + for user, pass := range snapshot { + s.store[user] = pass + } +} + +func normalizePassword(password string) (string, error) { + if _, err := bcrypt.Cost([]byte(password)); err == nil { + // The credential is already a bcrypt hash, keep it as-is. + return password, nil + } + + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return "", err + } + + return string(hash), nil +} diff --git a/pkg/credential/credentials_test.go b/pkg/credential/credentials_test.go index 76ba7cb..70985da 100644 --- a/pkg/credential/credentials_test.go +++ b/pkg/credential/credentials_test.go @@ -34,3 +34,24 @@ func Test_CredentialStore_AddBcryptHash_ThenValid(t *testing.T) { assert.True(t, s.Valid("foo", "bar")) assert.False(t, s.Valid("foo", "baz")) } + +func Test_CredentialStore_Delete(t *testing.T) { + s := NewStaticCredentialStore() + s.Add("foo", "bar") + + assert.True(t, s.Delete("foo")) + assert.False(t, s.Valid("foo", "bar")) + assert.False(t, s.Delete("foo")) +} + +func Test_CredentialStore_ListUsers(t *testing.T) { + s := NewStaticCredentialStore() + s.Add("charlie", "secret") + s.Add("alice", "secret") + s.Add("bob", "secret") + + assert.Equal(t, []string{"alice", "bob", "charlie"}, s.ListUsers()) + assert.Equal(t, 3, s.Count()) + assert.True(t, s.Exists("alice")) + assert.False(t, s.Exists("nobody")) +} diff --git a/pkg/credential/persistent_store.go b/pkg/credential/persistent_store.go new file mode 100644 index 0000000..9fa448e --- /dev/null +++ b/pkg/credential/persistent_store.go @@ -0,0 +1,21 @@ +package credential + +// PersistentStore provides durable storage for proxy user credential snapshots. +type PersistentStore interface { + Load() (map[string]string, error) + Save(snapshot map[string]string) error +} + +func LoadInto(persistentStore PersistentStore, store *StaticCredentialStore) error { + if store == nil || persistentStore == nil { + return nil + } + + snapshot, err := persistentStore.Load() + if err != nil { + return err + } + + store.Replace(snapshot) + return nil +} diff --git a/pkg/credential/persistent_store_test.go b/pkg/credential/persistent_store_test.go new file mode 100644 index 0000000..cb534bd --- /dev/null +++ b/pkg/credential/persistent_store_test.go @@ -0,0 +1,39 @@ +package credential + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLoadInto_NilStore(t *testing.T) { + t.Parallel() + + err := LoadInto(nil, nil) + assert.NoError(t, err) +} + +func TestLoadInto_NilPersistentStore(t *testing.T) { + t.Parallel() + + target := NewStaticCredentialStore() + err := LoadInto(nil, target) + assert.NoError(t, err) +} + +func TestLoadInto_FromBoltStore(t *testing.T) { + t.Parallel() + + boltStore := NewBoltStore(t.TempDir() + "/data.db") + seed := NewStaticCredentialStore() + seed.Add("alice", "alice-pass") + seed.Add("bob", "bob-pass") + + assert.NoError(t, boltStore.Save(seed.Snapshot())) + + target := NewStaticCredentialStore() + err := LoadInto(boltStore, target) + assert.NoError(t, err) + assert.True(t, target.Valid("alice", "alice-pass")) + assert.True(t, target.Valid("bob", "bob-pass")) +} diff --git a/pkg/httpproxy/httpproxy.go b/pkg/httpproxy/httpproxy.go index db4f6ca..a2dbf1b 100644 --- a/pkg/httpproxy/httpproxy.go +++ b/pkg/httpproxy/httpproxy.go @@ -1,18 +1,24 @@ package httpproxy import ( + "bufio" + "crypto/tls" "encoding/base64" + "errors" "fmt" "io" "net" "net/http" + "net/url" "os" + "strconv" "strings" "time" "github.com/rs/zerolog" "github.com/ryanbekhen/nanoproxy/pkg/credential" "github.com/ryanbekhen/nanoproxy/pkg/resolver" + "github.com/ryanbekhen/nanoproxy/pkg/traffic" ) var hopHeaders = []string{ @@ -34,12 +40,19 @@ type Config struct { ClientConnTimeout time.Duration Dial func(network, addr string) (net.Conn, error) Resolver resolver.Resolver + Tracker *traffic.Tracker } type Server struct { config *Config } +var ( + ErrMissingProxyAuthorization = errors.New("missing proxy authorization header") + ErrInvalidProxyAuthorization = errors.New("invalid proxy authorization header") + ErrInvalidProxyCredentials = errors.New("invalid credentials") +) + func New(conf *Config) *Server { if conf.Logger == nil { logger := zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.RFC3339}).With().Timestamp().Logger() @@ -73,53 +86,67 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (s *Server) authenticateRequest(r *http.Request) bool { +func (s *Server) authenticateRequest(r *http.Request) (string, error) { if s.config.Credentials == nil { - return true + return "anonymous", nil } authHeader := r.Header.Get("Proxy-Authorization") if authHeader == "" { - return false + return "", ErrMissingProxyAuthorization } if strings.HasPrefix(authHeader, "Basic ") { encodedCreds := strings.TrimPrefix(authHeader, "Basic ") decoded, err := base64.StdEncoding.DecodeString(encodedCreds) if err != nil { - return false + return "", fmt.Errorf("%w: %v", ErrInvalidProxyAuthorization, err) } parts := strings.SplitN(string(decoded), ":", 2) if len(parts) != 2 { - return false + return "", ErrInvalidProxyAuthorization } username, password := parts[0], parts[1] - return s.config.Credentials.Valid(username, password) + if s.config.Credentials.Valid(username, password) { + return username, nil + } + return "", ErrInvalidProxyCredentials } - return false + return "", ErrInvalidProxyAuthorization } func (s *Server) handleConnect(w http.ResponseWriter, r *http.Request) { - if !s.authenticateRequest(r) { - s.config.Logger.Error(). - Str("client_addr", r.RemoteAddr). - Msg("Unauthorized CONNECT request") + requestLogger := s.requestLogger(r) + username, err := s.authenticateRequest(r) + if err != nil { + requestLogger.Error(). + Err(err). + Msg("proxy authentication failed") w.Header().Set("Proxy-Authenticate", "Basic realm=\"Restricted area\"") http.Error(w, "Proxy authentication required or unauthorized", http.StatusProxyAuthRequired) return } + requestLogger = requestLogger.With().Str("username", username).Str("dest_addr", r.Host).Logger() + if s.config.Credentials != nil { + requestLogger.Debug().Msg("proxy authentication succeeded") + } else { + requestLogger.Debug().Msg("connect request accepted without authentication") + } + session := s.startSession(username, r.RemoteAddr) + defer session.Close() + startTime := time.Now() + requestLogger.Debug().Msg("dialing connect target") serverConn, err := s.config.Dial("tcp", r.Host) latency := time.Since(startTime).Milliseconds() if err != nil { - s.config.Logger.Error(). - Str("client_addr", r.RemoteAddr). - Str("dest_addr", r.Host). + requestLogger.Error(). Str("latency", fmt.Sprintf("%dms", latency)). - Msg("CONNECT failed") + Err(err). + Msg("connect failed") http.Error(w, "Service unavailable", http.StatusServiceUnavailable) return } @@ -127,9 +154,9 @@ func (s *Server) handleConnect(w http.ResponseWriter, r *http.Request) { clientConn, _, err := w.(http.Hijacker).Hijack() if err != nil { - s.config.Logger.Error(). - Str("client_addr", r.RemoteAddr). - Msg("Failed to hijack client connection") + requestLogger.Error(). + Err(err). + Msg("failed to hijack client connection") http.Error(w, "Service unavailable", http.StatusServiceUnavailable) return } @@ -137,84 +164,112 @@ func (s *Server) handleConnect(w http.ResponseWriter, r *http.Request) { _, _ = clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")) - s.config.Logger.Info(). - Str("client_addr", r.RemoteAddr). - Str("dest_addr", r.Host). - Str("latency", fmt.Sprintf("%dms", latency)). - Msg("CONNECT request completed") - + uploadCh := make(chan struct{}, 1) go func() { - _, _ = io.Copy(serverConn, clientConn) + n, _ := io.Copy(serverConn, clientConn) + session.AddUpload(n) + uploadCh <- struct{}{} }() - _, _ = io.Copy(clientConn, serverConn) + + n, _ := io.Copy(clientConn, serverConn) + session.AddDownload(n) + <-uploadCh + + requestLogger.Info(). + Str("latency", time.Since(startTime).Round(time.Millisecond).String()). + Uint64("upload_bytes", session.UploadBytes()). + Uint64("download_bytes", session.DownloadBytes()). + Msg("connect completed") } func (s *Server) handleHTTP(w http.ResponseWriter, r *http.Request) { - if !s.authenticateRequest(r) { - s.config.Logger.Error(). - Str("client_addr", r.RemoteAddr). - Msg("Unauthorized HTTP request") + requestLogger := s.requestLogger(r) + username, err := s.authenticateRequest(r) + if err != nil { + requestLogger.Error(). + Err(err). + Msg("proxy authentication failed") w.Header().Set("Proxy-Authenticate", "Basic realm=\"Restricted area\"") http.Error(w, "Proxy authentication required or unauthorized", http.StatusProxyAuthRequired) return } + requestLogger = requestLogger.With().Str("username", username).Logger() + if s.config.Credentials != nil { + requestLogger.Debug().Msg("proxy authentication succeeded") + } else { + requestLogger.Debug().Msg("http request accepted without authentication") + } + session := s.startSession(username, r.RemoteAddr) + defer session.Close() startTime := time.Now() - clientIP := r.RemoteAddr - if !strings.HasPrefix(r.URL.Scheme, "http") { - s.config.Logger.Error(). - Str("client_addr", clientIP). + targetURL, err := normalizeProxyTargetURL(r) + if err != nil { + requestLogger.Error(). Str("dest_addr", r.URL.String()). - Msg("Invalid URL scheme") - http.Error(w, "Invalid URL scheme", http.StatusBadRequest) + Err(err). + Msg("invalid proxy target url") + http.Error(w, "Invalid target URL", http.StatusBadRequest) return } + requestLogger = requestLogger.With().Str("dest_addr", targetURL.String()).Logger() - proxyReq, err := http.NewRequest(r.Method, r.URL.String(), r.Body) + proxyReqBody := &countingReadCloser{ + ReadCloser: r.Body, + onRead: func(n int64) { + session.AddUpload(n) + }, + } + + resolvedAddr, err := resolveProxyTargetAddr(targetURL, s.config.Resolver) if err != nil { latency := time.Since(startTime).Milliseconds() - s.config.Logger.Error(). - Str("client_addr", clientIP). - Str("dest_addr", r.URL.String()). + requestLogger.Error(). Str("latency", fmt.Sprintf("%dms", latency)). - Msg("Failed to create request - Internal Server Error") - http.Error(w, "Internal server error while creating request", http.StatusInternalServerError) + Err(err). + Msg("failed to resolve target host") + http.Error(w, "Bad gateway: failed to resolve target host", http.StatusBadGateway) return } + requestLogger.Debug().Str("resolved_addr", resolvedAddr).Msg("resolved proxy target") - for key, values := range r.Header { - if isHopHeader(key) { - continue - } - - for _, value := range values { - proxyReq.Header.Add(key, value) - } + serverConn, err := dialProxyTarget(targetURL, resolvedAddr, s.config.Dial, s.config.ClientConnTimeout) + if err != nil { + latency := time.Since(startTime).Milliseconds() + requestLogger.Error(). + Str("latency", fmt.Sprintf("%dms", latency)). + Err(err). + Msg("failed to connect to target") + http.Error(w, "Bad gateway: failed to send request", http.StatusBadGateway) + return } + defer serverConn.Close() - client := &http.Client{ - Timeout: s.config.ClientConnTimeout, + proxyReq := buildOutboundProxyRequest(r, targetURL, proxyReqBody) + requestLogger.Debug().Msg("forwarding proxy request") + if err := proxyReq.Write(serverConn); err != nil { + latency := time.Since(startTime).Milliseconds() + requestLogger.Error(). + Str("latency", fmt.Sprintf("%dms", latency)). + Err(err). + Msg("failed to send request") + http.Error(w, "Bad gateway: failed to send request", http.StatusBadGateway) + return } - resp, err := client.Do(proxyReq) + + resp, err := http.ReadResponse(bufio.NewReader(serverConn), proxyReq) latency := time.Since(startTime).Milliseconds() if err != nil { - s.config.Logger.Error(). - Str("client_addr", clientIP). - Str("dest_addr", r.URL.String()). + requestLogger.Error(). Str("latency", fmt.Sprintf("%dms", latency)). - Msg("Failed to send request - Bad Gateway") - http.Error(w, "Bad gateway: failed to send request", http.StatusBadGateway) + Err(err). + Msg("failed to read response") + http.Error(w, "Bad gateway: failed to read response", http.StatusBadGateway) return } defer resp.Body.Close() - s.config.Logger.Info(). - Str("client_addr", clientIP). - Str("dest_addr", r.URL.String()). - Str("latency", fmt.Sprintf("%dms", latency)). - Msg("HTTP request successfully proxied") - for _, key := range hopHeaders { resp.Header.Del(key) } @@ -226,7 +281,193 @@ func (s *Server) handleHTTP(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(resp.StatusCode) - _, _ = io.Copy(w, resp.Body) + n, _ := io.Copy(w, resp.Body) + session.AddDownload(n) + + requestLogger.Info(). + Int("status_code", resp.StatusCode). + Str("latency", time.Since(startTime).Round(time.Millisecond).String()). + Uint64("upload_bytes", session.UploadBytes()). + Uint64("download_bytes", session.DownloadBytes()). + Msg("request completed") +} + +func (s *Server) startSession(username, remoteAddr string) *traffic.Session { + if s.config.Tracker == nil { + return nil + } + return s.config.Tracker.Start(username, extractClientIP(remoteAddr)) +} + +func extractClientIP(remoteAddr string) string { + host, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + return remoteAddr + } + return host +} + +func (s *Server) requestLogger(r *http.Request) zerolog.Logger { + logger := s.config.Logger.With().Str("protocol", "http") + if r != nil { + logger = logger.Str("http_method", r.Method) + if r.RemoteAddr != "" { + logger = logger.Str("client_addr", r.RemoteAddr) + } + } + return logger.Logger() +} + +func resolveProxyTargetAddr(targetURL *url.URL, res resolver.Resolver) (string, error) { + hostname := targetURL.Hostname() + port := targetURL.Port() + if port == "" { + if targetURL.Scheme == "https" { + port = "443" + } else { + port = "80" + } + } + + var ipStr string + // If hostname is already a valid IP address, use it directly without DNS. + if ip := net.ParseIP(hostname); ip != nil { + ipStr = ip.String() + } else { + resolved, err := res.Resolve(hostname) + if err != nil { + return "", fmt.Errorf("resolve %q: %w", hostname, err) + } + ipStr = resolved.String() + } + + return net.JoinHostPort(ipStr, port), nil +} + +func normalizeProxyTargetURL(r *http.Request) (*url.URL, error) { + if r == nil || r.URL == nil { + return nil, fmt.Errorf("missing target url") + } + + rawURL := r.URL + scheme := strings.ToLower(strings.TrimSpace(rawURL.Scheme)) + if scheme != "http" && scheme != "https" { + return nil, fmt.Errorf("unsupported url scheme: %q", rawURL.Scheme) + } + + if rawURL.User != nil { + return nil, fmt.Errorf("userinfo is not allowed") + } + + host := strings.TrimSpace(rawURL.Host) + if host == "" { + host = strings.TrimSpace(r.Host) + } + if host == "" { + return nil, fmt.Errorf("missing url host") + } + + hostname := host + if strings.Contains(host, ":") { + h, port, err := net.SplitHostPort(host) + if err != nil { + return nil, fmt.Errorf("invalid host:port") + } + if port != "" { + p, err := strconv.Atoi(port) + if err != nil || p < 1 || p > 65535 { + return nil, fmt.Errorf("invalid port") + } + } + hostname = h + } + + if hostname == "" { + return nil, fmt.Errorf("missing hostname") + } + + normalized := &url.URL{ + Scheme: scheme, + Host: host, + Path: rawURL.EscapedPath(), + RawPath: rawURL.RawPath, + RawQuery: rawURL.RawQuery, + } + + if normalized.Path == "" { + normalized.Path = "/" + } + + return normalized, nil +} + +func buildOutboundProxyRequest(r *http.Request, targetURL *url.URL, body io.ReadCloser) *http.Request { + proxyReq := r.Clone(r.Context()) + proxyReq.URL = &url.URL{ + Path: targetURL.Path, + RawPath: targetURL.RawPath, + RawQuery: targetURL.RawQuery, + } + proxyReq.Host = targetURL.Host + proxyReq.RequestURI = "" + proxyReq.Body = body + proxyReq.Close = true + proxyReq.Header = make(http.Header, len(r.Header)) + + for key, values := range r.Header { + if isHopHeader(key) { + continue + } + + for _, value := range values { + proxyReq.Header.Add(key, value) + } + } + + return proxyReq +} + +func dialProxyTarget(targetURL *url.URL, resolvedAddr string, dial func(network, addr string) (net.Conn, error), timeout time.Duration) (net.Conn, error) { + conn, err := dial("tcp", resolvedAddr) + if err != nil { + return nil, err + } + + if timeout > 0 { + _ = conn.SetDeadline(time.Now().Add(timeout)) + } + + if targetURL.Scheme != "https" { + return conn, nil + } + + tlsConn := tls.Client(conn, &tls.Config{ + ServerName: targetURL.Hostname(), + MinVersion: tls.VersionTLS12, + }) + if err := tlsConn.Handshake(); err != nil { + _ = conn.Close() + return nil, err + } + + if timeout > 0 { + _ = tlsConn.SetDeadline(time.Now().Add(timeout)) + } + + return tlsConn, nil +} + +type countingReadCloser struct { + io.ReadCloser + onRead func(n int64) +} + +func (c *countingReadCloser) Read(p []byte) (int, error) { + n, err := c.ReadCloser.Read(p) + if n > 0 && c.onRead != nil { + c.onRead(int64(n)) + } + return n, err } func isHopHeader(header string) bool { diff --git a/pkg/httpproxy/httpproxy_test.go b/pkg/httpproxy/httpproxy_test.go index baed0e9..3fbe765 100644 --- a/pkg/httpproxy/httpproxy_test.go +++ b/pkg/httpproxy/httpproxy_test.go @@ -2,6 +2,7 @@ package httpproxy import ( "bufio" + "bytes" "encoding/base64" "encoding/json" "errors" @@ -10,10 +11,12 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "time" "github.com/rs/zerolog" + "github.com/ryanbekhen/nanoproxy/pkg/traffic" "github.com/stretchr/testify/assert" ) @@ -36,6 +39,12 @@ func (m *MockResolver) Resolve(host string) (net.IP, error) { return nil, errors.New("host not found") } +type resolverFunc func(host string) (net.IP, error) + +func (f resolverFunc) Resolve(host string) (net.IP, error) { + return f(host) +} + type MockNetConn struct{} func (m *MockNetConn) Read(b []byte) (n int, err error) { @@ -56,6 +65,17 @@ func (m *MockNetConn) SetDeadline(t time.Time) error { return nil } func (m *MockNetConn) SetReadDeadline(t time.Time) error { return nil } func (m *MockNetConn) SetWriteDeadline(t time.Time) error { return nil } +type writeFailConn struct{} + +func (c *writeFailConn) Read(_ []byte) (int, error) { return 0, io.EOF } +func (c *writeFailConn) Write(_ []byte) (int, error) { return 0, errors.New("write failed") } +func (c *writeFailConn) Close() error { return nil } +func (c *writeFailConn) LocalAddr() net.Addr { return nil } +func (c *writeFailConn) RemoteAddr() net.Addr { return nil } +func (c *writeFailConn) SetDeadline(_ time.Time) error { return nil } +func (c *writeFailConn) SetReadDeadline(_ time.Time) error { return nil } +func (c *writeFailConn) SetWriteDeadline(_ time.Time) error { return nil } + type MockHijacker struct { *httptest.ResponseRecorder } @@ -66,6 +86,43 @@ func (m *MockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { return mockConn, buf, nil } +func parseJSONLogLine(t *testing.T, buf *bytes.Buffer) map[string]interface{} { + t.Helper() + + entries := parseJSONLogLines(t, buf) + if len(entries) == 0 { + t.Fatal("expected log output") + } + + return entries[len(entries)-1] +} + +func parseJSONLogLines(t *testing.T, buf *bytes.Buffer) []map[string]interface{} { + t.Helper() + + content := strings.TrimSpace(buf.String()) + if content == "" { + return nil + } + + lines := strings.Split(content, "\n") + entries := make([]map[string]interface{}, 0, len(lines)) + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + + entry := map[string]interface{}{} + if err := json.Unmarshal([]byte(line), &entry); err != nil { + t.Fatalf("failed to parse log entry: %v", err) + } + entries = append(entries, entry) + } + + return entries +} + func TestServer_ServeHTTP(t *testing.T) { logger := zerolog.New(io.Discard) mockCredentials := &MockCredentialStore{} @@ -117,7 +174,7 @@ func TestServer_ServeHTTP(t *testing.T) { server.ServeHTTP(rr, req) assert.Equal(t, http.StatusBadGateway, rr.Code) - assert.Contains(t, rr.Body.String(), "Bad gateway: failed to send request") + assert.Contains(t, rr.Body.String(), "Bad gateway: failed to resolve target host") }) } @@ -168,6 +225,158 @@ func TestServer_HandleCONNECT(t *testing.T) { }) } +func TestServer_HandleCONNECT_LogsStructuredAuthFailure(t *testing.T) { + var logBuf bytes.Buffer + logger := zerolog.New(&logBuf) + server := New(&Config{ + Credentials: &MockCredentialStore{}, + Logger: &logger, + }) + + req := httptest.NewRequest(http.MethodConnect, "http://example.com", nil) + req.RemoteAddr = "202.65.229.173:50059" + rr := httptest.NewRecorder() + + server.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusProxyAuthRequired, rr.Code) + entry := parseJSONLogLine(t, &logBuf) + assert.Equal(t, "proxy authentication failed", entry["message"]) + assert.Equal(t, "http", entry["protocol"]) + assert.Equal(t, http.MethodConnect, entry["http_method"]) + assert.Equal(t, "202.65.229.173:50059", entry["client_addr"]) + assert.Equal(t, ErrMissingProxyAuthorization.Error(), entry["error"]) + assert.Equal(t, "error", entry["level"]) +} + +func TestServer_HandleHTTP_LogsStructuredAuthFailure(t *testing.T) { + var logBuf bytes.Buffer + logger := zerolog.New(&logBuf) + server := New(&Config{ + Credentials: &MockCredentialStore{}, + Logger: &logger, + }) + + req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) + req.RemoteAddr = "202.65.229.173:51655" + rr := httptest.NewRecorder() + + server.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusProxyAuthRequired, rr.Code) + entry := parseJSONLogLine(t, &logBuf) + assert.Equal(t, "proxy authentication failed", entry["message"]) + assert.Equal(t, "http", entry["protocol"]) + assert.Equal(t, http.MethodGet, entry["http_method"]) + assert.Equal(t, "202.65.229.173:51655", entry["client_addr"]) + assert.Equal(t, ErrMissingProxyAuthorization.Error(), entry["error"]) + assert.Equal(t, "error", entry["level"]) +} + +func TestServer_HandleCONNECT_LogsStructuredDialFailure(t *testing.T) { + var logBuf bytes.Buffer + logger := zerolog.New(&logBuf) + server := New(&Config{ + Credentials: &MockCredentialStore{}, + Logger: &logger, + Dial: func(network, addr string) (net.Conn, error) { + return nil, errors.New("dial failed") + }, + }) + + req := httptest.NewRequest(http.MethodConnect, "http://example.com", nil) + req.Host = "example.com:443" + req.RemoteAddr = "202.65.229.173:50059" + req.Header.Set("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("user:password"))) + rr := httptest.NewRecorder() + + server.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusServiceUnavailable, rr.Code) + entry := parseJSONLogLine(t, &logBuf) + assert.Equal(t, "connect failed", entry["message"]) + assert.Equal(t, "http", entry["protocol"]) + assert.Equal(t, http.MethodConnect, entry["http_method"]) + assert.Equal(t, "202.65.229.173:50059", entry["client_addr"]) + assert.Equal(t, "example.com:443", entry["dest_addr"]) + assert.Equal(t, "dial failed", entry["error"]) + assert.NotEmpty(t, entry["latency"]) + assert.Equal(t, "error", entry["level"]) +} + +func TestServer_HandleHTTP_LogsStructuredSuccessAtInfo(t *testing.T) { + targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer targetServer.Close() + + var logBuf bytes.Buffer + logger := zerolog.New(&logBuf).Level(zerolog.InfoLevel) + server := New(&Config{ + Logger: &logger, + Dial: net.Dial, + Tracker: traffic.NewTracker(), + ClientConnTimeout: 2 * time.Second, + }) + + req := httptest.NewRequest(http.MethodGet, targetServer.URL, nil) + req.RemoteAddr = "202.65.229.173:51655" + rr := httptest.NewRecorder() + + server.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + entry := parseJSONLogLine(t, &logBuf) + assert.Equal(t, "request completed", entry["message"]) + assert.Equal(t, "info", entry["level"]) + assert.Equal(t, "http", entry["protocol"]) + assert.Equal(t, http.MethodGet, entry["http_method"]) + assert.Equal(t, "anonymous", entry["username"]) + assert.Equal(t, float64(http.StatusOK), entry["status_code"]) + assert.NotEmpty(t, entry["latency"]) + assert.Equal(t, float64(0), entry["upload_bytes"]) + assert.Equal(t, float64(2), entry["download_bytes"]) +} + +func TestServer_HandleHTTP_LogsDebugResolutionDetails(t *testing.T) { + targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer targetServer.Close() + + targetURL, err := url.Parse(targetServer.URL) + assert.NoError(t, err) + fakeTargetURL := "http://debug-target.test:" + targetURL.Port() + "/" + + var logBuf bytes.Buffer + logger := zerolog.New(&logBuf).Level(zerolog.DebugLevel) + server := New(&Config{ + Logger: &logger, + Dial: net.Dial, + Tracker: traffic.NewTracker(), + Resolver: resolverFunc(func(host string) (net.IP, error) { + return net.ParseIP(targetURL.Hostname()), nil + }), + ClientConnTimeout: 2 * time.Second, + }) + + req := httptest.NewRequest(http.MethodGet, fakeTargetURL, nil) + req.RemoteAddr = "202.65.229.173:51655" + rr := httptest.NewRecorder() + + server.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusNoContent, rr.Code) + entries := parseJSONLogLines(t, &logBuf) + assert.GreaterOrEqual(t, len(entries), 3) + assert.Equal(t, "http request accepted without authentication", entries[0]["message"]) + assert.Equal(t, "debug", entries[0]["level"]) + assert.Equal(t, "resolved proxy target", entries[1]["message"]) + assert.Equal(t, "debug", entries[1]["level"]) + assert.NotEmpty(t, entries[1]["resolved_addr"]) +} + func TestProxy_ForwardRequests(t *testing.T) { targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Empty(t, r.Header.Get("Connection")) @@ -236,7 +445,7 @@ func TestServer_HandleHTTP_WithProxyRequest(t *testing.T) { t.Run("Forward HTTP request successfully", func(t *testing.T) { clientReq, err := http.NewRequest(http.MethodGet, targetURL+"/anything", nil) if err != nil { - t.Fatalf("Gagal membuat request: %v", err) + t.Fatalf("failed to create request: %v", err) } clientReq.Header.Set("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("user:password"))) @@ -255,7 +464,7 @@ func TestServer_HandleHTTP_WithProxyRequest(t *testing.T) { resp, err := proxyClient.Do(clientReq) assert.NoError(t, err) if err != nil { - t.Fatalf("[ERROR] Proxy client mengalami error: %v", err) + t.Fatalf("proxy client returned an error: %v", err) } assert.Equal(t, http.StatusOK, resp.StatusCode) @@ -287,13 +496,13 @@ func TestServer_HandleHTTP_InvalidURLScheme(t *testing.T) { }) t.Run("Invalid URL Scheme", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "ftp://example.com", nil) // Skema tidak valid (ftp) + req := httptest.NewRequest(http.MethodGet, "ftp://example.com", nil) // Invalid scheme (ftp) rr := httptest.NewRecorder() server.ServeHTTP(rr, req) - assert.Equal(t, http.StatusBadRequest, rr.Code) // Memastikan statusnya Bad Request - assert.Contains(t, rr.Body.String(), "Invalid URL scheme") // Memastikan pesan error sesuai + assert.Equal(t, http.StatusBadRequest, rr.Code) // Verify the response status is Bad Request. + assert.Contains(t, rr.Body.String(), "Invalid target URL") // Verify the expected error message. }) } @@ -306,7 +515,7 @@ func TestServer_HandleHTTP_ClientDoError(t *testing.T) { }) t.Run("Failed to resolve DNS", func(t *testing.T) { - // Membuat permintaan HTTP Proxy + // Create a proxied HTTP request. proxyReq := httptest.NewRequest(http.MethodGet, "http://unreachablehost", nil) proxyReq.Header.Set("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("user:password"))) @@ -315,6 +524,206 @@ func TestServer_HandleHTTP_ClientDoError(t *testing.T) { server.ServeHTTP(rr, proxyReq) assert.Equal(t, http.StatusBadGateway, rr.Code) - assert.Contains(t, rr.Body.String(), "Bad gateway: failed to send request") + assert.Contains(t, rr.Body.String(), "Bad gateway: failed to resolve target host") + }) +} + +func TestNormalizeProxyTargetURL(t *testing.T) { + t.Run("Rejects nil request", func(t *testing.T) { + _, err := normalizeProxyTargetURL(nil) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing target url") + }) + + t.Run("Rejects missing host", func(t *testing.T) { + req := &http.Request{URL: &url.URL{Scheme: "http", Path: "/"}} + + _, err := normalizeProxyTargetURL(req) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing url host") + }) + + t.Run("Uses request host when URL host is empty", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://placeholder", nil) + req.URL.Host = "" + req.URL.Path = "/status" + req.URL.RawQuery = "check=1" + req.Host = "example.com:8080" + + targetURL, err := normalizeProxyTargetURL(req) + + assert.NoError(t, err) + assert.Equal(t, "http", targetURL.Scheme) + assert.Equal(t, "example.com:8080", targetURL.Host) + assert.Equal(t, "/status", targetURL.Path) + assert.Equal(t, "check=1", targetURL.RawQuery) + }) + + t.Run("Rejects userinfo", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://user:pass@example.com/private", nil) + + _, err := normalizeProxyTargetURL(req) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "userinfo") + }) +} + +func TestDialProxyTarget(t *testing.T) { + t.Run("Returns plain TCP conn for HTTP", func(t *testing.T) { + fakeConn := &MockNetConn{} + + conn, err := dialProxyTarget(&url.URL{Scheme: "http", Host: "example.com"}, "127.0.0.1:80", func(network, addr string) (net.Conn, error) { + assert.Equal(t, "tcp", network) + assert.Equal(t, "127.0.0.1:80", addr) + return fakeConn, nil + }, time.Second) + + assert.NoError(t, err) + assert.Same(t, fakeConn, conn) + }) + + t.Run("Returns handshake error for HTTPS", func(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer serverConn.Close() + + go func() { + // Trigger TLS handshake failure on client side. + _, _ = serverConn.Write([]byte("not-tls")) + _ = serverConn.Close() + }() + + _, err := dialProxyTarget(&url.URL{Scheme: "https", Host: "example.com"}, "127.0.0.1:443", func(network, addr string) (net.Conn, error) { + return clientConn, nil + }, time.Second) + + assert.Error(t, err) + }) +} + +func TestResolveProxyTargetAddr(t *testing.T) { + t.Run("Uses resolver result with default HTTPS port", func(t *testing.T) { + targetURL := &url.URL{Scheme: "https", Host: "validhost.com"} + + addr, err := resolveProxyTargetAddr(targetURL, resolverFunc(func(host string) (net.IP, error) { + assert.Equal(t, "validhost.com", host) + return net.ParseIP("203.0.113.10"), nil + })) + + assert.NoError(t, err) + assert.Equal(t, "203.0.113.10:443", addr) }) + + t.Run("Keeps literal IP addresses without DNS lookup", func(t *testing.T) { + targetURL := &url.URL{Scheme: "http", Host: "127.0.0.1:9000"} + + addr, err := resolveProxyTargetAddr(targetURL, resolverFunc(func(host string) (net.IP, error) { + t.Fatalf("resolver should not be called for literal IPs") + return nil, nil + })) + + assert.NoError(t, err) + assert.Equal(t, "127.0.0.1:9000", addr) + }) +} + +func TestBuildOutboundProxyRequest(t *testing.T) { + requestBody := io.NopCloser(strings.NewReader("payload")) + incomingReq := httptest.NewRequest(http.MethodPost, "http://example.com/original?trace=1", strings.NewReader("ignored")) + incomingReq.Header.Set("Connection", "keep-alive") + incomingReq.Header.Set("Proxy-Authorization", "Basic dXNlcjpwYXNz") + incomingReq.Header.Set("X-Test-Header", "ok") + targetURL := &url.URL{Host: "example.com:8080", Path: "/rewritten", RawQuery: "trace=1"} + + proxyReq := buildOutboundProxyRequest(incomingReq, targetURL, requestBody) + + assert.Equal(t, "", proxyReq.RequestURI) + assert.Equal(t, "example.com:8080", proxyReq.Host) + assert.Equal(t, "/rewritten?trace=1", proxyReq.URL.RequestURI()) + assert.Equal(t, "ok", proxyReq.Header.Get("X-Test-Header")) + assert.Empty(t, proxyReq.Header.Get("Connection")) + assert.Empty(t, proxyReq.Header.Get("Proxy-Authorization")) +} + +func TestServer_HandleHTTP_ReadResponseError(t *testing.T) { + logger := zerolog.New(io.Discard) + + server := New(&Config{ + Logger: &logger, + ClientConnTimeout: 2 * time.Second, + Resolver: resolverFunc(func(host string) (net.IP, error) { + assert.Equal(t, "example.com", host) + return net.ParseIP("127.0.0.1"), nil + }), + Dial: func(network, addr string) (net.Conn, error) { + clientConn, serverConn := net.Pipe() + go func() { + defer serverConn.Close() + _, _ = io.Copy(io.Discard, serverConn) + }() + return clientConn, nil + }, + }) + + proxyReq := httptest.NewRequest(http.MethodGet, "http://example.com/health", nil) + rr := httptest.NewRecorder() + + server.ServeHTTP(rr, proxyReq) + + assert.Equal(t, http.StatusBadGateway, rr.Code) + assert.Contains(t, rr.Body.String(), "Bad gateway: failed to read response") +} + +func TestServer_HandleHTTP_DialTargetError(t *testing.T) { + logger := zerolog.New(io.Discard) + + server := New(&Config{ + Logger: &logger, + ClientConnTimeout: 2 * time.Second, + Resolver: resolverFunc(func(host string) (net.IP, error) { + assert.Equal(t, "example.com", host) + return net.ParseIP("127.0.0.1"), nil + }), + Dial: func(network, addr string) (net.Conn, error) { + assert.Equal(t, "tcp", network) + assert.Equal(t, "127.0.0.1:80", addr) + return nil, errors.New("dial failed") + }, + }) + + proxyReq := httptest.NewRequest(http.MethodGet, "http://example.com/health", nil) + rr := httptest.NewRecorder() + + server.ServeHTTP(rr, proxyReq) + + assert.Equal(t, http.StatusBadGateway, rr.Code) + assert.Contains(t, rr.Body.String(), "Bad gateway: failed to send request") +} + +func TestServer_HandleHTTP_WriteRequestError(t *testing.T) { + logger := zerolog.New(io.Discard) + + server := New(&Config{ + Logger: &logger, + ClientConnTimeout: 2 * time.Second, + Resolver: resolverFunc(func(host string) (net.IP, error) { + assert.Equal(t, "example.com", host) + return net.ParseIP("127.0.0.1"), nil + }), + Dial: func(network, addr string) (net.Conn, error) { + assert.Equal(t, "tcp", network) + assert.Equal(t, "127.0.0.1:80", addr) + return &writeFailConn{}, nil + }, + }) + + proxyReq := httptest.NewRequest(http.MethodGet, "http://example.com/health", nil) + rr := httptest.NewRecorder() + + server.ServeHTTP(rr, proxyReq) + + assert.Equal(t, http.StatusBadGateway, rr.Code) + assert.Contains(t, rr.Body.String(), "Bad gateway: failed to send request") } diff --git a/pkg/socks5/request.go b/pkg/socks5/request.go index 5f5235c..1e3e1f7 100644 --- a/pkg/socks5/request.go +++ b/pkg/socks5/request.go @@ -53,6 +53,17 @@ type Request struct { Latency time.Duration } +var socks5DomainLengthOctets = func() [256]byte { + var lookup [256]byte + for i := uint8(0); ; i++ { + lookup[i] = i + if i == 255 { + break + } + } + return lookup +}() + func NewRequest(bufferConn io.Reader) (*Request, error) { var header [3]byte if _, err := io.ReadFull(bufferConn, header[:]); err != nil { @@ -142,7 +153,11 @@ func sendReply(conn io.Writer, reply uint8, addr *AddrSpec) error { case addr.FQDN != "": addrType = AddressTypeDomain.Uint8() - addrBody = append([]byte{byte(len(addr.FQDN))}, addr.FQDN...) + if len(addr.FQDN) > 255 { + return fmt.Errorf("fqdn too long for socks5 domain field: %d", len(addr.FQDN)) + } + fqdnLen := socks5DomainLengthOctets[len(addr.FQDN)] + addrBody = append([]byte{fqdnLen}, addr.FQDN...) if addr.Port < 0 || addr.Port > 65535 { return fmt.Errorf("port value out of range uint16: %d", addr.Port) } @@ -191,3 +206,14 @@ func relay(dst io.Writer, src io.Reader, errCh chan error) { } errCh <- err } + +func relayWithCount(dst io.Writer, src io.Reader, errCh chan error, onBytes func(int64)) { + n, err := io.Copy(dst, src) + if onBytes != nil && n > 0 { + onBytes(n) + } + if tcpConn, ok := dst.(*net.TCPConn); ok { + _ = tcpConn.CloseWrite() + } + errCh <- err +} diff --git a/pkg/socks5/request_test.go b/pkg/socks5/request_test.go index 5603d07..8cbf09a 100644 --- a/pkg/socks5/request_test.go +++ b/pkg/socks5/request_test.go @@ -111,7 +111,7 @@ func Test_NewRequest(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, req) - err = s.handleRequest(req, resp) + err = s.testHandleRequest(req, resp) assert.NoError(t, err) // verify the response diff --git a/pkg/socks5/socks5.go b/pkg/socks5/socks5.go index a5f5f05..db11768 100644 --- a/pkg/socks5/socks5.go +++ b/pkg/socks5/socks5.go @@ -4,14 +4,17 @@ import ( "bufio" "errors" "fmt" + "io" "net" "os" "strings" + "syscall" "time" "github.com/rs/zerolog" "github.com/ryanbekhen/nanoproxy/pkg/credential" "github.com/ryanbekhen/nanoproxy/pkg/resolver" + "github.com/ryanbekhen/nanoproxy/pkg/traffic" ) type Config struct { @@ -24,6 +27,7 @@ type Config struct { AfterRequest func(req *Request, conn net.Conn) Resolver resolver.Resolver Rewriter AddressRewriter + Tracker *traffic.Tracker } type Server struct { @@ -103,66 +107,90 @@ func (s *Server) serve(l net.Listener) error { } func (s *Server) handleConnection(conn net.Conn) { + connLogger := s.connectionLogger(conn) defer func(conn net.Conn) { if err := conn.Close(); err != nil { - s.config.Logger.Error().Err(err).Msg("failed to close connection") + connLogger.Error().Err(err).Msg("failed to close connection") } }(conn) connectionBuffer := bufio.NewReader(conn) // Set a deadline for the connection if err := conn.SetDeadline(time.Now().Add(s.config.ClientConnTimeout)); err != nil { - s.config.Logger.Err(err).Msg("failed to set connection deadline") + connLogger.Error().Err(err).Msg("failed to set connection deadline") return } // Read the version byte version := []byte{0} if _, err := connectionBuffer.Read(version); err != nil { - s.config.Logger.Err(err).Msg("failed to read version byte") + if shouldLogRequestError(err) { + connLogger.Error().Err(err).Msg("failed to read version byte") + } return } // Ensure we are compatible if version[0] != Version { - s.config.Logger.Error().Msg("unsupported version") + connLogger.Error().Uint8("version", version[0]).Msg("unsupported version") return } // Authenticate authContext, err := s.authenticate(conn, connectionBuffer) if err != nil { - s.config.Logger.Err(err).Msg("SOCKS5 authentication failed") + if shouldLogRequestError(err) { + connLogger.Error().Err(err).Msg("proxy authentication failed") + } return } + username := usernameFromAuthContext(authContext) + connLogger = connLogger.With().Str("username", username).Logger() + if s.config.Credentials != nil { + connLogger.Debug().Msg("proxy authentication succeeded") + } else { + connLogger.Debug().Msg("connection accepted without authentication") + } request, err := NewRequest(connectionBuffer) if err != nil { if errors.Is(err, ErrUnrecognizedAddrType) { if err := sendReply(conn, StatusAddressNotSupported.Uint8(), nil); err != nil { - s.config.Logger.Err(err).Msg("failed to send reply") + if shouldLogRequestError(err) { + connLogger.Error().Err(err).Msg("failed to send reply") + } return } } - s.config.Logger.Err(err).Msg("failed to create request") + if shouldLogRequestError(err) { + connLogger.Error().Err(err).Msg("failed to create request") + } return } request.AuthContext = authContext + requestLogger := connLogger.With(). + Str("command", request.Command.String()). + Str("dest_addr", request.DestAddr.String()). + Logger() + requestLogger.Debug().Msg("request received") + trafficSession := s.startTrafficSession(authContext, conn) + defer trafficSession.Close() if clientAddr, ok := conn.RemoteAddr().(*net.TCPAddr); ok { request.RemoteAddr = &AddrSpec{IP: clientAddr.IP, Port: clientAddr.Port} } - if err := s.handleRequest(request, conn); err != nil && - !strings.Contains(err.Error(), "i/o timeout") { - s.config.Logger.Err(err). + if err := s.handleRequest(request, conn, trafficSession, requestLogger); err != nil && + shouldLogRequestError(err) { + requestLogger.Error(). + Err(err). Msg("request failed") - } else { - s.config.Logger.Info(). - Str("client_addr", conn.RemoteAddr().String()). - Str("dest_addr", request.DestAddr.String()). - Str("latency", request.Latency.String()). - Msg("SOCKS5 request completed") + } else if err == nil { + requestLogger.Info(). + Str("latency", request.Latency.Round(time.Millisecond).String()). + Uint64("upload_bytes", trafficSession.UploadBytes()). + Uint64("download_bytes", trafficSession.DownloadBytes()). + Msg("request completed") } if s.config.AfterRequest != nil { @@ -188,17 +216,18 @@ func (s *Server) authenticate(conn net.Conn, bufConn *bufio.Reader) (*Context, e return nil, noAcceptable(conn) } -func (s *Server) handleRequest(req *Request, conn net.Conn) error { +func (s *Server) handleRequest(req *Request, conn net.Conn, trafficSession *traffic.Session, requestLogger zerolog.Logger) error { dest := req.DestAddr if dest.FQDN != "" { addr, err := s.config.Resolver.Resolve(dest.FQDN) if err != nil { if err := sendReply(conn, StatusHostUnreachable.Uint8(), nil); err != nil { - return ErrFailedToSendReply + return fmt.Errorf("%w: %w", ErrFailedToSendReply, err) } return fmt.Errorf("failed to resolve destination: %w", err) } dest.IP = addr + requestLogger.Debug().Str("resolved_ip", addr.String()).Msg("resolved destination address") } req.realAddr = req.DestAddr @@ -208,7 +237,7 @@ func (s *Server) handleRequest(req *Request, conn net.Conn) error { switch req.Command { case CommandConnect: - return s.handleConnect(conn, req) + return s.handleConnect(conn, req, trafficSession, requestLogger) // TODO: Implement these //case CommandBind: // return s.handleBind(conn, req) @@ -216,13 +245,13 @@ func (s *Server) handleRequest(req *Request, conn net.Conn) error { // return s.handleAssociate(conn, req) default: if err := sendReply(conn, StatusCommandNotSupported.Uint8(), nil); err != nil { - return ErrFailedToSendReply + return fmt.Errorf("%w: %w", ErrFailedToSendReply, err) } return fmt.Errorf("unsupported command: %d", req.Command) } } -func (s *Server) handleConnect(conn net.Conn, req *Request) error { +func (s *Server) handleConnect(conn net.Conn, req *Request, trafficSession *traffic.Session, requestLogger zerolog.Logger) error { dial := s.config.Dial if dial == nil { dial = func(network, addr string) (net.Conn, error) { @@ -231,6 +260,7 @@ func (s *Server) handleConnect(conn net.Conn, req *Request) error { } processStartTimestamp := time.Now() + requestLogger.Debug().Msg("dialing destination") dest, err := dial("tcp", req.realAddr.Address()) req.Latency = time.Since(processStartTimestamp) @@ -248,7 +278,7 @@ func (s *Server) handleConnect(conn net.Conn, req *Request) error { } if err := sendReply(conn, resp.Uint8(), nil); err != nil { - return ErrFailedToSendReply + return fmt.Errorf("%w: %w", ErrFailedToSendReply, err) } return errors.New(msg) @@ -260,12 +290,12 @@ func (s *Server) handleConnect(conn net.Conn, req *Request) error { local := dest.LocalAddr().(*net.TCPAddr) bind := AddrSpec{IP: local.IP, Port: local.Port} if err := sendReply(conn, StatusRequestGranted.Uint8(), &bind); err != nil { - return ErrFailedToSendReply + return fmt.Errorf("%w: %w", ErrFailedToSendReply, err) } errChan := make(chan error, 2) - go relay(dest, req.BufferConn, errChan) - go relay(conn, dest, errChan) + go relayWithCount(dest, req.BufferConn, errChan, trafficSession.AddUpload) + go relayWithCount(conn, dest, errChan, trafficSession.AddDownload) for i := 0; i < 2; i++ { if err := <-errChan; err != nil { @@ -275,3 +305,75 @@ func (s *Server) handleConnect(conn net.Conn, req *Request) error { return nil } + +func (s *Server) startTrafficSession(authContext *Context, conn net.Conn) *traffic.Session { + if s.config.Tracker == nil { + return nil + } + username := "anonymous" + if authContext != nil && authContext.Payload != nil { + if v := authContext.Payload["Username"]; v != "" { + username = v + } + } + return s.config.Tracker.Start(username, extractClientIP(conn.RemoteAddr().String())) +} + +func extractClientIP(remoteAddr string) string { + host, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + return remoteAddr + } + return host +} + +func (s *Server) connectionLogger(conn net.Conn) zerolog.Logger { + logger := s.config.Logger.With().Str("protocol", "socks5") + if clientAddr := remoteAddrString(conn); clientAddr != "" { + logger = logger.Str("client_addr", clientAddr) + } + return logger.Logger() +} + +func remoteAddrString(conn net.Conn) string { + if conn == nil || conn.RemoteAddr() == nil { + return "" + } + return conn.RemoteAddr().String() +} + +func usernameFromAuthContext(authContext *Context) string { + if authContext != nil && authContext.Payload != nil { + if username := authContext.Payload["Username"]; username != "" { + return username + } + } + return "anonymous" +} + +func shouldLogRequestError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { + return false + } + if errors.Is(err, syscall.ECONNRESET) || + errors.Is(err, syscall.EPIPE) || + errors.Is(err, syscall.ETIMEDOUT) { + return false + } + + msg := strings.ToLower(err.Error()) + if strings.Contains(msg, "i/o timeout") || + strings.Contains(msg, "broken pipe") || + strings.Contains(msg, "connection reset by peer") || + strings.Contains(msg, "use of closed network connection") || + strings.Contains(msg, "splice:") || + strings.Contains(msg, "readfrom tcp") || + strings.Contains(msg, "writeto tcp") { + return false + } + + return true +} diff --git a/pkg/socks5/socks5_test.go b/pkg/socks5/socks5_test.go index 906ba28..cb2349b 100644 --- a/pkg/socks5/socks5_test.go +++ b/pkg/socks5/socks5_test.go @@ -3,6 +3,7 @@ package socks5 import ( "bytes" "encoding/binary" + "encoding/json" "errors" "fmt" "io" @@ -10,11 +11,83 @@ import ( "testing" "time" + "github.com/rs/zerolog" "github.com/ryanbekhen/nanoproxy/pkg/credential" "github.com/ryanbekhen/nanoproxy/pkg/resolver" + "github.com/ryanbekhen/nanoproxy/pkg/traffic" "github.com/stretchr/testify/assert" ) +type resolverFunc func(host string) (net.IP, error) + +func (f resolverFunc) Resolve(host string) (net.IP, error) { + return f(host) +} + +func parseJSONLogLines(t *testing.T, buf *bytes.Buffer) []map[string]interface{} { + t.Helper() + + content := bytes.TrimSpace(buf.Bytes()) + if len(content) == 0 { + return nil + } + + lines := bytes.Split(content, []byte("\n")) + entries := make([]map[string]interface{}, 0, len(lines)) + for _, line := range lines { + line = bytes.TrimSpace(line) + if len(line) == 0 { + continue + } + entry := map[string]interface{}{} + if err := json.Unmarshal(line, &entry); err != nil { + t.Fatalf("failed to parse log entry: %v", err) + } + entries = append(entries, entry) + } + + return entries +} + +func parseLastJSONLogLine(t *testing.T, buf *bytes.Buffer) map[string]interface{} { + t.Helper() + + entries := parseJSONLogLines(t, buf) + if len(entries) == 0 { + t.Fatal("expected log output") + } + + return entries[len(entries)-1] +} + +// testHandleConnect wraps handleConnect for tests that use the old signature +func (s *Server) testHandleConnect(conn net.Conn, req *Request) error { + tracker := traffic.NewTracker() + session := tracker.Start("test", "127.0.0.1") + defer session.Close() + logger := zerolog.New(io.Discard) + reqLogger := logger.With(). + Str("protocol", "socks5"). + Str("command", req.Command.String()). + Str("dest_addr", req.DestAddr.String()). + Logger() + return s.handleConnect(conn, req, session, reqLogger) +} + +// testHandleRequest wraps handleRequest for tests that use the old signature +func (s *Server) testHandleRequest(req *Request, conn net.Conn) error { + tracker := traffic.NewTracker() + session := tracker.Start("test", "127.0.0.1") + defer session.Close() + logger := zerolog.New(io.Discard) + reqLogger := logger.With(). + Str("protocol", "socks5"). + Str("command", req.Command.String()). + Str("dest_addr", req.DestAddr.String()). + Logger() + return s.handleRequest(req, conn, session, reqLogger) +} + func TestNew(t *testing.T) { conf := &Config{ Authentication: []Authenticator{&NoAuthAuthenticator{}}, @@ -30,6 +103,24 @@ func TestNew(t *testing.T) { assert.IsType(t, &NoAuthAuthenticator{}, server.authentication[NoAuth]) } +func TestNew_WithCredentials(t *testing.T) { + creds := credential.NewStaticCredentialStore() + creds.Add("user", "pass") + conf := &Config{Credentials: creds} + server := New(conf) + assert.NotNil(t, server) + _, ok := server.authentication[UserPassAuth] + assert.True(t, ok) +} + +func TestNew_DefaultNoAuth(t *testing.T) { + conf := &Config{} // no Authentication, no Credentials → NoAuthAuthenticator + server := New(conf) + assert.NotNil(t, server) + _, ok := server.authentication[NoAuth] + assert.True(t, ok) +} + func TestListenAndServe(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") assert.NoError(t, err) @@ -159,6 +250,289 @@ func TestListenAndServe_InvalidCredentials(t *testing.T) { assert.Equal(t, expected, out) } +func TestHandleConnection_LogsStructuredAuthFailure(t *testing.T) { + var logBuf bytes.Buffer + logger := zerolog.New(&logBuf) + credentials := credential.NewStaticCredentialStore() + credentials.Add("foo", "bar") + + server := New(&Config{ + Authentication: []Authenticator{&UserPassAuthenticator{Credentials: credentials}}, + Logger: &logger, + }) + + serverConn, clientConn := net.Pipe() + defer clientConn.Close() + + done := make(chan struct{}) + go func() { + defer close(done) + server.handleConnection(serverConn) + }() + + request := []byte{ + Version, + 1, UserPassAuth.Uint8(), + UserAuthVersion, + 3, 'b', 'a', 'd', + 4, 'p', 'a', 's', 's', + } + _, err := clientConn.Write(request) + assert.NoError(t, err) + + response := make([]byte, 4) + _, err = io.ReadFull(clientConn, response) + assert.NoError(t, err) + assert.Equal(t, []byte{Version, UserPassAuth.Uint8(), UserAuthVersion, AuthFailure.Uint8()}, response) + + clientConn.Close() + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timed out waiting for connection handler") + } + + entry := parseLastJSONLogLine(t, &logBuf) + assert.Equal(t, "proxy authentication failed", entry["message"]) + assert.Equal(t, "socks5", entry["protocol"]) + assert.Equal(t, "invalid credentials", entry["error"]) + assert.Equal(t, "error", entry["level"]) + assert.NotEmpty(t, entry["client_addr"]) +} + +func TestHandleConnection_LogsClientAddrForUnsupportedVersion(t *testing.T) { + var logBuf bytes.Buffer + logger := zerolog.New(&logBuf) + server := New(&Config{Logger: &logger}) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + assert.NoError(t, err) + defer listener.Close() + + done := make(chan struct{}) + go func() { + defer close(done) + conn, acceptErr := listener.Accept() + assert.NoError(t, acceptErr) + if acceptErr == nil { + server.handleConnection(conn) + } + }() + + clientConn, err := net.Dial("tcp", listener.Addr().String()) + assert.NoError(t, err) + _, err = clientConn.Write([]byte{0x04}) + assert.NoError(t, err) + _ = clientConn.Close() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timed out waiting for connection handler") + } + + entry := parseLastJSONLogLine(t, &logBuf) + assert.Equal(t, "unsupported version", entry["message"]) + assert.Equal(t, "socks5", entry["protocol"]) + assert.Equal(t, float64(4), entry["version"]) + clientAddr, ok := entry["client_addr"].(string) + assert.True(t, ok) + assert.NotEmpty(t, clientAddr) + assert.Contains(t, clientAddr, "127.0.0.1:") +} + +func TestHandleConnection_LogsRequestFailureContext(t *testing.T) { + var logBuf bytes.Buffer + logger := zerolog.New(&logBuf) + server := New(&Config{ + Authentication: []Authenticator{&NoAuthAuthenticator{}}, + Logger: &logger, + }) + + serverConn, clientConn := net.Pipe() + defer func() { _ = clientConn.Close() }() + + done := make(chan struct{}) + go func() { + defer close(done) + server.handleConnection(serverConn) + }() + + request := []byte{ + Version, + 1, NoAuth.Uint8(), + Version, + 9, // unsupported command + 0, + AddressTypeIPv4.Uint8(), + 127, 0, 0, 1, + 0, 80, + } + _, err := clientConn.Write(request) + assert.NoError(t, err) + + response := make([]byte, 12) + _, err = io.ReadFull(clientConn, response) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timed out waiting for connection handler") + } + + entry := parseLastJSONLogLine(t, &logBuf) + assert.Equal(t, "request failed", entry["message"]) + assert.Equal(t, "socks5", entry["protocol"]) + assert.Equal(t, "unknown", entry["command"]) + assert.Equal(t, "127.0.0.1:80", entry["dest_addr"]) + assert.Equal(t, "unsupported command: 9", entry["error"]) + assert.Equal(t, "error", entry["level"]) + clientAddr, ok := entry["client_addr"].(string) + assert.True(t, ok) + assert.NotEmpty(t, clientAddr) +} + +func TestHandleConnection_LogsSuccessfulRequestAtInfo(t *testing.T) { + backend, err := net.Listen("tcp", "127.0.0.1:0") + assert.NoError(t, err) + defer backend.Close() + + go func() { + conn, acceptErr := backend.Accept() + assert.NoError(t, acceptErr) + if acceptErr != nil { + return + } + defer conn.Close() + + buf := make([]byte, 4) + _, _ = io.ReadFull(conn, buf) + _, _ = conn.Write([]byte("pong")) + }() + + var logBuf bytes.Buffer + logger := zerolog.New(&logBuf).Level(zerolog.InfoLevel) + server := New(&Config{ + Authentication: []Authenticator{&NoAuthAuthenticator{}}, + Logger: &logger, + Tracker: traffic.NewTracker(), + }) + + serverConn, clientConn := net.Pipe() + defer clientConn.Close() + + done := make(chan struct{}) + go func() { + defer close(done) + server.handleConnection(serverConn) + }() + + backendAddr := backend.Addr().(*net.TCPAddr) + request := bytes.NewBuffer(nil) + request.Write([]byte{Version}) + request.Write([]byte{1, NoAuth.Uint8()}) + request.Write([]byte{Version, CommandConnect.Uint8(), 0, AddressTypeIPv4.Uint8()}) + request.Write([]byte{127, 0, 0, 1}) + port := []byte{0, 0} + binary.BigEndian.PutUint16(port, uint16(backendAddr.Port)) + request.Write(port) + request.Write([]byte("ping")) + + _, err = clientConn.Write(request.Bytes()) + assert.NoError(t, err) + + response := make([]byte, 16) + _, err = io.ReadFull(clientConn, response) + assert.NoError(t, err) + _ = clientConn.Close() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timed out waiting for connection handler") + } + + entry := parseLastJSONLogLine(t, &logBuf) + assert.Equal(t, "request completed", entry["message"]) + assert.Equal(t, "info", entry["level"]) + assert.Equal(t, "socks5", entry["protocol"]) + assert.Equal(t, "connect", entry["command"]) + assert.Equal(t, "anonymous", entry["username"]) + assert.Equal(t, fmt.Sprintf("127.0.0.1:%d", backendAddr.Port), entry["dest_addr"]) + assert.NotEmpty(t, entry["latency"]) + assert.Equal(t, float64(4), entry["upload_bytes"]) + assert.Equal(t, float64(4), entry["download_bytes"]) +} + +func TestHandleConnection_LogsSuccessfulRequestDetailsAtDebug(t *testing.T) { + backend, err := net.Listen("tcp", "127.0.0.1:0") + assert.NoError(t, err) + defer backend.Close() + + go func() { + conn, acceptErr := backend.Accept() + assert.NoError(t, acceptErr) + if acceptErr != nil { + return + } + defer conn.Close() + _, _ = io.Copy(io.Discard, conn) + }() + + var logBuf bytes.Buffer + logger := zerolog.New(&logBuf).Level(zerolog.DebugLevel) + server := New(&Config{ + Authentication: []Authenticator{&NoAuthAuthenticator{}}, + Logger: &logger, + Tracker: traffic.NewTracker(), + Resolver: resolverFunc(func(host string) (net.IP, error) { + return net.ParseIP("127.0.0.1"), nil + }), + }) + + serverConn, clientConn := net.Pipe() + defer clientConn.Close() + + done := make(chan struct{}) + go func() { + defer close(done) + server.handleConnection(serverConn) + }() + + backendAddr := backend.Addr().(*net.TCPAddr) + request := bytes.NewBuffer(nil) + request.Write([]byte{Version}) + request.Write([]byte{1, NoAuth.Uint8()}) + request.Write([]byte{Version, CommandConnect.Uint8(), 0, AddressTypeDomain.Uint8(), 17}) + request.WriteString("debug-target.test") + port := []byte{0, 0} + binary.BigEndian.PutUint16(port, uint16(backendAddr.Port)) + request.Write(port) + + _, err = clientConn.Write(request.Bytes()) + assert.NoError(t, err) + + response := make([]byte, 12) + _, err = io.ReadFull(clientConn, response) + assert.NoError(t, err) + _ = clientConn.Close() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timed out waiting for connection handler") + } + + entries := parseJSONLogLines(t, &logBuf) + assert.GreaterOrEqual(t, len(entries), 4) + assert.Equal(t, "connection accepted without authentication", entries[0]["message"]) + assert.Equal(t, "request received", entries[1]["message"]) + assert.Equal(t, "resolved destination address", entries[2]["message"]) + assert.Equal(t, "127.0.0.1", entries[2]["resolved_ip"]) + assert.Equal(t, "dialing destination", entries[3]["message"]) +} + func TestListenAndServe_InvalidAuthType(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") assert.NoError(t, err) @@ -287,7 +661,7 @@ func TestRequest_Unreachable(t *testing.T) { assert.NoError(t, err) req.realAddr = req.DestAddr - err = s.handleConnect(resp, req) + err = s.testHandleConnect(resp, req) assert.Error(t, err) out := resp.buf.Bytes() @@ -326,7 +700,7 @@ func TestRequest_Refused(t *testing.T) { assert.NoError(t, err) req.realAddr = req.DestAddr - err = s.handleConnect(resp, req) + err = s.testHandleConnect(resp, req) assert.Error(t, err) out := resp.buf.Bytes() @@ -369,7 +743,7 @@ func TestRequest_NetworkUnreachable(t *testing.T) { assert.NoError(t, err) req.realAddr = req.DestAddr - err = s.handleConnect(resp, req) + err = s.testHandleConnect(resp, req) assert.Error(t, err) out := resp.buf.Bytes() @@ -405,7 +779,7 @@ func TestRequest_CommandNotSupported(t *testing.T) { resp := &MockConn{} req, _ := NewRequest(buf) - err := s.handleRequest(req, resp) + err := s.testHandleRequest(req, resp) assert.Error(t, err) out := resp.buf.Bytes() @@ -420,3 +794,127 @@ func TestRequest_CommandNotSupported(t *testing.T) { assert.Equal(t, expected, out) } + +func TestShutdown_NilListener(t *testing.T) { + server := New(&Config{}) + err := server.Shutdown() + assert.NoError(t, err) +} + +func TestShutdown_WithListener(t *testing.T) { + server := New(&Config{}) + errCh := make(chan error, 1) + go func() { + errCh <- server.ListenAndServe("tcp", "127.0.0.1:0") + }() + // Give the goroutine time to start listening + time.Sleep(20 * time.Millisecond) + err := server.Shutdown() + assert.NoError(t, err) + select { + case err := <-errCh: + assert.Error(t, err) // serve returns after listener is closed + case <-time.After(time.Second): + t.Fatal("serve did not stop after Shutdown") + } +} + +func TestListenAndServe_InvalidAddress(t *testing.T) { + server := New(&Config{}) + err := server.ListenAndServe("tcp", "300.0.0.1:9999") // invalid IP + assert.Error(t, err) +} + +func TestHandleRequest_WithResolver(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + assert.NoError(t, err) + defer l.Close() + lAddr := l.Addr().(*net.TCPAddr) + + go func() { + conn, err2 := l.Accept() + if err2 != nil { + return + } + defer conn.Close() + buf := make([]byte, 4) + _, _ = io.ReadAtLeast(conn, buf, 4) + _, _ = conn.Write([]byte("pong")) + }() + + s := &Server{ + config: &Config{ + Resolver: &resolver.DNSResolver{}, + DestConnTimeout: 2 * time.Second, + }, + } + + req := &Request{ + Command: CommandConnect, + DestAddr: &AddrSpec{ + FQDN: "localhost", + Port: lAddr.Port, + }, + BufferConn: bytes.NewReader([]byte("ping")), + } + req.realAddr = req.DestAddr + + conn := &MockConn{} + _ = s.testHandleRequest(req, conn) +} + +func TestHandleRequest_ResolverError(t *testing.T) { + s := &Server{ + config: &Config{ + Resolver: &mockFailResolver{}, + }, + } + + conn := &MockConn{} + req := &Request{ + Command: CommandConnect, + DestAddr: &AddrSpec{FQDN: "bad.invalid", Port: 9999}, + } + + err := s.testHandleRequest(req, conn) + assert.Error(t, err) +} + +type mockFailResolver struct{} + +func (m *mockFailResolver) Resolve(_ string) (net.IP, error) { + return nil, errors.New("resolve failed") +} + +func TestHandleRequest_WithRewriter(t *testing.T) { + s := &Server{ + config: &Config{ + Rewriter: &mockRewriter{}, + }, + } + + buf := bytes.NewBuffer(nil) + buf.Write([]byte{ + Version, + CommandConnect.Uint8(), + 0, + AddressTypeIPv4.Uint8(), + 127, 0, 0, 1, + }) + port := []byte{0, 0} + binary.BigEndian.PutUint16(port, uint16(12345)) + buf.Write(port) + + conn := &MockConn{} + req, err := NewRequest(buf) + assert.NoError(t, err) + req.realAddr = req.DestAddr + // connection refused is expected — we just want to cover the Rewriter path + _ = s.testHandleRequest(req, conn) +} + +type mockRewriter struct{} + +func (m *mockRewriter) Rewrite(req *Request) *AddrSpec { + return req.DestAddr +} diff --git a/pkg/tor/identity_test.go b/pkg/tor/identity_test.go index 3475cdc..1d8f0a4 100644 --- a/pkg/tor/identity_test.go +++ b/pkg/tor/identity_test.go @@ -8,7 +8,7 @@ import ( "github.com/rs/zerolog" ) -// Mock Requester untuk menggantikan implementasi sebenarnya +// MockRequester replaces the real requester implementation in tests. type MockRequester struct { RequestNewTorIdentityFunc func(logger *zerolog.Logger) error } @@ -24,7 +24,7 @@ func TestWaitForTorBootstrap(t *testing.T) { t.Run("Successful bootstrap", func(t *testing.T) { mockRequester := &MockRequester{ RequestNewTorIdentityFunc: func(logger *zerolog.Logger) error { - return nil // selalu sukses + return nil // Always succeeds. }, } @@ -37,7 +37,7 @@ func TestWaitForTorBootstrap(t *testing.T) { t.Run("Timeout occurs", func(t *testing.T) { mockRequester := &MockRequester{ RequestNewTorIdentityFunc: func(logger *zerolog.Logger) error { - time.Sleep(3 * time.Second) // memicu timeout + time.Sleep(3 * time.Second) // Intentionally triggers a timeout. return nil }, } @@ -84,6 +84,6 @@ func TestSwitcherIdentity(t *testing.T) { }() time.Sleep(3 * time.Second) - // Tidak ada log error karena mockRequester selalu berhasil + // No error log is expected because mockRequester always succeeds. }) } diff --git a/pkg/traffic/bolt_store.go b/pkg/traffic/bolt_store.go new file mode 100644 index 0000000..075089e --- /dev/null +++ b/pkg/traffic/bolt_store.go @@ -0,0 +1,132 @@ +package traffic + +import ( + "encoding/json" + "errors" + "os" + "path/filepath" + "time" + + "go.etcd.io/bbolt" +) + +var trafficBucket = []byte("traffic") + +type storedTraffic struct { + UploadBytes uint64 `json:"upload_bytes"` + DownloadBytes uint64 `json:"download_bytes"` + LastClientIP string `json:"last_client_ip"` + LastSeenAt time.Time `json:"last_seen_at"` +} + +type BoltStore struct { + path string +} + +func NewBoltStore(path string) *BoltStore { + return &BoltStore{path: path} +} + +func (b *BoltStore) LoadTraffic() (map[string]UserTotals, error) { + if b == nil || b.path == "" { + return map[string]UserTotals{}, nil + } + if _, err := os.Stat(b.path); err != nil { + if errors.Is(err, os.ErrNotExist) { + return map[string]UserTotals{}, nil + } + return nil, err + } + db, err := bbolt.Open(b.path, 0o600, nil) + if err != nil { + return nil, err + } + defer db.Close() + + out := map[string]UserTotals{} + err = db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(trafficBucket) + if bucket == nil { + return nil + } + return bucket.ForEach(func(k, v []byte) error { + var rec storedTraffic + if err := json.Unmarshal(v, &rec); err != nil { + return nil + } + out[string(k)] = UserTotals{ + UploadBytes: rec.UploadBytes, + DownloadBytes: rec.DownloadBytes, + LastClientIP: rec.LastClientIP, + LastSeenAt: rec.LastSeenAt, + } + return nil + }) + }) + return out, err +} + +func (b *BoltStore) SaveTraffic(totals map[string]UserTotals) error { + if b == nil || b.path == "" { + return nil + } + dir := filepath.Dir(b.path) + if dir != "." { + if err := os.MkdirAll(dir, 0o750); err != nil { + return err + } + } + db, err := bbolt.Open(b.path, 0o600, nil) + if err != nil { + return err + } + defer db.Close() + + return db.Update(func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists(trafficBucket) + if err != nil { + return err + } + for username, t := range totals { + rec := storedTraffic{ + UploadBytes: t.UploadBytes, + DownloadBytes: t.DownloadBytes, + LastClientIP: t.LastClientIP, + LastSeenAt: t.LastSeenAt, + } + data, err := json.Marshal(rec) + if err != nil { + continue + } + if err := bucket.Put([]byte(username), data); err != nil { + return err + } + } + return nil + }) +} + +func (b *BoltStore) ResetUserTraffic(username string) error { + if b == nil || b.path == "" { + return nil + } + if _, err := os.Stat(b.path); err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return err + } + db, err := bbolt.Open(b.path, 0o600, nil) + if err != nil { + return err + } + defer db.Close() + + return db.Update(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(trafficBucket) + if bucket == nil { + return nil + } + return bucket.Delete([]byte(username)) + }) +} diff --git a/pkg/traffic/bolt_store_test.go b/pkg/traffic/bolt_store_test.go new file mode 100644 index 0000000..d41f08f --- /dev/null +++ b/pkg/traffic/bolt_store_test.go @@ -0,0 +1,64 @@ +package traffic + +import ( + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBoltStore_SaveAndLoadTraffic(t *testing.T) { + t.Parallel() + + path := filepath.Join(t.TempDir(), "traffic.db") + store := NewBoltStore(path) + + totals := map[string]UserTotals{ + "alice": { + UploadBytes: 1024, + DownloadBytes: 2048, + LastClientIP: "10.0.0.2", + LastSeenAt: time.Now(), + }, + "bob": { + UploadBytes: 512, + DownloadBytes: 1024, + LastClientIP: "10.0.0.3", + LastSeenAt: time.Now(), + }, + } + + require.NoError(t, store.SaveTraffic(totals)) + + loaded, err := store.LoadTraffic() + require.NoError(t, err) + assert.Len(t, loaded, 2) + assert.Equal(t, uint64(1024), loaded["alice"].UploadBytes) + assert.Equal(t, uint64(2048), loaded["alice"].DownloadBytes) + assert.Equal(t, "10.0.0.2", loaded["alice"].LastClientIP) +} + +func TestBoltStore_ResetUserTraffic(t *testing.T) { + t.Parallel() + + path := filepath.Join(t.TempDir(), "traffic.db") + store := NewBoltStore(path) + + totals := map[string]UserTotals{ + "alice": { + UploadBytes: 1024, + DownloadBytes: 2048, + LastClientIP: "10.0.0.2", + LastSeenAt: time.Now(), + }, + } + + require.NoError(t, store.SaveTraffic(totals)) + require.NoError(t, store.ResetUserTraffic("alice")) + + loaded, err := store.LoadTraffic() + require.NoError(t, err) + assert.Empty(t, loaded) +} diff --git a/pkg/traffic/store.go b/pkg/traffic/store.go new file mode 100644 index 0000000..c42d25d --- /dev/null +++ b/pkg/traffic/store.go @@ -0,0 +1,8 @@ +package traffic + +// Store persists per-user traffic totals across restarts. +type Store interface { + LoadTraffic() (map[string]UserTotals, error) + SaveTraffic(totals map[string]UserTotals) error + ResetUserTraffic(username string) error +} diff --git a/pkg/traffic/tracker.go b/pkg/traffic/tracker.go new file mode 100644 index 0000000..4def62b --- /dev/null +++ b/pkg/traffic/tracker.go @@ -0,0 +1,316 @@ +package traffic + +import ( + "sort" + "strconv" + "sync" + "sync/atomic" + "time" +) + +type Snapshot struct { + ID string + Username string + ClientIP string + UploadBytes uint64 + DownloadBytes uint64 + UploadBPS uint64 + DownloadBPS uint64 + StartedAt time.Time +} + +type Tracker struct { + mu sync.Mutex + sessions map[string]*sessionState + totals map[string]UserTotals + lastRate map[string]UserTotals + lastPoll time.Time + nextID atomic.Uint64 +} + +type UserTotals struct { + UploadBytes uint64 + DownloadBytes uint64 + UploadBPS uint64 + DownloadBPS uint64 + LastSeenAt time.Time + LastClientIP string +} + +type sessionState struct { + username string + clientIP string + started time.Time + + uploadBytes atomic.Uint64 + downloadBytes atomic.Uint64 + + lastSampleAt time.Time + lastUploadSample uint64 + lastDownloadSample uint64 + lastSeenUnix atomic.Int64 +} + +type Session struct { + tracker *Tracker + id string + once sync.Once +} + +func (s *Session) UploadBytes() uint64 { + if s == nil || s.tracker == nil { + return 0 + } + s.tracker.mu.Lock() + state := s.tracker.sessions[s.id] + s.tracker.mu.Unlock() + if state == nil { + return 0 + } + return state.uploadBytes.Load() +} + +func (s *Session) DownloadBytes() uint64 { + if s == nil || s.tracker == nil { + return 0 + } + s.tracker.mu.Lock() + state := s.tracker.sessions[s.id] + s.tracker.mu.Unlock() + if state == nil { + return 0 + } + return state.downloadBytes.Load() +} + +func NewTracker() *Tracker { + return &Tracker{ + sessions: make(map[string]*sessionState), + totals: make(map[string]UserTotals), + lastRate: make(map[string]UserTotals), + } +} + +func (t *Tracker) LoadPersistedTotals(store Store) error { + if t == nil || store == nil { + return nil + } + persisted, err := store.LoadTraffic() + if err != nil { + return err + } + t.mu.Lock() + defer t.mu.Unlock() + for username, totals := range persisted { + t.totals[username] = totals + } + return nil +} + +func (t *Tracker) ResetUserStats(username string) { + if t == nil { + return + } + t.mu.Lock() + defer t.mu.Unlock() + t.totals[username] = UserTotals{} +} + +func (t *Tracker) Start(username, clientIP string) *Session { + if t == nil { + return nil + } + if username == "" { + username = "anonymous" + } + if clientIP == "" { + clientIP = "unknown" + } + + id := strconv.FormatUint(t.nextID.Add(1), 10) + now := time.Now() + + t.mu.Lock() + state := &sessionState{ + username: username, + clientIP: clientIP, + started: now, + lastSampleAt: now, + } + state.lastSeenUnix.Store(now.UnixNano()) + t.sessions[id] = state + + totals := t.totals[username] + if totals.LastSeenAt.IsZero() || now.After(totals.LastSeenAt) { + totals.LastSeenAt = now + totals.LastClientIP = clientIP + } + t.totals[username] = totals + t.mu.Unlock() + + return &Session{tracker: t, id: id} +} + +func (s *Session) AddUpload(n int64) { + if s == nil || n <= 0 || s.tracker == nil { + return + } + s.tracker.mu.Lock() + state := s.tracker.sessions[s.id] + s.tracker.mu.Unlock() + if state == nil { + return + } + state.uploadBytes.Add(uint64(n)) + state.lastSeenUnix.Store(time.Now().UnixNano()) +} + +func (s *Session) AddDownload(n int64) { + if s == nil || n <= 0 || s.tracker == nil { + return + } + s.tracker.mu.Lock() + state := s.tracker.sessions[s.id] + s.tracker.mu.Unlock() + if state == nil { + return + } + state.downloadBytes.Add(uint64(n)) + state.lastSeenUnix.Store(time.Now().UnixNano()) +} + +func (s *Session) Close() { + if s == nil || s.tracker == nil { + return + } + s.once.Do(func() { + s.tracker.mu.Lock() + state := s.tracker.sessions[s.id] + if state != nil { + totals := s.tracker.totals[state.username] + totals.UploadBytes += state.uploadBytes.Load() + totals.DownloadBytes += state.downloadBytes.Load() + lastSeenUnix := state.lastSeenUnix.Load() + lastSeenAt := time.Unix(0, lastSeenUnix) + if lastSeenUnix <= 0 { + lastSeenAt = state.started + } + if totals.LastSeenAt.IsZero() || lastSeenAt.After(totals.LastSeenAt) { + totals.LastSeenAt = lastSeenAt + totals.LastClientIP = state.clientIP + } + s.tracker.totals[state.username] = totals + } + delete(s.tracker.sessions, s.id) + s.tracker.mu.Unlock() + }) +} + +func (t *Tracker) TotalsByUser() map[string]UserTotals { + if t == nil { + return nil + } + + t.mu.Lock() + defer t.mu.Unlock() + now := time.Now() + + out := make(map[string]UserTotals, len(t.totals)) + for username, totals := range t.totals { + out[username] = totals + } + + for _, s := range t.sessions { + totals := out[s.username] + totals.UploadBytes += s.uploadBytes.Load() + totals.DownloadBytes += s.downloadBytes.Load() + lastSeenUnix := s.lastSeenUnix.Load() + lastSeenAt := time.Unix(0, lastSeenUnix) + if lastSeenUnix <= 0 { + lastSeenAt = s.started + } + if totals.LastSeenAt.IsZero() || lastSeenAt.After(totals.LastSeenAt) { + totals.LastSeenAt = lastSeenAt + totals.LastClientIP = s.clientIP + } + out[s.username] = totals + } + + elapsed := now.Sub(t.lastPoll) + if elapsed <= 0 { + elapsed = time.Second + } + + nextLastRate := make(map[string]UserTotals, len(out)) + for username, totals := range out { + prev := t.lastRate[username] + + if totals.UploadBytes >= prev.UploadBytes { + totals.UploadBPS = uint64(float64(totals.UploadBytes-prev.UploadBytes) / elapsed.Seconds()) + } + if totals.DownloadBytes >= prev.DownloadBytes { + totals.DownloadBPS = uint64(float64(totals.DownloadBytes-prev.DownloadBytes) / elapsed.Seconds()) + } + + out[username] = totals + nextLastRate[username] = UserTotals{ + UploadBytes: totals.UploadBytes, + DownloadBytes: totals.DownloadBytes, + } + } + + t.lastRate = nextLastRate + t.lastPoll = now + + return out +} + +func (t *Tracker) Snapshot() []Snapshot { + if t == nil { + return nil + } + + now := time.Now() + t.mu.Lock() + out := make([]Snapshot, 0, len(t.sessions)) + for id, s := range t.sessions { + upload := s.uploadBytes.Load() + download := s.downloadBytes.Load() + + elapsed := now.Sub(s.lastSampleAt) + if elapsed <= 0 { + elapsed = time.Second + } + + uploadDelta := upload - s.lastUploadSample + downloadDelta := download - s.lastDownloadSample + + uploadBPS := uint64(float64(uploadDelta) / elapsed.Seconds()) + downloadBPS := uint64(float64(downloadDelta) / elapsed.Seconds()) + + s.lastSampleAt = now + s.lastUploadSample = upload + s.lastDownloadSample = download + + out = append(out, Snapshot{ + ID: id, + Username: s.username, + ClientIP: s.clientIP, + UploadBytes: upload, + DownloadBytes: download, + UploadBPS: uploadBPS, + DownloadBPS: downloadBPS, + StartedAt: s.started, + }) + } + t.mu.Unlock() + + sort.Slice(out, func(i, j int) bool { + if out[i].Username == out[j].Username { + return out[i].ClientIP < out[j].ClientIP + } + return out[i].Username < out[j].Username + }) + + return out +} diff --git a/pkg/traffic/tracker_test.go b/pkg/traffic/tracker_test.go new file mode 100644 index 0000000..b3f6e2e --- /dev/null +++ b/pkg/traffic/tracker_test.go @@ -0,0 +1,91 @@ +package traffic + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestTracker_SessionLifecycleAndSnapshot(t *testing.T) { + tracker := NewTracker() + s := tracker.Start("alice", "10.0.0.2") + s.AddUpload(100) + s.AddDownload(200) + + first := tracker.Snapshot() + assert.Len(t, first, 1) + assert.Equal(t, "alice", first[0].Username) + assert.Equal(t, "10.0.0.2", first[0].ClientIP) + assert.Equal(t, uint64(100), first[0].UploadBytes) + assert.Equal(t, uint64(200), first[0].DownloadBytes) + + time.Sleep(20 * time.Millisecond) + s.AddUpload(50) + s.AddDownload(30) + second := tracker.Snapshot() + assert.Len(t, second, 1) + assert.Equal(t, uint64(150), second[0].UploadBytes) + assert.Equal(t, uint64(230), second[0].DownloadBytes) + assert.GreaterOrEqual(t, second[0].UploadBPS, uint64(1)) + assert.GreaterOrEqual(t, second[0].DownloadBPS, uint64(1)) + + s.Close() + assert.Empty(t, tracker.Snapshot()) +} + +func TestTracker_DefaultLabels(t *testing.T) { + tracker := NewTracker() + _ = tracker.Start("", "") + snaps := tracker.Snapshot() + assert.Len(t, snaps, 1) + assert.Equal(t, "anonymous", snaps[0].Username) + assert.Equal(t, "unknown", snaps[0].ClientIP) +} + +func TestTracker_TotalsByUser_PersistsAfterSessionClose(t *testing.T) { + tracker := NewTracker() + s1 := tracker.Start("alice", "10.0.0.2") + s1.AddUpload(120) + s1.AddDownload(300) + + totals := tracker.TotalsByUser() + assert.Equal(t, uint64(120), totals["alice"].UploadBytes) + assert.Equal(t, uint64(300), totals["alice"].DownloadBytes) + assert.False(t, totals["alice"].LastSeenAt.IsZero()) + assert.Equal(t, "10.0.0.2", totals["alice"].LastClientIP) + + s1.Close() + totals = tracker.TotalsByUser() + assert.Equal(t, uint64(120), totals["alice"].UploadBytes) + assert.Equal(t, uint64(300), totals["alice"].DownloadBytes) + assert.False(t, totals["alice"].LastSeenAt.IsZero()) + + s2 := tracker.Start("alice", "10.0.0.3") + s2.AddUpload(30) + s2.AddDownload(50) + totals = tracker.TotalsByUser() + assert.Equal(t, uint64(150), totals["alice"].UploadBytes) + assert.Equal(t, uint64(350), totals["alice"].DownloadBytes) + assert.Equal(t, "10.0.0.3", totals["alice"].LastClientIP) +} + +func TestTracker_TotalsByUser_ComputesRateAcrossPolls(t *testing.T) { + tracker := NewTracker() + + // Prime the rate baseline. + _ = tracker.TotalsByUser() + + s := tracker.Start("alice", "10.0.0.2") + s.AddUpload(1024) + s.AddDownload(2048) + s.Close() + + time.Sleep(20 * time.Millisecond) + totals := tracker.TotalsByUser() + + assert.Equal(t, uint64(1024), totals["alice"].UploadBytes) + assert.Equal(t, uint64(2048), totals["alice"].DownloadBytes) + assert.Greater(t, totals["alice"].UploadBPS, uint64(0)) + assert.Greater(t, totals["alice"].DownloadBPS, uint64(0)) +} diff --git a/systemd/nanoproxy.service b/systemd/nanoproxy.service index 6fabb9d..1f6279a 100644 --- a/systemd/nanoproxy.service +++ b/systemd/nanoproxy.service @@ -3,7 +3,13 @@ Description=NanoProxy is a simple reverse proxy written in Go After=network.target [Service] -EnvironmentFile=/etc/nanoproxy/nanoproxy +Environment=ADDR=:1080 +Environment=ADDR_HTTP=:8080 +Environment=ADDR_ADMIN=:9090 +Environment=NETWORK=tcp +Environment=TZ=Local +Environment=USER_STORE_PATH=/var/lib/nanoproxy/data.db +ExecStartPre=/bin/mkdir -p /var/lib/nanoproxy ExecStart=/usr/bin/nanoproxy WorkingDirectory=/usr/bin Restart=always