From 5f52c9531f1623e37581fc9296e0e6a0f5f213ce Mon Sep 17 00:00:00 2001 From: Brian O'Kelley Date: Fri, 27 Mar 2026 07:28:44 +0900 Subject: [PATCH] feat: identity agent Store abstraction with in-memory fallback Extract Store interface from go-redis direct usage. RedisStore wraps go-redis; InMemoryStore provides fallback when Valkey is unavailable. Pipeline batching via StorePipeline interface. Health endpoint. Auto-detect Valkey availability on startup and fall back gracefully. Closes #5 Co-Authored-By: Claude Opus 4.6 (1M context) --- reference/identity-agent/agent.go | 48 ++--- reference/identity-agent/agent_test.go | 103 +++++++--- reference/identity-agent/main.go | 36 +++- reference/identity-agent/store.go | 262 +++++++++++++++++++++++++ 4 files changed, 395 insertions(+), 54 deletions(-) create mode 100644 reference/identity-agent/store.go diff --git a/reference/identity-agent/agent.go b/reference/identity-agent/agent.go index 9f1ee4e..843d32e 100644 --- a/reference/identity-agent/agent.go +++ b/reference/identity-agent/agent.go @@ -6,10 +6,10 @@ import ( "encoding/hex" "fmt" "math" + "strconv" "time" "github.com/adcontextprotocol/adcp-go/tmp" - "github.com/redis/go-redis/v9" ) // FrequencyRule defines a sliding window frequency cap. @@ -33,15 +33,15 @@ type CampaignConfig struct { FrequencyRules []FrequencyRule // Campaign-level caps (all must pass) } -// IdentityAgent evaluates user eligibility using Valkey/Redis. +// IdentityAgent evaluates user eligibility using a Store (Redis or in-memory). type IdentityAgent struct { - rdb *redis.Client + store Store packages map[string]PackageConfig campaigns map[string]CampaignConfig } -// NewIdentityAgent creates an agent with the given Redis client and configs. -func NewIdentityAgent(rdb *redis.Client, packages []PackageConfig, campaigns []CampaignConfig) *IdentityAgent { +// NewIdentityAgent creates an agent with the given store and configs. +func NewIdentityAgent(store Store, packages []PackageConfig, campaigns []CampaignConfig) *IdentityAgent { pkgMap := make(map[string]PackageConfig, len(packages)) for _, p := range packages { pkgMap[p.PackageID] = p @@ -50,7 +50,7 @@ func NewIdentityAgent(rdb *redis.Client, packages []PackageConfig, campaigns []C for _, c := range campaigns { campMap[c.CampaignID] = c } - return &IdentityAgent{rdb: rdb, packages: pkgMap, campaigns: campMap} + return &IdentityAgent{store: store, packages: pkgMap, campaigns: campMap} } // IdentityMatch evaluates a user against all requested packages. @@ -130,8 +130,7 @@ func (a *IdentityAgent) IdentityMatch(ctx context.Context, req *tmp.IdentityMatc } // Expose records that a user was shown an ad for a package. -// Adds a timestamped entry to sorted sets for both package and campaign frequency. -// Uses sorted sets for sliding window frequency capping. +// Uses pipeline to batch Redis commands for efficiency. func (a *IdentityAgent) Expose(ctx context.Context, req *tmp.ExposeRequest) (*tmp.ExposeResponse, error) { tokenHash := hashToken(req.UserToken) pkg, ok := a.packages[req.PackageID] @@ -141,14 +140,13 @@ func (a *IdentityAgent) Expose(ctx context.Context, req *tmp.ExposeRequest) (*tm now := time.Now() ts := float64(now.UnixMilli()) - member := fmt.Sprintf("%d:%s", now.UnixNano(), req.PackageID) // Unique per exposure + member := fmt.Sprintf("%d:%s", now.UnixNano(), req.PackageID) - pipe := a.rdb.Pipeline() + pipe := a.store.Pipeline(ctx) // Add to package-level sorted set pkgKey := fmt.Sprintf("freq:pkg:%s:%s", req.PackageID, tokenHash) - pipe.ZAdd(ctx, pkgKey, redis.Z{Score: ts, Member: member}) - // Set TTL to longest window + buffer to auto-cleanup + pipe.ZAdd(ctx, pkgKey, ts, member) if len(pkg.FrequencyRules) > 0 { maxWindow := maxRuleWindow(pkg.FrequencyRules) pipe.Expire(ctx, pkgKey, maxWindow+time.Hour) @@ -163,7 +161,7 @@ func (a *IdentityAgent) Expose(ctx context.Context, req *tmp.ExposeRequest) (*tm var campKey string if campaignID != "" { campKey = fmt.Sprintf("freq:campaign:%s:%s", campaignID, tokenHash) - pipe.ZAdd(ctx, campKey, redis.Z{Score: ts, Member: member}) + pipe.ZAdd(ctx, campKey, ts, member) if camp, ok := a.campaigns[campaignID]; ok && len(camp.FrequencyRules) > 0 { maxWindow := maxRuleWindow(camp.FrequencyRules) pipe.Expire(ctx, campKey, maxWindow+time.Hour) @@ -174,8 +172,7 @@ func (a *IdentityAgent) Expose(ctx context.Context, req *tmp.ExposeRequest) (*tm intentKey := fmt.Sprintf("intent:%s:%s", req.PackageID, tokenHash) pipe.Set(ctx, intentKey, now.Unix(), 7*24*time.Hour) - _, err := pipe.Exec(ctx) - if err != nil { + if err := pipe.Exec(ctx); err != nil { return nil, err } @@ -186,7 +183,7 @@ func (a *IdentityAgent) Expose(ctx context.Context, req *tmp.ExposeRequest) (*tm if camp, ok := a.campaigns[campaignID]; ok && len(camp.FrequencyRules) > 0 { shortestRule := camp.FrequencyRules[0] cutoff := float64(now.Add(-shortestRule.Window).UnixMilli()) - count, _ := a.rdb.ZCount(ctx, campKey, fmt.Sprintf("%f", cutoff), "+inf").Result() + count, _ := a.store.ZCount(ctx, campKey, cutoff, 1e18) resp.CampaignCount = int(count) resp.CampaignRemaining = shortestRule.MaxCount - int(count) if resp.CampaignRemaining < 0 { @@ -200,13 +197,12 @@ func (a *IdentityAgent) Expose(ctx context.Context, req *tmp.ExposeRequest) (*tm // checkFrequencyRules checks all frequency rules against a sorted set. // Returns true (capped) if ANY rule is exceeded. -// Each rule is a sliding window: count entries within [now-window, now]. func (a *IdentityAgent) checkFrequencyRules(ctx context.Context, key string, rules []FrequencyRule) (bool, error) { now := time.Now() for _, rule := range rules { cutoff := float64(now.Add(-rule.Window).UnixMilli()) - count, err := a.rdb.ZCount(ctx, key, fmt.Sprintf("%f", cutoff), "+inf").Result() - if err != nil && err != redis.Nil { + count, err := a.store.ZCount(ctx, key, cutoff, 1e18) + if err != nil { return false, err } if int(count) >= rule.MaxCount { @@ -229,7 +225,7 @@ func maxRuleWindow(rules []FrequencyRule) time.Duration { func (a *IdentityAgent) checkAudienceMatch(ctx context.Context, tokenHash string, segments []string) (bool, error) { for _, seg := range segments { key := fmt.Sprintf("audience:%s", seg) - member, err := a.rdb.SIsMember(ctx, key, tokenHash).Result() + member, err := a.store.SIsMember(ctx, key, tokenHash) if err != nil { return false, err } @@ -242,12 +238,16 @@ func (a *IdentityAgent) checkAudienceMatch(ctx context.Context, tokenHash string func (a *IdentityAgent) computeIntentScore(ctx context.Context, tokenHash, packageID string) (float64, error) { key := fmt.Sprintf("intent:%s:%s", packageID, tokenHash) - ts, err := a.rdb.Get(ctx, key).Int64() - if err == redis.Nil { + val, err := a.store.Get(ctx, key) + if err != nil { + return 0, err + } + if val == "" { return 0, nil } + ts, err := strconv.ParseInt(val, 10, 64) if err != nil { - return 0, err + return 0, nil } hoursSince := time.Since(time.Unix(ts, 0)).Hours() score := 1.0 - (hoursSince / 168.0) @@ -263,7 +263,7 @@ func (a *IdentityAgent) LoadAudienceSegment(ctx context.Context, segmentID strin for i, tok := range userTokens { members[i] = hashToken(tok) } - return a.rdb.SAdd(ctx, key, members...).Err() + return a.store.SAdd(ctx, key, members...) } func hashToken(token string) string { diff --git a/reference/identity-agent/agent_test.go b/reference/identity-agent/agent_test.go index 5aeb8ed..c633604 100644 --- a/reference/identity-agent/agent_test.go +++ b/reference/identity-agent/agent_test.go @@ -18,7 +18,8 @@ func setupTest(t *testing.T) (*IdentityAgent, *miniredis.Miniredis) { } rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) - agent := NewIdentityAgent(rdb, + store := NewRedisStore(rdb) + agent := NewIdentityAgent(store, []PackageConfig{ { PackageID: "pkg-display-001", @@ -39,8 +40,8 @@ func setupTest(t *testing.T) (*IdentityAgent, *miniredis.Miniredis) { PackageID: "pkg-multi-rule", CampaignID: "campaign-acme", FrequencyRules: []FrequencyRule{ - {MaxCount: 2, Window: 12 * time.Hour}, // 2 per 12h - {MaxCount: 5, Window: 7 * 24 * time.Hour}, // AND 5 per week + {MaxCount: 2, Window: 12 * time.Hour}, + {MaxCount: 5, Window: 7 * 24 * time.Hour}, }, }, { @@ -86,14 +87,13 @@ func TestExpose_CampaignFrequencyCap(t *testing.T) { defer mr.Close() ctx := context.Background() - agent.LoadAudienceSegment(ctx, "cooking", []string{"user-abc"}) + _ = agent.LoadAudienceSegment(ctx, "cooking", []string{"user-abc"}) - // 5 exposures across two packages in campaign-acme (campaign cap is 5) for i := 0; i < 3; i++ { - agent.Expose(ctx, &tmp.ExposeRequest{UserToken: "user-abc", PackageID: "pkg-display-001"}) + _, _ = agent.Expose(ctx, &tmp.ExposeRequest{UserToken: "user-abc", PackageID: "pkg-display-001"}) } for i := 0; i < 2; i++ { - agent.Expose(ctx, &tmp.ExposeRequest{UserToken: "user-abc", PackageID: "pkg-display-002"}) + _, _ = agent.Expose(ctx, &tmp.ExposeRequest{UserToken: "user-abc", PackageID: "pkg-display-002"}) } resp, err := agent.IdentityMatch(ctx, &tmp.IdentityMatchRequest{ @@ -117,11 +117,10 @@ func TestExpose_PackageCappedButCampaignNot(t *testing.T) { defer mr.Close() ctx := context.Background() - agent.LoadAudienceSegment(ctx, "cooking", []string{"user-abc"}) + _ = agent.LoadAudienceSegment(ctx, "cooking", []string{"user-abc"}) - // 3 exposures on pkg-display-001 (package cap=3, campaign cap=5) for i := 0; i < 3; i++ { - agent.Expose(ctx, &tmp.ExposeRequest{UserToken: "user-abc", PackageID: "pkg-display-001"}) + _, _ = agent.Expose(ctx, &tmp.ExposeRequest{UserToken: "user-abc", PackageID: "pkg-display-001"}) } resp, err := agent.IdentityMatch(ctx, &tmp.IdentityMatchRequest{ @@ -151,10 +150,8 @@ func TestMultipleFrequencyRules(t *testing.T) { defer mr.Close() ctx := context.Background() - // pkg-multi-rule: 2 per 12h AND 5 per 7d - // Expose 2 times — should hit the 12h cap - agent.Expose(ctx, &tmp.ExposeRequest{UserToken: "user-abc", PackageID: "pkg-multi-rule"}) - agent.Expose(ctx, &tmp.ExposeRequest{UserToken: "user-abc", PackageID: "pkg-multi-rule"}) + _, _ = agent.Expose(ctx, &tmp.ExposeRequest{UserToken: "user-abc", PackageID: "pkg-multi-rule"}) + _, _ = agent.Expose(ctx, &tmp.ExposeRequest{UserToken: "user-abc", PackageID: "pkg-multi-rule"}) resp, err := agent.IdentityMatch(ctx, &tmp.IdentityMatchRequest{ RequestID: "id-test-multi", @@ -175,14 +172,12 @@ func TestSlidingWindow_OldExposuresExpire(t *testing.T) { defer mr.Close() ctx := context.Background() - agent.LoadAudienceSegment(ctx, "cooking", []string{"user-abc"}) + _ = agent.LoadAudienceSegment(ctx, "cooking", []string{"user-abc"}) - // Expose 3 times (hits package cap of 3 per 24h) for i := 0; i < 3; i++ { - agent.Expose(ctx, &tmp.ExposeRequest{UserToken: "user-abc", PackageID: "pkg-display-001"}) + _, _ = agent.Expose(ctx, &tmp.ExposeRequest{UserToken: "user-abc", PackageID: "pkg-display-001"}) } - // Should be capped now resp, _ := agent.IdentityMatch(ctx, &tmp.IdentityMatchRequest{ RequestID: "id-before", UserToken: "user-abc", PackageIDs: []string{"pkg-display-001"}, }) @@ -190,11 +185,8 @@ func TestSlidingWindow_OldExposuresExpire(t *testing.T) { t.Error("should be capped (3/3 in 24h)") } - // Fast-forward miniredis by 25 hours — exposures fall outside the 24h window mr.FastForward(25 * time.Hour) - // The sorted set entries still exist but their timestamps are now >24h old. - // ZCOUNT with the sliding window cutoff should return 0. resp, _ = agent.IdentityMatch(ctx, &tmp.IdentityMatchRequest{ RequestID: "id-after", UserToken: "user-abc", PackageIDs: []string{"pkg-display-001"}, }) @@ -208,8 +200,8 @@ func TestExpose_IntentScoreUpdated(t *testing.T) { defer mr.Close() ctx := context.Background() - agent.LoadAudienceSegment(ctx, "cooking", []string{"user-abc"}) - agent.Expose(ctx, &tmp.ExposeRequest{UserToken: "user-abc", PackageID: "pkg-display-001"}) + _ = agent.LoadAudienceSegment(ctx, "cooking", []string{"user-abc"}) + _, _ = agent.Expose(ctx, &tmp.ExposeRequest{UserToken: "user-abc", PackageID: "pkg-display-001"}) resp, err := agent.IdentityMatch(ctx, &tmp.IdentityMatchRequest{ RequestID: "id-intent", UserToken: "user-abc", PackageIDs: []string{"pkg-display-001"}, @@ -264,3 +256,68 @@ func TestUnknownPackage(t *testing.T) { t.Error("unknown package should not be eligible") } } + +// --- In-Memory Store Tests --- + +func TestInMemoryStore_FullFlow(t *testing.T) { + store := NewInMemoryStore() + agent := NewIdentityAgent(store, + []PackageConfig{ + {PackageID: "pkg-1", CampaignID: "camp-1", FrequencyRules: []FrequencyRule{{MaxCount: 2, Window: time.Hour}}}, + }, + []CampaignConfig{ + {CampaignID: "camp-1", FrequencyRules: []FrequencyRule{{MaxCount: 3, Window: 24 * time.Hour}}}, + }, + ) + ctx := context.Background() + + // Two exposures should work + _, err := agent.Expose(ctx, &tmp.ExposeRequest{UserToken: "user-1", PackageID: "pkg-1"}) + if err != nil { + t.Fatal(err) + } + _, err = agent.Expose(ctx, &tmp.ExposeRequest{UserToken: "user-1", PackageID: "pkg-1"}) + if err != nil { + t.Fatal(err) + } + + // Should now be capped + resp, err := agent.IdentityMatch(ctx, &tmp.IdentityMatchRequest{ + RequestID: "test", UserToken: "user-1", PackageIDs: []string{"pkg-1"}, + }) + if err != nil { + t.Fatal(err) + } + if resp.Eligibility[0].Eligible { + t.Error("should be capped after 2 exposures (in-memory store)") + } +} + +func TestInMemoryStore_AudienceSegments(t *testing.T) { + store := NewInMemoryStore() + agent := NewIdentityAgent(store, + []PackageConfig{ + {PackageID: "pkg-1", TargetSegments: []string{"vip"}}, + }, + nil, + ) + ctx := context.Background() + + // Not in segment + resp, _ := agent.IdentityMatch(ctx, &tmp.IdentityMatchRequest{ + RequestID: "t1", UserToken: "user-1", PackageIDs: []string{"pkg-1"}, + }) + if resp.Eligibility[0].Eligible { + t.Error("should not be eligible (not in segment)") + } + + // Load segment + _ = agent.LoadAudienceSegment(ctx, "vip", []string{"user-1"}) + + resp, _ = agent.IdentityMatch(ctx, &tmp.IdentityMatchRequest{ + RequestID: "t2", UserToken: "user-1", PackageIDs: []string{"pkg-1"}, + }) + if !resp.Eligibility[0].Eligible { + t.Error("should be eligible after segment load") + } +} diff --git a/reference/identity-agent/main.go b/reference/identity-agent/main.go index b40d7e0..6570ae2 100644 --- a/reference/identity-agent/main.go +++ b/reference/identity-agent/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "flag" "log" @@ -18,7 +19,19 @@ func main() { rdb := redis.NewClient(&redis.Options{Addr: *redisAddr}) - agent := NewIdentityAgent(rdb, + // Try to connect to Valkey; fall back to in-memory if unavailable + var store Store + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := rdb.Ping(ctx).Err(); err != nil { + log.Printf("Valkey unavailable (%v), using in-memory store", err) + store = NewInMemoryStore() + } else { + log.Printf("Connected to Valkey at %s", *redisAddr) + store = NewRedisStore(rdb) + } + + agent := NewIdentityAgent(store, []PackageConfig{ {PackageID: "pkg-display-0041", CampaignID: "campaign-acme-q1", FrequencyRules: []FrequencyRule{{MaxCount: 5, Window: 24 * time.Hour}}, TargetSegments: []string{"cooking_enthusiast", "home_improvement"}}, {PackageID: "pkg-display-0042", CampaignID: "campaign-acme-q1", FrequencyRules: []FrequencyRule{{MaxCount: 3, Window: 12 * time.Hour}}}, @@ -36,34 +49,43 @@ func main() { var req tmp.IdentityMatchRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(tmp.ErrorResponse{Code: tmp.ErrorCodeInvalidRequest, Message: err.Error()}) + _ = json.NewEncoder(w).Encode(tmp.ErrorResponse{Code: tmp.ErrorCodeInvalidRequest, Message: err.Error()}) return } resp, err := agent.IdentityMatch(r.Context(), &req) if err != nil { w.WriteHeader(http.StatusInternalServerError) - json.NewEncoder(w).Encode(tmp.ErrorResponse{RequestID: req.RequestID, Code: tmp.ErrorCodeInternalError, Message: err.Error()}) + _ = json.NewEncoder(w).Encode(tmp.ErrorResponse{RequestID: req.RequestID, Code: tmp.ErrorCodeInternalError, Message: err.Error()}) return } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) + _ = json.NewEncoder(w).Encode(resp) }) mux.HandleFunc("POST /tmp/expose", func(w http.ResponseWriter, r *http.Request) { var req tmp.ExposeRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(tmp.ErrorResponse{Code: tmp.ErrorCodeInvalidRequest, Message: err.Error()}) + _ = json.NewEncoder(w).Encode(tmp.ErrorResponse{Code: tmp.ErrorCodeInvalidRequest, Message: err.Error()}) return } resp, err := agent.Expose(r.Context(), &req) if err != nil { w.WriteHeader(http.StatusInternalServerError) - json.NewEncoder(w).Encode(tmp.ErrorResponse{Code: tmp.ErrorCodeInternalError, Message: err.Error()}) + _ = json.NewEncoder(w).Encode(tmp.ErrorResponse{Code: tmp.ErrorCodeInternalError, Message: err.Error()}) return } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) + _ = json.NewEncoder(w).Encode(resp) + }) + + mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) { + if err := store.Ping(r.Context()); err != nil { + w.WriteHeader(http.StatusServiceUnavailable) + _ = json.NewEncoder(w).Encode(map[string]string{"status": "degraded", "store": "in-memory"}) + return + } + _ = json.NewEncoder(w).Encode(map[string]string{"status": "ok", "store": "valkey"}) }) log.Printf("Identity Agent listening on %s, Valkey at %s", *addr, *redisAddr) diff --git a/reference/identity-agent/store.go b/reference/identity-agent/store.go new file mode 100644 index 0000000..53e96aa --- /dev/null +++ b/reference/identity-agent/store.go @@ -0,0 +1,262 @@ +package main + +import ( + "context" + "fmt" + "sort" + "sync" + "time" + + "github.com/redis/go-redis/v9" +) + +// Store abstracts the sorted-set and set operations needed by the identity agent. +// RedisStore wraps go-redis; InMemoryStore provides a fallback when Valkey is down. +type Store interface { + ZAdd(ctx context.Context, key string, score float64, member string) error + ZCount(ctx context.Context, key string, min, max float64) (int64, error) + Expire(ctx context.Context, key string, ttl time.Duration) error + SIsMember(ctx context.Context, key, member string) (bool, error) + SAdd(ctx context.Context, key string, members ...interface{}) error + Get(ctx context.Context, key string) (string, error) + Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error + Pipeline(ctx context.Context) StorePipeline + Ping(ctx context.Context) error +} + +// StorePipeline batches commands for execution. +type StorePipeline interface { + ZAdd(ctx context.Context, key string, score float64, member string) + Expire(ctx context.Context, key string, ttl time.Duration) + Set(ctx context.Context, key string, value interface{}, ttl time.Duration) + Exec(ctx context.Context) error +} + +// --- Redis Implementation --- + +// RedisStore wraps go-redis. +type RedisStore struct { + rdb *redis.Client +} + +func NewRedisStore(rdb *redis.Client) *RedisStore { + return &RedisStore{rdb: rdb} +} + +func (s *RedisStore) ZAdd(ctx context.Context, key string, score float64, member string) error { + return s.rdb.ZAdd(ctx, key, redis.Z{Score: score, Member: member}).Err() +} + +func (s *RedisStore) ZCount(ctx context.Context, key string, min, max float64) (int64, error) { + minStr := fmt.Sprintf("%f", min) + maxStr := "+inf" + if max < 1e18 { + maxStr = fmt.Sprintf("%f", max) + } + return s.rdb.ZCount(ctx, key, minStr, maxStr).Result() +} + +func (s *RedisStore) Expire(ctx context.Context, key string, ttl time.Duration) error { + return s.rdb.Expire(ctx, key, ttl).Err() +} + +func (s *RedisStore) SIsMember(ctx context.Context, key, member string) (bool, error) { + return s.rdb.SIsMember(ctx, key, member).Result() +} + +func (s *RedisStore) SAdd(ctx context.Context, key string, members ...interface{}) error { + return s.rdb.SAdd(ctx, key, members...).Err() +} + +func (s *RedisStore) Get(ctx context.Context, key string) (string, error) { + val, err := s.rdb.Get(ctx, key).Result() + if err == redis.Nil { + return "", nil + } + return val, err +} + +func (s *RedisStore) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error { + return s.rdb.Set(ctx, key, value, ttl).Err() +} + +func (s *RedisStore) Pipeline(_ context.Context) StorePipeline { + return &RedisPipeline{pipe: s.rdb.Pipeline()} +} + +func (s *RedisStore) Ping(ctx context.Context) error { + return s.rdb.Ping(ctx).Err() +} + +// RedisPipeline wraps go-redis pipeline. +type RedisPipeline struct { + pipe redis.Pipeliner +} + +func (p *RedisPipeline) ZAdd(ctx context.Context, key string, score float64, member string) { + p.pipe.ZAdd(ctx, key, redis.Z{Score: score, Member: member}) +} + +func (p *RedisPipeline) Expire(ctx context.Context, key string, ttl time.Duration) { + p.pipe.Expire(ctx, key, ttl) +} + +func (p *RedisPipeline) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) { + p.pipe.Set(ctx, key, value, ttl) +} + +func (p *RedisPipeline) Exec(ctx context.Context) error { + _, err := p.pipe.Exec(ctx) + return err +} + +// --- In-Memory Implementation --- + +type zsetEntry struct { + score float64 + member string +} + +type memKey struct { + value interface{} + expires time.Time +} + +// InMemoryStore provides a fallback when Valkey is unavailable. +type InMemoryStore struct { + mu sync.RWMutex + zsets map[string][]zsetEntry + sets map[string]map[string]struct{} + kvs map[string]memKey +} + +func NewInMemoryStore() *InMemoryStore { + return &InMemoryStore{ + zsets: make(map[string][]zsetEntry), + sets: make(map[string]map[string]struct{}), + kvs: make(map[string]memKey), + } +} + +func (s *InMemoryStore) ZAdd(_ context.Context, key string, score float64, member string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.zsets[key] = append(s.zsets[key], zsetEntry{score: score, member: member}) + return nil +} + +func (s *InMemoryStore) ZCount(_ context.Context, key string, min, max float64) (int64, error) { + s.mu.RLock() + defer s.mu.RUnlock() + var count int64 + for _, e := range s.zsets[key] { + if e.score >= min && (max >= 1e18 || e.score <= max) { + count++ + } + } + return count, nil +} + +func (s *InMemoryStore) Expire(_ context.Context, _ string, _ time.Duration) error { + return nil // TTL not implemented for in-memory (acceptable for fallback) +} + +func (s *InMemoryStore) SIsMember(_ context.Context, key, member string) (bool, error) { + s.mu.RLock() + defer s.mu.RUnlock() + set, ok := s.sets[key] + if !ok { + return false, nil + } + _, found := set[member] + return found, nil +} + +func (s *InMemoryStore) SAdd(_ context.Context, key string, members ...interface{}) error { + s.mu.Lock() + defer s.mu.Unlock() + set, ok := s.sets[key] + if !ok { + set = make(map[string]struct{}) + s.sets[key] = set + } + for _, m := range members { + set[fmt.Sprintf("%v", m)] = struct{}{} + } + return nil +} + +func (s *InMemoryStore) Get(_ context.Context, key string) (string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + kv, ok := s.kvs[key] + if !ok { + return "", nil + } + if !kv.expires.IsZero() && time.Now().After(kv.expires) { + return "", nil + } + return fmt.Sprintf("%v", kv.value), nil +} + +func (s *InMemoryStore) Set(_ context.Context, key string, value interface{}, ttl time.Duration) error { + s.mu.Lock() + defer s.mu.Unlock() + var expires time.Time + if ttl > 0 { + expires = time.Now().Add(ttl) + } + s.kvs[key] = memKey{value: value, expires: expires} + return nil +} + +func (s *InMemoryStore) Pipeline(_ context.Context) StorePipeline { + return &InMemoryPipeline{store: s} +} + +func (s *InMemoryStore) Ping(_ context.Context) error { + return nil +} + +// ZEntries returns sorted entries for a key (for testing). +func (s *InMemoryStore) ZEntries(key string) []zsetEntry { + s.mu.RLock() + defer s.mu.RUnlock() + entries := make([]zsetEntry, len(s.zsets[key])) + copy(entries, s.zsets[key]) + sort.Slice(entries, func(i, j int) bool { return entries[i].score < entries[j].score }) + return entries +} + +// InMemoryPipeline collects operations and executes them sequentially. +type InMemoryPipeline struct { + store *InMemoryStore + ops []func(context.Context) error +} + +func (p *InMemoryPipeline) ZAdd(_ context.Context, key string, score float64, member string) { + p.ops = append(p.ops, func(ctx context.Context) error { + return p.store.ZAdd(ctx, key, score, member) + }) +} + +func (p *InMemoryPipeline) Expire(_ context.Context, key string, ttl time.Duration) { + p.ops = append(p.ops, func(ctx context.Context) error { + return p.store.Expire(ctx, key, ttl) + }) +} + +func (p *InMemoryPipeline) Set(_ context.Context, key string, value interface{}, ttl time.Duration) { + p.ops = append(p.ops, func(ctx context.Context) error { + return p.store.Set(ctx, key, value, ttl) + }) +} + +func (p *InMemoryPipeline) Exec(ctx context.Context) error { + for _, op := range p.ops { + if err := op(ctx); err != nil { + return err + } + } + return nil +}