Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion client/cmd/testutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
if err != nil {
t.Fatal(err)
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil)
if err != nil {
t.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion client/internal/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1671,7 +1671,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
if err != nil {
return nil, "", err
}
Expand Down
2 changes: 1 addition & 1 deletion client/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
if err != nil {
return nil, "", err
}
Expand Down
2 changes: 1 addition & 1 deletion management/internals/server/boot.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
}

gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider())
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider(), s.SessionStore())
if err != nil {
log.Fatalf("failed to create management server: %v", err)
}
Expand Down
7 changes: 7 additions & 0 deletions management/internals/server/controllers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
log "github.com/sirupsen/logrus"

"github.com/netbirdio/management-integrations/integrations"

"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"

Expand Down Expand Up @@ -66,6 +67,12 @@ func (s *BaseServer) SecretsManager() grpc.SecretsManager {
})
}

func (s *BaseServer) SessionStore() *auth.SessionStore {
return Create(s, func() *auth.SessionStore {
return auth.NewSessionStore(s.CacheStore())
})
}

func (s *BaseServer) AuthManager() auth.Manager {
audiences := s.Config.GetAuthAudiences()
audience := s.Config.HttpConfig.AuthAudience
Expand Down
37 changes: 35 additions & 2 deletions management/internals/shared/grpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"sync/atomic"
"time"

jwtv5 "github.com/golang-jwt/jwt/v5"
pb "github.com/golang/protobuf/proto" // nolint
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
Expand Down Expand Up @@ -67,6 +68,7 @@ type Server struct {
appMetrics telemetry.AppMetrics
peerLocks sync.Map
authManager auth.Manager
sessionStore *auth.SessionStore

logBlockedPeers bool
blockPeersWithSameConfig bool
Expand Down Expand Up @@ -98,6 +100,7 @@ func NewServer(
integratedPeerValidator integrated_validator.IntegratedValidator,
networkMapController network_map.Controller,
oAuthConfigProvider idp.OAuthConfigProvider,
sessionStore *auth.SessionStore,
) (*Server, error) {
if appMetrics != nil {
// update gauge based on number of connected peers which is equal to open gRPC streams
Expand Down Expand Up @@ -140,6 +143,7 @@ func NewServer(
integratedPeerValidator: integratedPeerValidator,
networkMapController: networkMapController,
oAuthConfigProvider: oAuthConfigProvider,
sessionStore: sessionStore,

loginFilter: newLoginFilter(),

Expand Down Expand Up @@ -535,7 +539,7 @@ func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID st
log.WithContext(ctx).Debugf("peer %s has been disconnected", peer.Key)
}

func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) {
func (s *Server) validateToken(ctx context.Context, peerKey, jwtToken string) (string, error) {
if s.authManager == nil {
return "", status.Errorf(codes.Internal, "missing auth manager")
}
Expand All @@ -545,6 +549,10 @@ func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, er
return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err)
}

if err := s.claimLoginToken(ctx, peerKey, jwtToken, token); err != nil {
return "", err
}

// we need to call this method because if user is new, we will automatically add it to existing or create a new account
accountId, _, err := s.accountManager.GetAccountIDFromUserAuth(ctx, userAuth)
if err != nil {
Expand Down Expand Up @@ -828,6 +836,31 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
return loginResp, nil
}

func (s *Server) claimLoginToken(ctx context.Context, peerKey, jwtToken string, token *jwtv5.Token) error {
if s.sessionStore == nil || token == nil {
return nil
}

exp, err := token.Claims.GetExpirationTime()
if err != nil || exp == nil {
log.WithContext(ctx).Warnf("JWT has no usable exp claim for peer %s", peerKey)
return status.Error(codes.Unauthenticated, "jwt token has no expiration")
}

err = s.sessionStore.RegisterToken(ctx, jwtToken, exp.Time)
if err == nil {
return nil
}

if errors.Is(err, auth.ErrTokenAlreadyUsed) || errors.Is(err, auth.ErrTokenExpired) {
log.WithContext(ctx).Warnf("%v for peer %s", err, peerKey)
return status.Error(codes.Unauthenticated, err.Error())
}

log.WithContext(ctx).Warnf("failed to claim JWT for peer %s: %v", peerKey, err)
return status.Error(codes.Unavailable, "failed to claim jwt token")
}

// processJwtToken validates the existence of a JWT token in the login request, and returns the corresponding user ID if
// the token is valid.
//
Expand All @@ -838,7 +871,7 @@ func (s *Server) processJwtToken(ctx context.Context, loginReq *proto.LoginReque
if loginReq.GetJwtToken() != "" {
var err error
for i := 0; i < 3; i++ {
userID, err = s.validateToken(ctx, loginReq.GetJwtToken())
userID, err = s.validateToken(ctx, peerKey.String(), loginReq.GetJwtToken())
if err == nil {
break
}
Expand Down
61 changes: 61 additions & 0 deletions management/server/auth/session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package auth

import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"time"

"github.com/eko/gocache/lib/v4/cache"
"github.com/eko/gocache/lib/v4/store"
)

const (
usedTokenKeyPrefix = "jwt-used:"
usedTokenMarker = "1"
)

var (
ErrTokenAlreadyUsed = errors.New("JWT already used")
ErrTokenExpired = errors.New("JWT expired")
)

type SessionStore struct {
cache *cache.Cache[string]
}

func NewSessionStore(cacheStore store.StoreInterface) *SessionStore {
return &SessionStore{cache: cache.New[string](cacheStore)}
}

// RegisterToken records a JWT until its exp time and rejects reuse.
func (s *SessionStore) RegisterToken(ctx context.Context, token string, expiresAt time.Time) error {
ttl := time.Until(expiresAt)
if ttl <= 0 {
return ErrTokenExpired
}

key := usedTokenKeyPrefix + hashToken(token)
_, err := s.cache.Get(ctx, key)
if err == nil {
return ErrTokenAlreadyUsed
}

var notFound *store.NotFound
if !errors.As(err, &notFound) {
return fmt.Errorf("failed to lookup used token entry: %w", err)
}

if err := s.cache.Set(ctx, key, usedTokenMarker, store.WithExpiration(ttl)); err != nil {
return fmt.Errorf("failed to store used token entry: %w", err)
}

return nil
}
Comment thread
bcmmbaga marked this conversation as resolved.

func hashToken(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}
82 changes: 82 additions & 0 deletions management/server/auth/session_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package auth

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

nbcache "github.com/netbirdio/netbird/management/server/cache"
)

func newTestSessionStore(t *testing.T) *SessionStore {
t.Helper()
cacheStore, err := nbcache.NewStore(context.Background(), time.Hour, time.Hour, 100)
require.NoError(t, err)
return NewSessionStore(cacheStore)
}

func TestSessionStore_FirstRegisterSucceeds(t *testing.T) {
s := newTestSessionStore(t)
ctx := context.Background()

require.NoError(t, s.RegisterToken(ctx, "token", time.Now().Add(time.Hour)))
}

func TestSessionStore_RegisterSameTokenTwiceIsRejected(t *testing.T) {
s := newTestSessionStore(t)
ctx := context.Background()
token := "token"
exp := time.Now().Add(time.Hour)

require.NoError(t, s.RegisterToken(ctx, token, exp))

err := s.RegisterToken(ctx, token, exp)
require.Error(t, err)
assert.ErrorIs(t, err, ErrTokenAlreadyUsed)
}

func TestSessionStore_RegisterDifferentTokensAreIndependent(t *testing.T) {
s := newTestSessionStore(t)
ctx := context.Background()
exp := time.Now().Add(time.Hour)

require.NoError(t, s.RegisterToken(ctx, "tokenA", exp))
require.NoError(t, s.RegisterToken(ctx, "tokenB", exp))
}

func TestSessionStore_RegisterWithPastExpiryIsRejected(t *testing.T) {
s := newTestSessionStore(t)
ctx := context.Background()
token := "token"

err := s.RegisterToken(ctx, token, time.Now().Add(-time.Second))
require.Error(t, err)
assert.ErrorIs(t, err, ErrTokenExpired)
}

func TestSessionStore_EntryEvictsAtTTLAndAllowsReRegistration(t *testing.T) {
s := newTestSessionStore(t)
ctx := context.Background()
token := "token"

require.NoError(t, s.RegisterToken(ctx, token, time.Now().Add(50*time.Millisecond)))

err := s.RegisterToken(ctx, token, time.Now().Add(50*time.Millisecond))
assert.ErrorIs(t, err, ErrTokenAlreadyUsed)

time.Sleep(120 * time.Millisecond)

require.NoError(t, s.RegisterToken(ctx, token, time.Now().Add(time.Hour)))
}

func TestHashToken_StableAndDoesNotLeak(t *testing.T) {
a := hashToken("tokenA")
b := hashToken("tokenB")
assert.Equal(t, a, hashToken("tokenA"), "hash must be deterministic")
assert.NotEqual(t, a, b, "different tokens must hash differently")
assert.Len(t, a, 64, "sha256 hex must be 64 chars")
assert.NotContains(t, a, "tokenA", "raw token must not appear in hash")
}
2 changes: 1 addition & 1 deletion management/server/management_proto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
return nil, nil, "", cleanup, err
}

mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil, nil)
if err != nil {
return nil, nil, "", cleanup, err
}
Expand Down
1 change: 1 addition & 0 deletions management/server/management_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ func startServer(
server.MockIntegratedValidator{},
networkMapController,
nil,
nil,
)
if err != nil {
t.Fatalf("failed creating management server: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion shared/management/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
if err != nil {
t.Fatal(err)
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController, nil, nil)
if err != nil {
t.Fatal(err)
}
Expand Down
Loading