-
Notifications
You must be signed in to change notification settings - Fork 8
Update newpool to add IAM auth beforeConnect #32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
401e9a7
02903c3
e3bffd3
ef65e44
b66b204
3e01779
7d3b4dd
7c43cf9
96678e6
b473e89
b34cd9c
837bc43
ad81ebb
5d44793
e3fcf9e
d8fedab
c00ee81
eb97c09
7265ebc
7129bea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| package drivers | ||
|
|
||
| import ( | ||
| "context" | ||
| "fmt" | ||
| "log/slog" | ||
| "net" | ||
| "net/url" | ||
| "strings" | ||
|
|
||
| awsConfig "github.com/aws/aws-sdk-go-v2/config" | ||
| "github.com/aws/aws-sdk-go-v2/feature/rds/auth" | ||
| ) | ||
|
|
||
| type DatabaseConfiguration struct { | ||
| Connection string `json:"connection"` | ||
| Address string `json:"addr"` | ||
| Database string `json:"database"` | ||
| Username string `json:"username"` | ||
| Secret string `json:"secret"` | ||
| MaxConcurrentSessions int `json:"max_concurrent_sessions"` | ||
| EnableRDSIAMAuth bool `json:"enable_rds_iam_auth"` | ||
| Endpoint string | ||
| } | ||
|
|
||
| func (s DatabaseConfiguration) defaultPostgreSQLConnectionString() string { | ||
| if s.Connection != "" { | ||
| return s.Connection | ||
| } | ||
|
|
||
| return fmt.Sprintf("postgresql://%s:%s@%s/%s", s.Username, url.QueryEscape(s.Secret), s.Address, s.Database) | ||
| } | ||
|
|
||
| // Looks up CNAME record to get an RDS instance identifier and builds an IAM auth token, returning a connection string | ||
| func (s DatabaseConfiguration) RDSIAMAuthConnectionString() string { | ||
| if cfg, err := awsConfig.LoadDefaultConfig(context.TODO()); err != nil { | ||
| slog.Error("AWS Config Loading Error", slog.String("err", err.Error())) | ||
| } else { | ||
| // Must use instance endpoint with IAM auth | ||
| var endpoint string | ||
| if s.Endpoint != "" { | ||
| endpoint = s.Endpoint | ||
| } else { | ||
| endpoint = s.LookupEndpoint() | ||
| } | ||
|
|
||
| slog.Info("Requesting RDS IAM Auth Token") | ||
| if authenticationToken, err := auth.BuildAuthToken(context.TODO(), endpoint, cfg.Region, s.Username, cfg.Credentials); err != nil { | ||
| slog.Error("RDS IAM Auth Token Request Error", slog.String("err", err.Error())) | ||
| } else { | ||
| slog.Info("RDS IAM Auth Token Created") | ||
| return fmt.Sprintf("postgresql://%s:%s@%s/%s", s.Username, url.QueryEscape(authenticationToken), endpoint, s.Database) | ||
| } | ||
| } | ||
|
|
||
| slog.Warn("Failed to create IAM auth token. Falling back to default Postgres connection string") | ||
| return s.defaultPostgreSQLConnectionString() | ||
| } | ||
|
|
||
| func (s DatabaseConfiguration) PostgreSQLConnectionString() string { | ||
| if s.EnableRDSIAMAuth { | ||
| return s.RDSIAMAuthConnectionString() | ||
| } | ||
|
|
||
| return s.defaultPostgreSQLConnectionString() | ||
| } | ||
|
|
||
| func (s DatabaseConfiguration) Neo4jConnectionString() string { | ||
| if s.Connection == "" { | ||
| return fmt.Sprintf("neo4j://%s:%s@%s/%s", s.Username, s.Secret, s.Address, s.Database) | ||
| } | ||
|
|
||
| return s.Connection | ||
| } | ||
|
|
||
| func (s DatabaseConfiguration) LookupEndpoint() string { | ||
| host := s.Address | ||
| if hostCName, err := net.LookupCNAME(s.Address); err != nil { | ||
| slog.Warn("Error looking up CNAME for DB host. Using original address.", slog.String("err", err.Error())) | ||
| } else { | ||
| host = hostCName | ||
| } | ||
|
|
||
| // Instance endpoint always returns with a trailing '.' | ||
| return strings.TrimSuffix(host, ".") + ":5432" | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ import ( | |
| "github.com/jackc/pgx/v5/pgxpool" | ||
| "github.com/specterops/dawgs" | ||
| "github.com/specterops/dawgs/cypher/models/pgsql" | ||
| "github.com/specterops/dawgs/drivers" | ||
| "github.com/specterops/dawgs/graph" | ||
| ) | ||
|
|
||
|
|
@@ -50,15 +51,11 @@ func afterPooledConnectionRelease(conn *pgx.Conn) bool { | |
| return true | ||
| } | ||
|
|
||
| func NewPool(connectionString string) (*pgxpool.Pool, error) { | ||
| if connectionString == "" { | ||
| return nil, fmt.Errorf("graph connection requires a connection url to be set") | ||
| } | ||
|
|
||
| func NewPool(cfg drivers.DatabaseConfiguration) (*pgxpool.Pool, error) { | ||
| poolCtx, done := context.WithTimeout(context.Background(), poolInitConnectionTimeout) | ||
| defer done() | ||
|
|
||
| poolCfg, err := pgxpool.ParseConfig(connectionString) | ||
| poolCfg, err := pgxpool.ParseConfig(cfg.PostgreSQLConnectionString()) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
@@ -73,6 +70,20 @@ func NewPool(connectionString string) (*pgxpool.Pool, error) { | |
| poolCfg.AfterConnect = afterPooledConnectionEstablished | ||
| poolCfg.AfterRelease = afterPooledConnectionRelease | ||
|
|
||
| if cfg.EnableRDSIAMAuth { | ||
| // Only enable the BeforeConnect handler if RDS IAM Auth is enabled | ||
| cfg.Endpoint = cfg.LookupEndpoint() | ||
| poolCfg.BeforeConnect = func(ctx context.Context, connCfg *pgx.ConnConfig) error { | ||
| if newPoolCfg, err := pgxpool.ParseConfig(cfg.PostgreSQLConnectionString()); err != nil { | ||
| return err | ||
| } else { | ||
| connCfg.Password = newPoolCfg.ConnConfig.Password | ||
| } | ||
|
|
||
| return nil | ||
| } | ||
|
Comment on lines
+73
to
+84
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: # Find the relevant files
find . -name "pg.go" -path "*/drivers/*" | head -20Repository: SpecterOps/DAWGS Length of output: 78 🏁 Script executed: # Find config.go in drivers
find . -name "config.go" -path "*/drivers/*" | head -20Repository: SpecterOps/DAWGS Length of output: 79 🏁 Script executed: # Check overall structure
ls -la drivers/ 2>/dev/null || echo "drivers not at root"Repository: SpecterOps/DAWGS Length of output: 402 🏁 Script executed: cat -n drivers/pg/pg.goRepository: SpecterOps/DAWGS Length of output: 3664 🏁 Script executed: cat -n drivers/config.goRepository: SpecterOps/DAWGS Length of output: 3021 Cache static connection parameters during pool setup; refresh only auth tokens in
🤖 Prompt for AI Agents
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added LookupEndpoint() to run before newpool creation and if endpoint field is empty. This should reduce the number of lookups, while still allowing IAM PostgreSQLConnectionString() calls outside of pool context. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: Length of output: 3340
CNAME lookup — ✅ resolved
The fix strikes a reasonable balance: the expensive/durable operation (CNAME lookup) is cached, while the security-sensitive operation (credential + token refresh) stays live per-connection. ✏️ Learnings added
🧠 Learnings used |
||
| } | ||
|
|
||
| pool, err := pgxpool.NewWithConfig(poolCtx, poolCfg) | ||
| if err != nil { | ||
| return nil, err | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.