Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion authentication/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"strings"
"time"

"github.com/agentuity/go-common/crypto"
"github.com/xhit/go-str2duration/v2"
)

Expand Down Expand Up @@ -82,7 +83,40 @@ func NewBearerToken(sharedSecret string, opts ...TokenOpt) (string, error) {
return nonce + "." + tok2, nil
}

// NewBearerTokenV2 generates a v2 bearer token using HKDF-derived key.
// The token format is "v2.bearer-token.<nonce>.<hash>" (non-expiring)
// or "v2.bearer-token.<duration>.<timestamp>.<hash>" (expiring).
func NewBearerTokenV2(sharedSecret string, opts ...TokenOpt) (string, error) {
derivedKey, err := crypto.DeriveKey([]byte(sharedSecret), crypto.ContextBearerToken)
if err != nil {
return "", fmt.Errorf("failed to derive key: %w", err)
}
// Generate the inner token using the derived key (same algorithm as v1)
innerToken, err := NewBearerToken(string(derivedKey), opts...)
if err != nil {
return "", err
}
return crypto.FormatV2Token(crypto.ContextBearerToken, innerToken), nil
}

func ValidateToken(sharedSecret string, auth string) error {
version, context, payload := crypto.DetectTokenVersion(auth)
if version == "v2" {
if context != crypto.ContextBearerToken {
return ErrInvalidToken
}
derivedKey, err := crypto.DeriveKey([]byte(sharedSecret), crypto.ContextBearerToken)
if err != nil {
return ErrInvalidToken
}
return validateTokenInner(string(derivedKey), payload)
}
// v1 legacy: use raw shared secret
return validateTokenInner(sharedSecret, auth)
}

// validateTokenInner contains the core token validation logic shared by v1 and v2.
func validateTokenInner(key string, auth string) error {
if len(auth) < 32 {
return ErrInvalidToken
}
Expand Down Expand Up @@ -127,7 +161,7 @@ func ValidateToken(sharedSecret string, auth string) error {

// see if we can hash the token with our shared secret to get the same value as the second token
hash := sha256.New()
hash.Write([]byte(sharedSecret + "." + token))
hash.Write([]byte(key + "." + token))
secret := hash.Sum(nil)

// if the two values are not the same, return an error
Expand Down
81 changes: 81 additions & 0 deletions authentication/authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,84 @@ func TestWithExpirationValidRanges(t *testing.T) {
})
}
}

// V2 Bearer Token Tests

func TestNewBearerTokenV2(t *testing.T) {
sharedSecret := "test-secret"

token, err := NewBearerTokenV2(sharedSecret)
assert.NoError(t, err)
assert.NotEmpty(t, token)

// v2 token should validate
err = ValidateToken(sharedSecret, token)
assert.NoError(t, err)
}

func TestNewBearerTokenV2WithExpiration(t *testing.T) {
sharedSecret := "test-secret"
expiration := time.Now().Add(2 * time.Hour)

token, err := NewBearerTokenV2(sharedSecret, WithExpiration(expiration))
assert.NoError(t, err)
assert.NotEmpty(t, token)

// v2 expiring token should validate
err = ValidateToken(sharedSecret, token)
assert.NoError(t, err)
}

func TestV2TokenPrefix(t *testing.T) {
sharedSecret := "test-secret"

token, err := NewBearerTokenV2(sharedSecret)
assert.NoError(t, err)
assert.True(t, strings.HasPrefix(token, "v2.bearer-token."), "v2 token should start with 'v2.bearer-token.' but got: %s", token)
}

func TestV1TokenStillValidates(t *testing.T) {
sharedSecret := "test-secret"

// Generate a v1 token
token, err := NewBearerToken(sharedSecret)
assert.NoError(t, err)
assert.NotEmpty(t, token)

// v1 token should still validate through the updated ValidateToken
err = ValidateToken(sharedSecret, token)
assert.NoError(t, err)
}

func TestV2TokenNotValidWithWrongSecret(t *testing.T) {
token, err := NewBearerTokenV2("correct-secret")
assert.NoError(t, err)

err = ValidateToken("wrong-secret", token)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrInvalidToken)
}

func TestCrossVersionIsolation(t *testing.T) {
sharedSecret := "test-secret"

// Generate v1 and v2 tokens
v1Token, err := NewBearerToken(sharedSecret)
assert.NoError(t, err)

v2Token, err := NewBearerTokenV2(sharedSecret)
assert.NoError(t, err)

// Extract the inner payload of the v2 token (strip "v2.bearer-token." prefix)
v2Payload := strings.TrimPrefix(v2Token, "v2.bearer-token.")
assert.NotEqual(t, v2Token, v2Payload, "v2 token should have the prefix")

// v2 payload should NOT validate as v1 (because it was hashed with derived key, not raw secret)
err = validateTokenInner(sharedSecret, v2Payload)
assert.Error(t, err, "v2 token payload should not validate as v1 with raw shared secret")

// v1 token wrapped as v2 should NOT validate (because ValidateToken will use derived key)
fakeV2 := "v2.bearer-token." + v1Token
err = ValidateToken(sharedSecret, fakeV2)
assert.Error(t, err, "v1 token wrapped as v2 should not validate")
}
64 changes: 64 additions & 0 deletions crypto/derive.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package crypto

import (
"crypto/sha256"
"fmt"
"io"
"strings"

"golang.org/x/crypto/hkdf"
)

const (
// KeyDerivationVersion is the current version of key derivation.
KeyDerivationVersion = "v2"

// keyDerivationSalt includes the version to ensure different versions produce different keys.
keyDerivationSalt = "agentuity-key-derivation-" + KeyDerivationVersion

// Context constants for different key derivation purposes.
// Context strings MUST NOT contain '.' characters, as this would break
// DetectTokenVersion's parsing which uses SplitN(token, ".", 3) to separate
// the version prefix, context, and payload.
ContextBearerToken = "bearer-token"
ContextStickySession = "sticky-session"
ContextPostgresInternal = "postgres-internal"
ContextGravityJWT = "gravity-jwt"
ContextS3Webhook = "s3-webhook"
)

// DeriveKey derives a purpose-specific 32-byte key from a master secret using HKDF-SHA256.
// The context parameter provides domain separation so the same master secret produces
// different keys for different purposes.
func DeriveKey(masterSecret []byte, context string) ([]byte, error) {
if len(masterSecret) == 0 {
return nil, fmt.Errorf("master secret cannot be empty")
}
if context == "" {
return nil, fmt.Errorf("context cannot be empty")
}
reader := hkdf.New(sha256.New, masterSecret, []byte(keyDerivationSalt), []byte(context))
key := make([]byte, 32)
if _, err := io.ReadFull(reader, key); err != nil {
return nil, fmt.Errorf("failed to derive key for context %q: %w", context, err)
}
return key, nil
}

// DetectTokenVersion inspects a token string and returns its version.
// v2 tokens have the format "v2.<context>.<payload>".
// All other tokens are assumed to be v1 (legacy format).
func DetectTokenVersion(token string) (version string, context string, payload string) {
if strings.HasPrefix(token, "v2.") {
parts := strings.SplitN(token, ".", 3)
if len(parts) == 3 {
return parts[0], parts[1], parts[2]
}
}
return "v1", "", token
}

// FormatV2Token creates a v2 prefixed token string: "v2.<context>.<payload>"
func FormatV2Token(context string, payload string) string {
return fmt.Sprintf("v2.%s.%s", context, payload)
}
122 changes: 122 additions & 0 deletions crypto/derive_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package crypto

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestDeriveKeyDeterministic(t *testing.T) {
secret := []byte("my-master-secret")
key1, err := DeriveKey(secret, ContextBearerToken)
require.NoError(t, err)

key2, err := DeriveKey(secret, ContextBearerToken)
require.NoError(t, err)

assert.Equal(t, key1, key2, "DeriveKey should produce consistent output for same inputs")
}

func TestDeriveKeyDomainSeparation(t *testing.T) {
secret := []byte("my-master-secret")

key1, err := DeriveKey(secret, ContextBearerToken)
require.NoError(t, err)

key2, err := DeriveKey(secret, ContextStickySession)
require.NoError(t, err)

assert.NotEqual(t, key1, key2, "DeriveKey should produce different output for different contexts")
}

func TestDeriveKeyDifferentSecrets(t *testing.T) {
key1, err := DeriveKey([]byte("secret-one"), ContextBearerToken)
require.NoError(t, err)

key2, err := DeriveKey([]byte("secret-two"), ContextBearerToken)
require.NoError(t, err)

assert.NotEqual(t, key1, key2, "DeriveKey should produce different output for different master secrets")
}

func TestDeriveKeyEmptyMasterSecret(t *testing.T) {
key, err := DeriveKey([]byte{}, ContextBearerToken)
assert.Error(t, err)
assert.Nil(t, key)
assert.Contains(t, err.Error(), "master secret cannot be empty")
}

func TestDeriveKeyNilMasterSecret(t *testing.T) {
key, err := DeriveKey(nil, ContextBearerToken)
assert.Error(t, err)
assert.Nil(t, key)
assert.Contains(t, err.Error(), "master secret cannot be empty")
}

func TestDeriveKeyEmptyContext(t *testing.T) {
key, err := DeriveKey([]byte("secret"), "")
assert.Error(t, err)
assert.Nil(t, key)
assert.Contains(t, err.Error(), "context cannot be empty")
}

func TestDeriveKeyOutputLength(t *testing.T) {
key, err := DeriveKey([]byte("my-master-secret"), ContextBearerToken)
require.NoError(t, err)
assert.Len(t, key, 32, "DeriveKey output should be exactly 32 bytes")
}

func TestDetectTokenVersionV2(t *testing.T) {
version, context, payload := DetectTokenVersion("v2.bearer-token.somePayload")
assert.Equal(t, "v2", version)
assert.Equal(t, "bearer-token", context)
assert.Equal(t, "somePayload", payload)
}

func TestDetectTokenVersionV1(t *testing.T) {
version, context, payload := DetectTokenVersion("oldStyleToken.hash")
assert.Equal(t, "v1", version)
assert.Equal(t, "", context)
assert.Equal(t, "oldStyleToken.hash", payload)
}

func TestDetectTokenVersionMalformedV2(t *testing.T) {
// "v2.incomplete" has the v2 prefix but only 2 parts, not 3
version, context, payload := DetectTokenVersion("v2.incomplete")
assert.Equal(t, "v1", version)
assert.Equal(t, "", context)
assert.Equal(t, "v2.incomplete", payload)
}

func TestFormatV2Token(t *testing.T) {
result := FormatV2Token("bearer-token", "payload123")
assert.Equal(t, "v2.bearer-token.payload123", result)
}

func TestDeriveKeyAllContexts(t *testing.T) {
secret := []byte("test-secret")
contexts := []string{
ContextBearerToken,
ContextStickySession,
ContextPostgresInternal,
ContextGravityJWT,
ContextS3Webhook,
}

keys := make(map[string][]byte)
for _, ctx := range contexts {
key, err := DeriveKey(secret, ctx)
require.NoError(t, err)
assert.Len(t, key, 32)
keys[ctx] = key
}

// Verify all keys are unique
for i, ctx1 := range contexts {
for _, ctx2 := range contexts[i+1:] {
assert.NotEqual(t, keys[ctx1], keys[ctx2],
"keys for %q and %q should be different", ctx1, ctx2)
}
}
}
15 changes: 8 additions & 7 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ require (
go.opentelemetry.io/otel/sdk/log v0.14.0
go.opentelemetry.io/otel/trace v1.40.0
go.uber.org/zap v1.27.0
golang.org/x/net v0.46.0
golang.org/x/sync v0.17.0
golang.org/x/term v0.36.0
golang.org/x/crypto v0.49.0
golang.org/x/net v0.51.0
golang.org/x/sync v0.20.0
golang.org/x/term v0.41.0
google.golang.org/grpc v1.75.0
google.golang.org/protobuf v1.36.8
gopkg.in/yaml.v3 v3.0.1
Expand Down Expand Up @@ -103,10 +104,10 @@ require (
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/mod v0.29.0 // indirect
golang.org/x/sys v0.40.0 // indirect
golang.org/x/text v0.30.0 // indirect
golang.org/x/tools v0.38.0 // indirect
golang.org/x/mod v0.33.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.35.0 // indirect
golang.org/x/tools v0.42.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
Expand Down
Loading
Loading