From 121350fb6a562875267704054dbd5f22d3ad6346 Mon Sep 17 00:00:00 2001 From: vietddude Date: Tue, 24 Mar 2026 14:44:13 +0700 Subject: [PATCH 1/9] fix(client): scope result routing by client ID --- cmd/mpcium/main.go | 12 +- pkg/client/client.go | 97 +++++++++--- pkg/client/client_test.go | 211 +++++++++++++++++++++++--- pkg/event/result_topics.go | 67 ++++++++ pkg/event/result_topics_test.go | 85 +++++++++++ pkg/eventconsumer/event_consumer.go | 58 ++++--- pkg/eventconsumer/keygen_consumer.go | 14 +- pkg/eventconsumer/sign_consumer.go | 18 ++- pkg/eventconsumer/timeout_consumer.go | 7 +- pkg/messaging/jetstream_broker.go | 13 +- pkg/messaging/message_queue.go | 14 +- pkg/messaging/pubsub.go | 22 ++- pkg/mpc/ecdsa_signing_session.go | 5 +- pkg/mpc/eddsa_signing_session.go | 10 +- pkg/mpc/key_exchange_session.go | 2 +- pkg/mpc/node.go | 3 + pkg/mpc/session.go | 3 +- 17 files changed, 532 insertions(+), 109 deletions(-) create mode 100644 pkg/event/result_topics.go create mode 100644 pkg/event/result_topics_test.go diff --git a/cmd/mpcium/main.go b/cmd/mpcium/main.go index 418caae4..a5043107 100644 --- a/cmd/mpcium/main.go +++ b/cmd/mpcium/main.go @@ -216,17 +216,13 @@ func runNode(ctx context.Context, c *cli.Command) error { } directMessaging := messaging.NewNatsDirectMessaging(natsConn) - mqManager := messaging.NewNATsMessageQueueManager("mpc", []string{ - "mpc.mpc_keygen_result.*", - event.SigningResultTopic, - "mpc.mpc_reshare_result.*", - }, natsConn) + mqManager := messaging.NewNATsMessageQueueManager("mpc", event.ResultStreamSubjects(), natsConn) - genKeyResultQueue := mqManager.NewMessageQueue("mpc_keygen_result") + genKeyResultQueue := mqManager.NewMessageQueue("mpc_keygen_result", event.KeygenResultSubscriptionSubject("")) defer genKeyResultQueue.Close() - singingResultQueue := mqManager.NewMessageQueue("mpc_signing_result") + singingResultQueue := mqManager.NewMessageQueue("mpc_signing_result", event.SigningResultSubscriptionSubject("")) defer singingResultQueue.Close() - reshareResultQueue := mqManager.NewMessageQueue("mpc_reshare_result") + reshareResultQueue := mqManager.NewMessageQueue("mpc_reshare_result", event.ReshareResultSubscriptionSubject("")) defer reshareResultQueue.Close() logger.Info("Starting mpcium node", "version", Version, "ID", nodeID, "name", nodeName) diff --git a/pkg/client/client.go b/pkg/client/client.go index 6c64c140..5b5c240f 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "strings" "github.com/fystack/mpcium/pkg/event" "github.com/fystack/mpcium/pkg/eventconsumer" @@ -13,11 +14,6 @@ import ( "github.com/nats-io/nats.go" ) -const ( - GenerateWalletSuccessTopic = "mpc.mpc_keygen_result.*" // wildcard to listen to all success events - ResharingSuccessTopic = "mpc.mpc_reshare_result.*" // wildcard to listen to all success events -) - type MPCClient interface { CreateWallet(walletID string) error CreateWalletWithAuthorizers(walletID string, authorizerSignatures []types.AuthorizerSignature) error @@ -38,6 +34,7 @@ type mpcClient struct { signResultQueue messaging.MessageQueue reshareSuccessQueue messaging.MessageQueue signer Signer + clientID string } // Options defines configuration options for creating a new MPCClient @@ -49,13 +46,42 @@ type Options struct { Signer Signer } +type ClientOption func(*clientConfig) + +type clientConfig struct { + clientID string +} + +type clientResultRouting struct { + keygenConsumerName string + keygenSubject string + signingConsumerName string + signingSubject string + reshareConsumerName string + reshareSubject string +} + +func WithClientID(id string) ClientOption { + return func(cfg *clientConfig) { + cfg.clientID = id + } +} + // NewMPCClient creates a new MPC client using the provided options. // The signer must be provided to handle message signing. -func NewMPCClient(opts Options) MPCClient { +func NewMPCClient(opts Options, clientOptions ...ClientOption) MPCClient { if opts.Signer == nil { logger.Fatal("Signer is required", nil) } + cfg := clientConfig{} + for _, opt := range clientOptions { + opt(&cfg) + } + if err := validateClientID(cfg.clientID); err != nil { + logger.Fatal("Invalid client ID", err) + } + // 2) Create the PubSub for both publish & subscribe signingBroker, err := messaging.NewJetStreamBroker( context.Background(), @@ -82,15 +108,12 @@ func NewMPCClient(opts Options) MPCClient { pubsub := messaging.NewNATSPubSub(opts.NatsConn) - manager := messaging.NewNATsMessageQueueManager("mpc", []string{ - "mpc.mpc_keygen_result.*", - "mpc.mpc_signing_result.*", - "mpc.mpc_reshare_result.*", - }, opts.NatsConn) + manager := messaging.NewNATsMessageQueueManager("mpc", event.ResultStreamSubjects(), opts.NatsConn) + routing := buildClientResultRouting(cfg.clientID) - genKeySuccessQueue := manager.NewMessageQueue("mpc_keygen_result") - signResultQueue := manager.NewMessageQueue("mpc_signing_result") - reshareSuccessQueue := manager.NewMessageQueue("mpc_reshare_result") + genKeySuccessQueue := manager.NewMessageQueue(routing.keygenConsumerName, routing.keygenSubject) + signResultQueue := manager.NewMessageQueue(routing.signingConsumerName, routing.signingSubject) + reshareSuccessQueue := manager.NewMessageQueue(routing.reshareConsumerName, routing.reshareSubject) return &mpcClient{ signingBroker: signingBroker, @@ -100,6 +123,7 @@ func NewMPCClient(opts Options) MPCClient { signResultQueue: signResultQueue, reshareSuccessQueue: reshareSuccessQueue, signer: opts.Signer, + clientID: cfg.clientID, } } @@ -131,7 +155,7 @@ func (c *mpcClient) CreateWalletWithAuthorizers(walletID string, authorizerSigna return fmt.Errorf("CreateWallet: marshal error: %w", err) } - if err := c.keygenBroker.PublishMessage(context.Background(), event.KeygenRequestTopic, bytes); err != nil { + if err := c.keygenBroker.PublishMessage(context.Background(), event.KeygenRequestTopic, bytes, c.requestHeaders()); err != nil { return fmt.Errorf("CreateWallet: publish error: %w", err) } return nil @@ -139,7 +163,7 @@ func (c *mpcClient) CreateWalletWithAuthorizers(walletID string, authorizerSigna // The callback will be invoked whenever a wallet creation result is received. func (c *mpcClient) OnWalletCreationResult(callback func(event event.KeygenResultEvent)) error { - err := c.genKeySuccessQueue.Dequeue(GenerateWalletSuccessTopic, func(msg []byte) error { + err := c.genKeySuccessQueue.Dequeue(event.KeygenResultSubscriptionSubject(c.clientID), func(msg []byte) error { var event event.KeygenResultEvent err := json.Unmarshal(msg, &event) if err != nil { @@ -174,14 +198,14 @@ func (c *mpcClient) SignTransaction(msg *types.SignTxMessage) error { return fmt.Errorf("SignTransaction: marshal error: %w", err) } - if err := c.signingBroker.PublishMessage(context.Background(), event.SigningRequestTopic, bytes); err != nil { + if err := c.signingBroker.PublishMessage(context.Background(), event.SigningRequestTopic, bytes, c.requestHeaders()); err != nil { return fmt.Errorf("SignTransaction: publish error: %w", err) } return nil } func (c *mpcClient) OnSignResult(callback func(event event.SigningResultEvent)) error { - err := c.signResultQueue.Dequeue(event.SigningResultCompleteTopic, func(msg []byte) error { + err := c.signResultQueue.Dequeue(event.SigningResultSubscriptionSubject(c.clientID), func(msg []byte) error { var event event.SigningResultEvent err := json.Unmarshal(msg, &event) if err != nil { @@ -215,14 +239,14 @@ func (c *mpcClient) Resharing(msg *types.ResharingMessage) error { return fmt.Errorf("Resharing: marshal error: %w", err) } - if err := c.pubsub.Publish(eventconsumer.MPCReshareEvent, bytes); err != nil { + if err := c.pubsub.Publish(eventconsumer.MPCReshareEvent, bytes, c.requestHeaders()); err != nil { return fmt.Errorf("Resharing: publish error: %w", err) } return nil } func (c *mpcClient) OnResharingResult(callback func(event event.ResharingResultEvent)) error { - err := c.reshareSuccessQueue.Dequeue(ResharingSuccessTopic, func(msg []byte) error { + err := c.reshareSuccessQueue.Dequeue(event.ReshareResultSubscriptionSubject(c.clientID), func(msg []byte) error { logger.Info("Received reshare success message", "raw", string(msg)) var event event.ResharingResultEvent err := json.Unmarshal(msg, &event) @@ -241,3 +265,36 @@ func (c *mpcClient) OnResharingResult(callback func(event event.ResharingResultE return nil } + +func (c *mpcClient) requestHeaders() map[string]string { + if c.clientID == "" { + return nil + } + return map[string]string{ + event.ClientIDHeader: c.clientID, + } +} + +func buildClientResultRouting(clientID string) clientResultRouting { + return clientResultRouting{ + keygenConsumerName: event.ResultConsumerName("mpc_keygen_result", clientID), + keygenSubject: event.KeygenResultSubscriptionSubject(clientID), + signingConsumerName: event.ResultConsumerName("mpc_signing_result", clientID), + signingSubject: event.SigningResultSubscriptionSubject(clientID), + reshareConsumerName: event.ResultConsumerName("mpc_reshare_result", clientID), + reshareSubject: event.ReshareResultSubscriptionSubject(clientID), + } +} + +func validateClientID(clientID string) error { + if clientID == "" { + return nil + } + if strings.TrimSpace(clientID) == "" { + return fmt.Errorf("client ID cannot be blank") + } + if strings.ContainsAny(clientID, " \t\r\n.*>") { + return fmt.Errorf("client ID must be a single NATS subject token") + } + return nil +} diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go index 36ac514c..64b350e6 100644 --- a/pkg/client/client_test.go +++ b/pkg/client/client_test.go @@ -1,11 +1,16 @@ package client import ( + "context" "errors" "testing" + "github.com/fystack/mpcium/pkg/event" + "github.com/fystack/mpcium/pkg/eventconsumer" + "github.com/fystack/mpcium/pkg/messaging" "github.com/fystack/mpcium/pkg/types" "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -31,6 +36,71 @@ func (m *MockSigner) PublicKey() (string, error) { return args.String(0), args.Error(1) } +type recordingBroker struct { + subject string + data []byte + headers map[string]string + calls int +} + +func (b *recordingBroker) PublishMessage(_ context.Context, subject string, data []byte, headers map[string]string) error { + b.subject = subject + b.data = append([]byte(nil), data...) + b.headers = copyHeaders(headers) + b.calls++ + return nil +} + +func (b *recordingBroker) CreateSubscription(context.Context, string, string, func(jetstream.Msg)) (messaging.MessageSubscription, error) { + return nil, nil +} + +func (b *recordingBroker) GetStreamInfo(context.Context) (*jetstream.StreamInfo, error) { + return nil, nil +} + +func (b *recordingBroker) FetchMessages(context.Context, string, string, int, func(jetstream.Msg)) error { + return nil +} + +func (b *recordingBroker) Close() error { + return nil +} + +type recordingPubSub struct { + topic string + message []byte + headers map[string]string + calls int +} + +func (p *recordingPubSub) Publish(topic string, message []byte, headers map[string]string) error { + p.topic = topic + p.message = append([]byte(nil), message...) + p.headers = copyHeaders(headers) + p.calls++ + return nil +} + +func (p *recordingPubSub) PublishWithReply(string, string, []byte, map[string]string) error { + return nil +} + +func (p *recordingPubSub) Subscribe(string, func(msg *nats.Msg)) (messaging.Subscription, error) { + return nil, nil +} + +func copyHeaders(headers map[string]string) map[string]string { + if headers == nil { + return nil + } + cloned := make(map[string]string, len(headers)) + for k, v := range headers { + cloned[k] = v + } + return cloned +} + // MockNATSConn creates a mock NATS connection for testing func MockNATSConn() *nats.Conn { // For unit tests, we can return nil and handle it appropriately in tests @@ -66,7 +136,7 @@ func TestNewMPCClient_NoSigner(t *testing.T) { func TestMPCClient_CreateWallet(t *testing.T) { mockSigner := &MockSigner{} - + // Set up expectations testSignature := []byte("test-signature") mockSigner.On("Sign", mock.AnythingOfType("[]uint8")).Return(testSignature, nil) @@ -79,20 +149,20 @@ func TestMPCClient_CreateWallet(t *testing.T) { // Test CreateWallet - this will test the signing logic // Note: This test would require mocking the messaging broker as well // For now, we test that the signer is called correctly - + walletID := "test-wallet-123" - + // We can't fully test CreateWallet without mocking the broker, // but we can test the signing part by calling it directly - + // Simulate what CreateWallet does with signing msg := &types.GenerateKeyMessage{ WalletID: walletID, } - + raw, err := msg.Raw() require.NoError(t, err) - + signature, err := client.signer.Sign(raw) require.NoError(t, err) assert.Equal(t, testSignature, signature) @@ -103,7 +173,7 @@ func TestMPCClient_CreateWallet(t *testing.T) { func TestMPCClient_CreateWallet_SigningError(t *testing.T) { mockSigner := &MockSigner{} - + // Set up signer to return error mockSigner.On("Sign", mock.AnythingOfType("[]uint8")).Return([]byte(nil), errors.New("signing failed")) @@ -115,10 +185,10 @@ func TestMPCClient_CreateWallet_SigningError(t *testing.T) { msg := &types.GenerateKeyMessage{ WalletID: "test-wallet", } - + raw, err := msg.Raw() require.NoError(t, err) - + signature, err := client.signer.Sign(raw) assert.Error(t, err) assert.Nil(t, signature) @@ -129,7 +199,7 @@ func TestMPCClient_CreateWallet_SigningError(t *testing.T) { func TestMPCClient_SignTransaction(t *testing.T) { mockSigner := &MockSigner{} - + // Set up expectations testSignature := []byte("test-transaction-signature") mockSigner.On("Sign", mock.AnythingOfType("[]uint8")).Return(testSignature, nil) @@ -146,10 +216,10 @@ func TestMPCClient_SignTransaction(t *testing.T) { TxID: "test-tx-123", Tx: []byte("test transaction data"), } - + raw, err := msg.Raw() require.NoError(t, err) - + signature, err := client.signer.Sign(raw) require.NoError(t, err) assert.Equal(t, testSignature, signature) @@ -159,7 +229,7 @@ func TestMPCClient_SignTransaction(t *testing.T) { func TestMPCClient_Resharing(t *testing.T) { mockSigner := &MockSigner{} - + // Set up expectations testSignature := []byte("test-resharing-signature") mockSigner.On("Sign", mock.AnythingOfType("[]uint8")).Return(testSignature, nil) @@ -176,10 +246,10 @@ func TestMPCClient_Resharing(t *testing.T) { KeyType: types.KeyTypeSecp256k1, WalletID: "test-wallet", } - + raw, err := msg.Raw() require.NoError(t, err) - + signature, err := client.signer.Sign(raw) require.NoError(t, err) assert.Equal(t, testSignature, signature) @@ -187,10 +257,113 @@ func TestMPCClient_Resharing(t *testing.T) { mockSigner.AssertExpectations(t) } +func TestMPCClient_CreateWallet_PublishesClientIDHeader(t *testing.T) { + mockSigner := &MockSigner{} + mockSigner.On("Sign", mock.AnythingOfType("[]uint8")).Return([]byte("sig"), nil) + + broker := &recordingBroker{} + client := &mpcClient{ + signer: mockSigner, + keygenBroker: broker, + clientID: "svc-a", + } + + err := client.CreateWallet("wallet-1") + require.NoError(t, err) + + assert.Equal(t, 1, broker.calls) + assert.Equal(t, event.KeygenRequestTopic, broker.subject) + assert.Equal(t, map[string]string{event.ClientIDHeader: "svc-a"}, broker.headers) + mockSigner.AssertExpectations(t) +} + +func TestMPCClient_SignTransaction_PublishesClientIDHeader(t *testing.T) { + mockSigner := &MockSigner{} + mockSigner.On("Sign", mock.AnythingOfType("[]uint8")).Return([]byte("sig"), nil) + + broker := &recordingBroker{} + client := &mpcClient{ + signer: mockSigner, + signingBroker: broker, + clientID: "svc-a", + } + + err := client.SignTransaction(&types.SignTxMessage{ + KeyType: types.KeyTypeSecp256k1, + WalletID: "wallet-1", + NetworkInternalCode: "eth-mainnet", + TxID: "tx-1", + Tx: []byte("payload"), + }) + require.NoError(t, err) + + assert.Equal(t, 1, broker.calls) + assert.Equal(t, event.SigningRequestTopic, broker.subject) + assert.Equal(t, map[string]string{event.ClientIDHeader: "svc-a"}, broker.headers) + mockSigner.AssertExpectations(t) +} + +func TestMPCClient_Resharing_PublishesClientIDHeader(t *testing.T) { + mockSigner := &MockSigner{} + mockSigner.On("Sign", mock.AnythingOfType("[]uint8")).Return([]byte("sig"), nil) + + pubsub := &recordingPubSub{} + client := &mpcClient{ + signer: mockSigner, + pubsub: pubsub, + clientID: "svc-a", + } + + err := client.Resharing(&types.ResharingMessage{ + SessionID: "reshare-1", + NodeIDs: []string{"node-1", "node-2", "node-3"}, + NewThreshold: 2, + KeyType: types.KeyTypeSecp256k1, + WalletID: "wallet-1", + }) + require.NoError(t, err) + + assert.Equal(t, 1, pubsub.calls) + assert.Equal(t, eventconsumer.MPCReshareEvent, pubsub.topic) + assert.Equal(t, map[string]string{event.ClientIDHeader: "svc-a"}, pubsub.headers) + mockSigner.AssertExpectations(t) +} + +func TestBuildClientResultRouting(t *testing.T) { + t.Run("legacy", func(t *testing.T) { + routing := buildClientResultRouting("") + assert.Equal(t, "mpc_keygen_result", routing.keygenConsumerName) + assert.Equal(t, "mpc.mpc_keygen_result.*", routing.keygenSubject) + assert.Equal(t, "mpc_signing_result", routing.signingConsumerName) + assert.Equal(t, "mpc.mpc_signing_result.complete", routing.signingSubject) + assert.Equal(t, "mpc_reshare_result", routing.reshareConsumerName) + assert.Equal(t, "mpc.mpc_reshare_result.*", routing.reshareSubject) + }) + + t.Run("scoped", func(t *testing.T) { + routing := buildClientResultRouting("svc-a") + assert.Equal(t, "mpc_keygen_result.svc-a", routing.keygenConsumerName) + assert.Equal(t, "mpc.mpc_keygen_result.svc-a.*", routing.keygenSubject) + assert.Equal(t, "mpc_signing_result.svc-a", routing.signingConsumerName) + assert.Equal(t, "mpc.mpc_signing_result.svc-a.complete", routing.signingSubject) + assert.Equal(t, "mpc_reshare_result.svc-a", routing.reshareConsumerName) + assert.Equal(t, "mpc.mpc_reshare_result.svc-a.*", routing.reshareSubject) + }) +} + +func TestValidateClientID(t *testing.T) { + assert.NoError(t, validateClientID("")) + assert.NoError(t, validateClientID("svc-a")) + + assert.Error(t, validateClientID("svc a")) + assert.Error(t, validateClientID("svc.a")) + assert.Error(t, validateClientID("svc*")) +} + func TestSignerInterface_Compliance(t *testing.T) { // Test that our mock signer implements the interface correctly mockSigner := &MockSigner{} - + // Set up mock expectations mockSigner.On("Algorithm").Return(types.EventInitiatorKeyTypeP256) mockSigner.On("PublicKey").Return("mock-public-key-hex", nil) @@ -215,7 +388,7 @@ func TestSignerInterface_Compliance(t *testing.T) { func TestSignerInterface_ErrorHandling(t *testing.T) { mockSigner := &MockSigner{} - + // Set up error cases mockSigner.On("PublicKey").Return("", errors.New("public key error")) mockSigner.On("Sign", mock.Anything).Return([]byte(nil), errors.New("signing error")) @@ -288,7 +461,7 @@ func createTestMPCClient(signer Signer) *mpcClient { func TestCreateTestMPCClient(t *testing.T) { mockSigner := &MockSigner{} client := createTestMPCClient(mockSigner) - + assert.NotNil(t, client) assert.Equal(t, mockSigner, client.signer) -} \ No newline at end of file +} diff --git a/pkg/event/result_topics.go b/pkg/event/result_topics.go new file mode 100644 index 00000000..a78442c3 --- /dev/null +++ b/pkg/event/result_topics.go @@ -0,0 +1,67 @@ +package event + +import "strings" + +const ClientIDHeader = "ClientID" + +const ( + keygenResultSubjectPrefix = "mpc.mpc_keygen_result" + signingResultSubjectPrefix = "mpc.mpc_signing_result" + reshareResultSubjectPrefix = "mpc.mpc_reshare_result" + signingResultCompleteToken = "complete" +) + +func ResultStreamSubjects() []string { + return []string{ + keygenResultSubjectPrefix + ".>", + signingResultSubjectPrefix + ".>", + reshareResultSubjectPrefix + ".>", + } +} + +func KeygenResultSubject(clientID, walletID string) string { + return scopedSubject(keygenResultSubjectPrefix, clientID, walletID) +} + +func KeygenResultSubscriptionSubject(clientID string) string { + return scopedSubject(keygenResultSubjectPrefix, clientID, "*") +} + +func SigningResultSubject(clientID string) string { + return scopedSubject(signingResultSubjectPrefix, clientID, signingResultCompleteToken) +} + +func SigningResultSubscriptionSubject(clientID string) string { + return SigningResultSubject(clientID) +} + +func ReshareResultSubject(clientID, sessionID string) string { + return scopedSubject(reshareResultSubjectPrefix, clientID, sessionID) +} + +func ReshareResultSubscriptionSubject(clientID string) string { + return scopedSubject(reshareResultSubjectPrefix, clientID, "*") +} + +func ResultConsumerName(base, clientID string) string { + if clientID == "" { + return base + } + return base + "." + clientID +} + +func ScopedOperationID(clientID, operationID string) string { + if clientID == "" { + return operationID + } + return clientID + ":" + operationID +} + +func scopedSubject(prefix, clientID, tail string) string { + parts := []string{prefix} + if clientID != "" { + parts = append(parts, clientID) + } + parts = append(parts, tail) + return strings.Join(parts, ".") +} diff --git a/pkg/event/result_topics_test.go b/pkg/event/result_topics_test.go new file mode 100644 index 00000000..deeb3437 --- /dev/null +++ b/pkg/event/result_topics_test.go @@ -0,0 +1,85 @@ +package event + +import "testing" + +func TestResultStreamSubjects(t *testing.T) { + subjects := ResultStreamSubjects() + expected := []string{ + "mpc.mpc_keygen_result.>", + "mpc.mpc_signing_result.>", + "mpc.mpc_reshare_result.>", + } + + if len(subjects) != len(expected) { + t.Fatalf("unexpected subject count: got %d want %d", len(subjects), len(expected)) + } + + for i := range expected { + if subjects[i] != expected[i] { + t.Fatalf("unexpected subject at index %d: got %q want %q", i, subjects[i], expected[i]) + } + } +} + +func TestScopedResultSubjects(t *testing.T) { + tests := []struct { + name string + clientID string + keygenResult string + keygenSubscription string + signingResult string + reshareResult string + reshareSubscription string + keygenConsumerName string + scopedOperationIdentifier string + }{ + { + name: "legacy", + clientID: "", + keygenResult: "mpc.mpc_keygen_result.wallet-1", + keygenSubscription: "mpc.mpc_keygen_result.*", + signingResult: "mpc.mpc_signing_result.complete", + reshareResult: "mpc.mpc_reshare_result.session-1", + reshareSubscription: "mpc.mpc_reshare_result.*", + keygenConsumerName: "mpc_keygen_result", + scopedOperationIdentifier: "wallet-1", + }, + { + name: "scoped", + clientID: "svc-a", + keygenResult: "mpc.mpc_keygen_result.svc-a.wallet-1", + keygenSubscription: "mpc.mpc_keygen_result.svc-a.*", + signingResult: "mpc.mpc_signing_result.svc-a.complete", + reshareResult: "mpc.mpc_reshare_result.svc-a.session-1", + reshareSubscription: "mpc.mpc_reshare_result.svc-a.*", + keygenConsumerName: "mpc_keygen_result.svc-a", + scopedOperationIdentifier: "svc-a:wallet-1", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := KeygenResultSubject(tc.clientID, "wallet-1"); got != tc.keygenResult { + t.Fatalf("unexpected keygen result subject: got %q want %q", got, tc.keygenResult) + } + if got := KeygenResultSubscriptionSubject(tc.clientID); got != tc.keygenSubscription { + t.Fatalf("unexpected keygen subscription subject: got %q want %q", got, tc.keygenSubscription) + } + if got := SigningResultSubject(tc.clientID); got != tc.signingResult { + t.Fatalf("unexpected signing result subject: got %q want %q", got, tc.signingResult) + } + if got := ReshareResultSubject(tc.clientID, "session-1"); got != tc.reshareResult { + t.Fatalf("unexpected reshare result subject: got %q want %q", got, tc.reshareResult) + } + if got := ReshareResultSubscriptionSubject(tc.clientID); got != tc.reshareSubscription { + t.Fatalf("unexpected reshare subscription subject: got %q want %q", got, tc.reshareSubscription) + } + if got := ResultConsumerName("mpc_keygen_result", tc.clientID); got != tc.keygenConsumerName { + t.Fatalf("unexpected consumer name: got %q want %q", got, tc.keygenConsumerName) + } + if got := ScopedOperationID(tc.clientID, "wallet-1"); got != tc.scopedOperationIdentifier { + t.Fatalf("unexpected scoped operation id: got %q want %q", got, tc.scopedOperationIdentifier) + } + }) + } +} diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index 9e2ecb30..4b94e2d6 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -25,9 +25,9 @@ const ( MPCSignEvent = "mpc:sign" MPCReshareEvent = "mpc:reshare" - DefaultConcurrentKeygen = 2 - DefaultConcurrentSigning = 20 - KeyGenTimeOut = 30 * time.Second + DefaultConcurrentKeygen = 2 + DefaultConcurrentSigning = 20 + KeyGenTimeOut = 30 * time.Second ) type EventConsumer interface { @@ -272,7 +272,7 @@ func (ec *eventConsumer) handleKeyGenEvent(natMsg *nats.Msg) { return } - key := fmt.Sprintf(mpc.TypeGenerateWalletResultFmt, walletID) + key := event.KeygenResultSubject(natMsg.Header.Get(event.ClientIDHeader), walletID) if err := ec.genKeyResultQueue.Enqueue( key, payload, @@ -305,7 +305,7 @@ func (ec *eventConsumer) handleKeygenSessionError(walletID string, err error, co return } - key := fmt.Sprintf(mpc.TypeGenerateWalletResultFmt, walletID) + key := event.KeygenResultSubject(natMsg.Header.Get(event.ClientIDHeader), walletID) err = ec.genKeyResultQueue.Enqueue(key, keygenResultBytes, &messaging.EnqueueOptions{ IdempotententKey: composeKeygenIdempotentKey(walletID, natMsg), }) @@ -405,6 +405,7 @@ func (ec *eventConsumer) handleSigningEvent(natMsg *nats.Msg) { var session mpc.SigningSession idempotentKey := composeSigningIdempotentKey(msg.TxID, natMsg) + resultTopic := event.SigningResultSubject(natMsg.Header.Get(event.ClientIDHeader)) var sessionErr error switch msg.KeyType { case types.KeyTypeSecp256k1: @@ -413,6 +414,7 @@ func (ec *eventConsumer) handleSigningEvent(natMsg *nats.Msg) { msg.WalletID, msg.TxID, msg.NetworkInternalCode, + resultTopic, ec.signingResultQueue, msg.DerivationPath, idempotentKey, @@ -423,6 +425,7 @@ func (ec *eventConsumer) handleSigningEvent(natMsg *nats.Msg) { msg.WalletID, msg.TxID, msg.NetworkInternalCode, + resultTopic, ec.signingResultQueue, msg.DerivationPath, idempotentKey, @@ -568,7 +571,7 @@ func (ec *eventConsumer) handleSigningSessionError(walletID, txID, networkIntern ) return } - err = ec.signingResultQueue.Enqueue(event.SigningResultCompleteTopic, signingResultBytes, &messaging.EnqueueOptions{ + err = ec.signingResultQueue.Enqueue(event.SigningResultSubject(natMsg.Header.Get(event.ClientIDHeader)), signingResultBytes, &messaging.EnqueueOptions{ IdempotententKey: composeSigningIdempotentKey(txID, natMsg), }) if err != nil { @@ -589,7 +592,7 @@ func (ec *eventConsumer) sendReplyToRemoveMsg(natMsg *nats.Msg) { return } - err := ec.pubsub.Publish(natMsg.Reply, msg) + err := ec.pubsub.Publish(natMsg.Reply, msg, nil) if err != nil { logger.Error("Failed to reply message", err, "reply", natMsg.Reply) return @@ -601,12 +604,13 @@ func (ec *eventConsumer) consumeReshareEvent() error { var msg types.ResharingMessage if err := json.Unmarshal(natMsg.Data, &msg); err != nil { logger.Error("Failed to unmarshal resharing message", err) - ec.handleReshareSessionError(msg.WalletID, msg.KeyType, msg.NewThreshold, err, "Failed to unmarshal resharing message", natMsg) + ec.handleReshareSessionError(msg.SessionID, msg.WalletID, msg.KeyType, msg.NewThreshold, err, "Failed to unmarshal resharing message", natMsg) return } if msg.SessionID == "" { ec.handleReshareSessionError( + msg.SessionID, msg.WalletID, msg.KeyType, msg.NewThreshold, @@ -619,13 +623,13 @@ func (ec *eventConsumer) consumeReshareEvent() error { if err := ec.identityStore.VerifyInitiatorMessage(&msg); err != nil { logger.Error("Failed to verify initiator message", err) - ec.handleReshareSessionError(msg.WalletID, msg.KeyType, msg.NewThreshold, err, "Failed to verify initiator message", natMsg) + ec.handleReshareSessionError(msg.SessionID, msg.WalletID, msg.KeyType, msg.NewThreshold, err, "Failed to verify initiator message", natMsg) return } if err := ec.identityStore.AuthorizeInitiatorMessage(&msg); err != nil { logger.Error("Failed to authorize initiator message", err) - ec.handleReshareSessionError(msg.WalletID, msg.KeyType, msg.NewThreshold, err, "Failed to authorize initiator message", natMsg) + ec.handleReshareSessionError(msg.SessionID, msg.WalletID, msg.KeyType, msg.NewThreshold, err, "Failed to authorize initiator message", natMsg) return } @@ -635,7 +639,7 @@ func (ec *eventConsumer) consumeReshareEvent() error { sessionType, err := sessionTypeFromKeyType(keyType) if err != nil { logger.Error("Failed to get session type", err) - ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "Failed to get session type", natMsg) + ec.handleReshareSessionError(msg.SessionID, walletID, keyType, msg.NewThreshold, err, "Failed to get session type", natMsg) return } @@ -653,13 +657,13 @@ func (ec *eventConsumer) consumeReshareEvent() error { oldSession, err := createSession(false) if err != nil { logger.Error("Failed to create old reshare session", err, "walletID", walletID) - ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "Failed to create old reshare session", natMsg) + ec.handleReshareSessionError(msg.SessionID, walletID, keyType, msg.NewThreshold, err, "Failed to create old reshare session", natMsg) return } newSession, err := createSession(true) if err != nil { logger.Error("Failed to create new reshare session", err, "walletID", walletID) - ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "Failed to create new reshare session", natMsg) + ec.handleReshareSessionError(msg.SessionID, walletID, keyType, msg.NewThreshold, err, "Failed to create new reshare session", natMsg) return } @@ -681,7 +685,7 @@ func (ec *eventConsumer) consumeReshareEvent() error { if oldSession != nil { err := oldSession.Init() if err != nil { - ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "Failed to init old reshare session", natMsg) + ec.handleReshareSessionError(msg.SessionID, walletID, keyType, msg.NewThreshold, err, "Failed to init old reshare session", natMsg) return } oldSession.ListenToIncomingMessageAsync() @@ -690,7 +694,7 @@ func (ec *eventConsumer) consumeReshareEvent() error { if newSession != nil { err := newSession.Init() if err != nil { - ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "Failed to init new reshare session", natMsg) + ec.handleReshareSessionError(msg.SessionID, walletID, keyType, msg.NewThreshold, err, "Failed to init new reshare session", natMsg) return } newSession.ListenToIncomingMessageAsync() @@ -727,7 +731,7 @@ func (ec *eventConsumer) consumeReshareEvent() error { } reshareBarrierWg.Wait() if reshareBarrierErr != nil { - ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, reshareBarrierErr, "Peers not ready before resharing", natMsg) + ec.handleReshareSessionError(msg.SessionID, walletID, keyType, msg.NewThreshold, reshareBarrierErr, "Peers not ready before resharing", natMsg) return } @@ -742,7 +746,7 @@ func (ec *eventConsumer) consumeReshareEvent() error { return case err := <-oldSession.ErrChan(): logger.Error("Old reshare session error", err) - ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "Old reshare session error", natMsg) + ec.handleReshareSessionError(msg.SessionID, walletID, keyType, msg.NewThreshold, err, "Old reshare session error", natMsg) doneOld() return } @@ -761,7 +765,7 @@ func (ec *eventConsumer) consumeReshareEvent() error { return case err := <-newSession.ErrChan(): logger.Error("New reshare session error", err) - ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "New reshare session error", natMsg) + ec.handleReshareSessionError(msg.SessionID, walletID, keyType, msg.NewThreshold, err, "New reshare session error", natMsg) doneNew() return } @@ -776,11 +780,11 @@ func (ec *eventConsumer) consumeReshareEvent() error { successBytes, err := json.Marshal(successEvent) if err != nil { logger.Error("Failed to marshal reshare success event", err) - ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "Failed to marshal reshare success event", natMsg) + ec.handleReshareSessionError(msg.SessionID, walletID, keyType, msg.NewThreshold, err, "Failed to marshal reshare success event", natMsg) return } - key := fmt.Sprintf(mpc.TypeReshareWalletResultFmt, msg.SessionID) + key := event.ReshareResultSubject(natMsg.Header.Get(event.ClientIDHeader), msg.SessionID) err = ec.reshareResultQueue.Enqueue( key, successBytes, @@ -789,7 +793,7 @@ func (ec *eventConsumer) consumeReshareEvent() error { }) if err != nil { logger.Error("Failed to publish reshare success message", err) - ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "Failed to publish reshare success message", natMsg) + ec.handleReshareSessionError(msg.SessionID, walletID, keyType, msg.NewThreshold, err, "Failed to publish reshare success message", natMsg) return } logger.Info("[COMPLETED RESHARE] Successfully published", "walletID", walletID) @@ -804,6 +808,7 @@ func (ec *eventConsumer) consumeReshareEvent() error { // handleReshareSessionError handles errors that occur during reshare operations func (ec *eventConsumer) handleReshareSessionError( + sessionID string, walletID string, keyType types.KeyType, newThreshold int, @@ -840,9 +845,14 @@ func (ec *eventConsumer) handleReshareSessionError( return } - key := fmt.Sprintf(mpc.TypeReshareWalletResultFmt, walletID) + if sessionID == "" { + logger.Warn("Skipping reshare result publish because session ID is empty", "walletID", walletID) + return + } + + key := event.ReshareResultSubject(natMsg.Header.Get(event.ClientIDHeader), sessionID) err = ec.reshareResultQueue.Enqueue(key, reshareResultBytes, &messaging.EnqueueOptions{ - IdempotententKey: composeReshareIdempotentKey(walletID, natMsg), + IdempotententKey: composeReshareIdempotentKey(sessionID, natMsg), }) if err != nil { logger.Error("Failed to enqueue reshare result event", err, @@ -954,7 +964,7 @@ func composeIdempotentKey(baseID string, natMsg *nats.Msg, formatTemplate string } else { uniqueKey = baseID } - return fmt.Sprintf(formatTemplate, uniqueKey) + return fmt.Sprintf(formatTemplate, event.ScopedOperationID(natMsg.Header.Get(event.ClientIDHeader), uniqueKey)) } func composeKeygenIdempotentKey(walletID string, natMsg *nats.Msg) string { diff --git a/pkg/eventconsumer/keygen_consumer.go b/pkg/eventconsumer/keygen_consumer.go index e2f07146..630ee133 100644 --- a/pkg/eventconsumer/keygen_consumer.go +++ b/pkg/eventconsumer/keygen_consumer.go @@ -132,18 +132,19 @@ func (sc *keygenConsumer) handleKeygenEvent(msg jetstream.Msg) { raw := msg.Data() var keygenMsg types.GenerateKeyMessage sessionID := msg.Headers().Get("SessionID") + clientID := msg.Headers().Get(event.ClientIDHeader) err := json.Unmarshal(raw, &keygenMsg) if err != nil { logger.Error("SigningConsumer: Failed to unmarshal keygen message", err) - sc.handleKeygenError(keygenMsg, event.ErrorCodeUnmarshalFailure, err, sessionID) + sc.handleKeygenError(keygenMsg, event.ErrorCodeUnmarshalFailure, err, sessionID, clientID) _ = msg.Ack() return } if !sc.peerRegistry.ArePeersReady() { logger.Warn("KeygenConsumer: Not all peers are ready to gen key, skipping message processing") - sc.handleKeygenError(keygenMsg, event.ErrorCodeClusterNotReady, errors.New("not all peers are ready"), sessionID) + sc.handleKeygenError(keygenMsg, event.ErrorCodeClusterNotReady, errors.New("not all peers are ready"), sessionID, clientID) _ = msg.Ack() return } @@ -168,6 +169,9 @@ func (sc *keygenConsumer) handleKeygenEvent(msg jetstream.Msg) { headers := map[string]string{ "SessionID": uuid.New().String(), } + if clientID != "" { + headers[event.ClientIDHeader] = clientID + } if err := sc.pubsub.PublishWithReply(MPCGenerateEvent, replyInbox, msg.Data(), headers); err != nil { logger.Error("KeygenConsumer: Failed to publish keygen event with reply", err) _ = msg.Nak() @@ -206,7 +210,7 @@ func (sc *keygenConsumer) handleKeygenEvent(msg jetstream.Msg) { _ = msg.Nak() } -func (sc *keygenConsumer) handleKeygenError(keygenMsg types.GenerateKeyMessage, errorCode event.ErrorCode, err error, sessionID string) { +func (sc *keygenConsumer) handleKeygenError(keygenMsg types.GenerateKeyMessage, errorCode event.ErrorCode, err error, sessionID, clientID string) { keygenResult := event.KeygenResultEvent{ ResultType: event.ResultTypeError, ErrorCode: string(errorCode), @@ -222,9 +226,9 @@ func (sc *keygenConsumer) handleKeygenError(keygenMsg types.GenerateKeyMessage, return } - topic := fmt.Sprintf(mpc.TypeGenerateWalletResultFmt, keygenResult.WalletID) + topic := event.KeygenResultSubject(clientID, keygenResult.WalletID) err = sc.keygenResultQueue.Enqueue(topic, keygenResultBytes, &messaging.EnqueueOptions{ - IdempotententKey: buildIdempotentKey(keygenMsg.WalletID, sessionID, mpc.TypeGenerateWalletResultFmt), + IdempotententKey: buildIdempotentKey(keygenMsg.WalletID, clientID, sessionID, mpc.TypeGenerateWalletResultFmt), }) if err != nil { logger.Error("Failed to enqueue keygen result event", err, diff --git a/pkg/eventconsumer/sign_consumer.go b/pkg/eventconsumer/sign_consumer.go index b202a337..cab1d4f7 100644 --- a/pkg/eventconsumer/sign_consumer.go +++ b/pkg/eventconsumer/sign_consumer.go @@ -152,11 +152,12 @@ func (sc *signingConsumer) handleSigningEvent(msg jetstream.Msg) { raw := msg.Data() var signingMsg types.SignTxMessage sessionID := msg.Headers().Get("SessionID") + clientID := msg.Headers().Get(event.ClientIDHeader) err := json.Unmarshal(raw, &signingMsg) if err != nil { logger.Error("SigningConsumer: Failed to unmarshal signing message", err) - sc.handleSigningError(signingMsg, event.ErrorCodeUnmarshalFailure, err, sessionID) + sc.handleSigningError(signingMsg, event.ErrorCodeUnmarshalFailure, err, sessionID, clientID) _ = msg.Ack() return } @@ -164,7 +165,7 @@ func (sc *signingConsumer) handleSigningEvent(msg jetstream.Msg) { if !sc.peerRegistry.AreMajorityReady() { requiredPeers := int64(sc.mpcThreshold + 1) err := fmt.Errorf("not enough peers to process signing request: ready=%d, required=%d", sc.peerRegistry.GetReadyPeersCount(), requiredPeers) - sc.handleSigningError(signingMsg, event.ErrorCodeNotMajority, err, sessionID) + sc.handleSigningError(signingMsg, event.ErrorCodeNotMajority, err, sessionID, clientID) _ = msg.Ack() return } @@ -188,6 +189,9 @@ func (sc *signingConsumer) handleSigningEvent(msg jetstream.Msg) { headers := map[string]string{ "SessionID": uuid.New().String(), } + if clientID != "" { + headers[event.ClientIDHeader] = clientID + } if err := sc.pubsub.PublishWithReply(MPCSignEvent, replyInbox, msg.Data(), headers); err != nil { logger.Error("SigningConsumer: Failed to publish signing event with reply", err) _ = msg.Nak() @@ -227,7 +231,7 @@ func (sc *signingConsumer) handleSigningEvent(msg jetstream.Msg) { _ = msg.Nak() } -func (sc *signingConsumer) handleSigningError(signMsg types.SignTxMessage, errorCode event.ErrorCode, err error, sessionID string) { +func (sc *signingConsumer) handleSigningError(signMsg types.SignTxMessage, errorCode event.ErrorCode, err error, sessionID, clientID string) { signingResult := event.SigningResultEvent{ ResultType: event.ResultTypeError, ErrorCode: errorCode, @@ -246,8 +250,8 @@ func (sc *signingConsumer) handleSigningError(signMsg types.SignTxMessage, error return } - err = sc.signingResultQueue.Enqueue(event.SigningResultCompleteTopic, signingResultBytes, &messaging.EnqueueOptions{ - IdempotententKey: buildIdempotentKey(signMsg.TxID, sessionID, mpc.TypeSigningResultFmt), + err = sc.signingResultQueue.Enqueue(event.SigningResultSubject(clientID), signingResultBytes, &messaging.EnqueueOptions{ + IdempotententKey: buildIdempotentKey(signMsg.TxID, clientID, sessionID, mpc.TypeSigningResultFmt), }) if err != nil { logger.Error("Failed to enqueue signing result event", err, @@ -269,12 +273,12 @@ func (sc *signingConsumer) Close() error { return nil } -func buildIdempotentKey(baseID string, sessionID string, formatTemplate string) string { +func buildIdempotentKey(baseID, clientID, sessionID, formatTemplate string) string { var uniqueKey string if sessionID != "" { uniqueKey = fmt.Sprintf("%s:%s", baseID, sessionID) } else { uniqueKey = baseID } - return fmt.Sprintf(formatTemplate, uniqueKey) + return fmt.Sprintf(formatTemplate, event.ScopedOperationID(clientID, uniqueKey)) } diff --git a/pkg/eventconsumer/timeout_consumer.go b/pkg/eventconsumer/timeout_consumer.go index bd911700..73569a53 100644 --- a/pkg/eventconsumer/timeout_consumer.go +++ b/pkg/eventconsumer/timeout_consumer.go @@ -2,10 +2,12 @@ package eventconsumer import ( "encoding/json" + "fmt" "github.com/fystack/mpcium/pkg/event" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" + "github.com/fystack/mpcium/pkg/mpc" "github.com/nats-io/nats.go" ) @@ -50,6 +52,7 @@ func (tc *timeOutConsumer) Run() { logger.Error("Failed to retrieve message", err) return } + clientID := failedMsg.Header.Get(event.ClientIDHeader) data := failedMsg.Data var signErrorResult event.SigningResultEvent @@ -71,8 +74,8 @@ func (tc *timeOutConsumer) Run() { return } - err = tc.resultQueue.Enqueue(event.SigningResultTopic, signErrorResultBytes, &messaging.EnqueueOptions{ - IdempotententKey: signErrorResult.TxID, + err = tc.resultQueue.Enqueue(event.SigningResultSubject(clientID), signErrorResultBytes, &messaging.EnqueueOptions{ + IdempotententKey: fmt.Sprintf(mpc.TypeSigningResultFmt, event.ScopedOperationID(clientID, signErrorResult.TxID)), }) if err != nil { logger.Error("Failed to publish signing result event", err) diff --git a/pkg/messaging/jetstream_broker.go b/pkg/messaging/jetstream_broker.go index 8f8a6069..ed0127ac 100644 --- a/pkg/messaging/jetstream_broker.go +++ b/pkg/messaging/jetstream_broker.go @@ -33,7 +33,7 @@ const ( ) type MessageBroker interface { - PublishMessage(ctx context.Context, subject string, data []byte) error + PublishMessage(ctx context.Context, subject string, data []byte, headers map[string]string) error CreateSubscription(ctx context.Context, consumerName, subject string, handler func(msg jetstream.Msg)) (MessageSubscription, error) GetStreamInfo(ctx context.Context) (*jetstream.StreamInfo, error) FetchMessages(ctx context.Context, consumerName, subject string, batchSize int, handler func(msg jetstream.Msg)) error @@ -211,12 +211,19 @@ func (b *jetStreamBroker) ensureStreamExists(ctx context.Context) error { return nil } -func (b *jetStreamBroker) PublishMessage(ctx context.Context, subject string, data []byte) error { +func (b *jetStreamBroker) PublishMessage(ctx context.Context, subject string, data []byte, headers map[string]string) error { if b.conn.IsClosed() { return ErrConnectionClosed } - _, err := b.js.Publish(ctx, subject, data) + msg := &nats.Msg{ + Subject: subject, + Data: data, + Header: nats.Header{}, + } + applyHeaders(msg.Header, headers) + + _, err := b.js.PublishMsg(ctx, msg) if err != nil { return fmt.Errorf("failed to publish message to subject %s: %w", subject, err) } diff --git a/pkg/messaging/message_queue.go b/pkg/messaging/message_queue.go index 6eec56eb..445f4704 100644 --- a/pkg/messaging/message_queue.go +++ b/pkg/messaging/message_queue.go @@ -73,25 +73,25 @@ func NewNATsMessageQueueManager(queueName string, subjectWildCards []string, nc } } -func (m *NATsMessageQueueManager) NewMessageQueue(consumerName string) MessageQueue { +func (m *NATsMessageQueueManager) NewMessageQueue(consumerName, filterSubject string) MessageQueue { + sanitizedConsumerName := sanitizeConsumerName(consumerName) mq := &msgQueue{ - consumerName: consumerName, + consumerName: sanitizedConsumerName, js: m.js, } - consumerWildCard := fmt.Sprintf("%s.%s.*", m.queueName, consumerName) cfg := jetstream.ConsumerConfig{ - Name: consumerName, - Durable: consumerName, + Name: sanitizedConsumerName, + Durable: sanitizedConsumerName, MaxAckPending: 1000, // If a message isn't acked within AckWait, it will be redelivered up to MaxDelive AckWait: 30 * time.Second, AckPolicy: jetstream.AckExplicitPolicy, FilterSubjects: []string{ - consumerWildCard, + filterSubject, }, MaxDeliver: 3, } - logger.Info("Creating consumer for subject", "consumerName", consumerName, "queueName", m.queueName, "filterSubject", consumerWildCard, "config", cfg) + logger.Info("Creating consumer for subject", "consumerName", sanitizedConsumerName, "queueName", m.queueName, "filterSubject", filterSubject, "config", cfg) consumer, err := m.js.CreateOrUpdateConsumer(context.Background(), m.queueName, cfg) if err != nil { logger.Fatal("Error creating JetStream consumer: ", err) diff --git a/pkg/messaging/pubsub.go b/pkg/messaging/pubsub.go index 8e4fd0ea..1b613541 100644 --- a/pkg/messaging/pubsub.go +++ b/pkg/messaging/pubsub.go @@ -10,7 +10,7 @@ type Subscription interface { } type PubSub interface { - Publish(topic string, message []byte) error + Publish(topic string, message []byte, headers map[string]string) error PublishWithReply(topic, reply string, data []byte, headers map[string]string) error Subscribe(topic string, handler func(msg *nats.Msg)) (Subscription, error) } @@ -31,9 +31,15 @@ func NewNATSPubSub(natsConn *nats.Conn) PubSub { return &natsPubSub{natsConn} } -func (n *natsPubSub) Publish(topic string, message []byte) error { +func (n *natsPubSub) Publish(topic string, message []byte, headers map[string]string) error { logger.Debug("[NATS] Publishing message", "topic", topic) - return n.natsConn.Publish(topic, message) + msg := &nats.Msg{ + Subject: topic, + Data: message, + Header: nats.Header{}, + } + applyHeaders(msg.Header, headers) + return n.natsConn.PublishMsg(msg) } func (n *natsPubSub) PublishWithReply(topic, reply string, data []byte, headers map[string]string) error { @@ -43,9 +49,7 @@ func (n *natsPubSub) PublishWithReply(topic, reply string, data []byte, headers Data: data, Header: nats.Header{}, } - for k, v := range headers { - msg.Header.Set(k, v) - } + applyHeaders(msg.Header, headers) err := n.natsConn.PublishMsg(msg) return err } @@ -61,3 +65,9 @@ func (n *natsPubSub) Subscribe(topic string, handler func(msg *nats.Msg)) (Subsc return &natsSubscription{subscription: sub}, nil } + +func applyHeaders(dst nats.Header, headers map[string]string) { + for k, v := range headers { + dst.Set(k, v) + } +} diff --git a/pkg/mpc/ecdsa_signing_session.go b/pkg/mpc/ecdsa_signing_session.go index 9664e088..5a555d57 100644 --- a/pkg/mpc/ecdsa_signing_session.go +++ b/pkg/mpc/ecdsa_signing_session.go @@ -51,6 +51,7 @@ func newECDSASigningSession( walletID string, txID string, networkInternalCode string, + resultTopic string, pubSub messaging.PubSub, direct messaging.DirectMessaging, participantPeerIDs []string, @@ -95,6 +96,7 @@ func newECDSASigningSession( }, getRoundFunc: GetEcdsaMsgRound, resultQueue: resultQueue, + resultTopic: resultTopic, identityStore: identityStore, idempotentKey: idempotentKey, }, @@ -145,7 +147,6 @@ func (s *ecdsaSigningSession) Init(tx *big.Int) error { if err != nil { return errors.Wrap(err, "Failed to unmarshal wallet data") } - if len(s.derivationPath) > 0 { il, extendedChildPk, errorDerivation := s.ckd.Derive(s.walletID, data.ECDSAPub, s.derivationPath, tss.S256()) @@ -215,7 +216,7 @@ func (s *ecdsaSigningSession) Sign(onSuccess func(data []byte)) { return } - err = s.resultQueue.Enqueue(event.SigningResultCompleteTopic, bytes, &messaging.EnqueueOptions{ + err = s.resultQueue.Enqueue(s.resultTopic, bytes, &messaging.EnqueueOptions{ IdempotententKey: s.idempotentKey, }) if err != nil { diff --git a/pkg/mpc/eddsa_signing_session.go b/pkg/mpc/eddsa_signing_session.go index 2c11fda4..8741c856 100644 --- a/pkg/mpc/eddsa_signing_session.go +++ b/pkg/mpc/eddsa_signing_session.go @@ -36,6 +36,7 @@ func newEDDSASigningSession( walletID string, txID string, networkInternalCode string, + resultTopic string, pubSub messaging.PubSub, direct messaging.DirectMessaging, participantPeerIDs []string, @@ -62,8 +63,8 @@ func newEDDSASigningSession( outCh: make(chan tss.Message), ErrCh: make(chan error, 1), doneCh: make(chan struct{}), - kvstore: kvstore, - keyinfoStore: keyinfoStore, + kvstore: kvstore, + keyinfoStore: keyinfoStore, topicComposer: &TopicComposer{ ComposeBroadcastTopic: func() string { return fmt.Sprintf("sign:eddsa:broadcast:%s:%s", walletID, txID) @@ -77,6 +78,7 @@ func newEDDSASigningSession( }, getRoundFunc: GetEddsaMsgRound, resultQueue: resultQueue, + resultTopic: resultTopic, identityStore: identityStore, idempotentKey: idempotentKey, }, @@ -126,7 +128,6 @@ func (s *eddsaSigningSession) Init(tx *big.Int) error { if err != nil { return errors.Wrap(err, "Failed to unmarshal wallet data") } - if len(s.derivationPath) > 0 { il, extendedChildPk, errorDerivation := s.ckd.Derive(s.walletID, data.EDDSAPub, s.derivationPath, tss.Edwards()) @@ -194,7 +195,7 @@ func (s *eddsaSigningSession) Sign(onSuccess func(data []byte)) { return } - err = s.resultQueue.Enqueue(event.SigningResultCompleteTopic, bytes, &messaging.EnqueueOptions{ + err = s.resultQueue.Enqueue(s.resultTopic, bytes, &messaging.EnqueueOptions{ IdempotententKey: s.idempotentKey, }) if err != nil { @@ -214,6 +215,7 @@ func (s *eddsaSigningSession) Sign(onSuccess func(data []byte)) { } } } + // Close cleans up the EDDSA signing session by zeroing all sensitive data. func (s *eddsaSigningSession) Close() error { if s == nil { diff --git a/pkg/mpc/key_exchange_session.go b/pkg/mpc/key_exchange_session.go index 8da1774b..d8065540 100644 --- a/pkg/mpc/key_exchange_session.go +++ b/pkg/mpc/key_exchange_session.go @@ -167,7 +167,7 @@ func (e *ecdhSession) BroadcastPublicKey() error { signedMsgBytes, _ := json.Marshal(msg) logger.Info("Starting to broadcast DH key", "nodeID", e.nodeID) - if err := e.pubSub.Publish(ECDHExchangeTopic, signedMsgBytes); err != nil { + if err := e.pubSub.Publish(ECDHExchangeTopic, signedMsgBytes, nil); err != nil { return fmt.Errorf("%s failed to publish DH message because %w", e.nodeID, err) } return nil diff --git a/pkg/mpc/node.go b/pkg/mpc/node.go index faa842a9..e28cf3f0 100644 --- a/pkg/mpc/node.go +++ b/pkg/mpc/node.go @@ -147,6 +147,7 @@ func (p *Node) CreateSigningSession( walletID string, txID string, networkInternalCode string, + resultTopic string, resultQueue messaging.MessageQueue, derivationPath []uint32, idempotentKey string, @@ -185,6 +186,7 @@ func (p *Node) CreateSigningSession( walletID, txID, networkInternalCode, + resultTopic, p.pubSub, p.direct, readyParticipantIDs, @@ -206,6 +208,7 @@ func (p *Node) CreateSigningSession( walletID, txID, networkInternalCode, + resultTopic, p.pubSub, p.direct, readyParticipantIDs, diff --git a/pkg/mpc/session.go b/pkg/mpc/session.go index c75e05b6..1510848b 100644 --- a/pkg/mpc/session.go +++ b/pkg/mpc/session.go @@ -80,6 +80,7 @@ type session struct { barrierSub messaging.Subscription resultQueue messaging.MessageQueue + resultTopic string identityStore identity.Store topicComposer *TopicComposer @@ -167,7 +168,7 @@ func (s *session) handleTssMessage(keyshare tss.Message) { return } - err = s.pubSub.Publish(s.topicComposer.ComposeBroadcastTopic(), msg) + err = s.pubSub.Publish(s.topicComposer.ComposeBroadcastTopic(), msg, nil) if err != nil { s.sendErr(err) return From 5be6479361997f26d9535fe608aaf908c9b39f65 Mon Sep 17 00:00:00 2001 From: vietddude Date: Tue, 24 Mar 2026 14:44:35 +0700 Subject: [PATCH 2/9] chore(examples): accept client-id for scoped routing --- e2e/base_test.go | 2 +- examples/authorizers/generate/main.go | 2 +- examples/authorizers/sign/main.go | 6 +++++- examples/ckd/main.go | 6 +++++- examples/generate/kms/main.go | 3 ++- examples/generate/main.go | 3 ++- examples/hdwallet/ecdsa/main.go | 6 +++++- examples/hdwallet/eddsa/main.go | 6 +++++- examples/reshare/main.go | 6 +++++- examples/sign/main.go | 10 +++++++--- 10 files changed, 38 insertions(+), 12 deletions(-) diff --git a/e2e/base_test.go b/e2e/base_test.go index 1ad62347..bcc010fb 100644 --- a/e2e/base_test.go +++ b/e2e/base_test.go @@ -196,7 +196,7 @@ func (s *E2ETestSuite) SetupMPCClient(t *testing.T) { mpcClient := client.NewMPCClient(client.Options{ NatsConn: s.natsConn, Signer: localSigner, - }) + }, client.WithClientID("e2e-suite")) s.mpcClient = mpcClient t.Log("MPC client created") } diff --git a/examples/authorizers/generate/main.go b/examples/authorizers/generate/main.go index 0cfec028..08da5e01 100644 --- a/examples/authorizers/generate/main.go +++ b/examples/authorizers/generate/main.go @@ -89,7 +89,7 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, Signer: localSigner, - }) + }, client.WithClientID(*clientID)) var walletStartTimes sync.Map var walletIDs []string diff --git a/examples/authorizers/sign/main.go b/examples/authorizers/sign/main.go index 6f5d3ef5..a5e14eed 100644 --- a/examples/authorizers/sign/main.go +++ b/examples/authorizers/sign/main.go @@ -2,6 +2,7 @@ package main import ( "encoding/hex" + "flag" "fmt" "os" "os/signal" @@ -23,6 +24,9 @@ var requiredAuthorizers = []string{"authorizer1", "authorizer2"} func main() { const environment = "dev" + clientID := flag.String("client-id", "example-authorizers-sign", "Client ID used to scope result routing") + flag.Parse() + config.InitViperConfig("") logger.Init(environment, true) @@ -80,7 +84,7 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, Signer: localSigner, - }) + }, client.WithClientID(*clientID)) // Create a signing request with authorizers txID := uuid.New().String() diff --git a/examples/ckd/main.go b/examples/ckd/main.go index 29c193fc..621f19b4 100644 --- a/examples/ckd/main.go +++ b/examples/ckd/main.go @@ -1,6 +1,7 @@ package main import ( + "flag" "fmt" "os" "os/signal" @@ -19,6 +20,9 @@ import ( func main() { const environment = "dev" + clientID := flag.String("client-id", "example-ckd", "Client ID used to scope result routing") + flag.Parse() + config.InitViperConfig("") logger.Init(environment, true) @@ -63,7 +67,7 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, Signer: localSigner, - }) + }, client.WithClientID(*clientID)) // 2) Once wallet exists, immediately fire a SignTransaction txID := uuid.New().String() diff --git a/examples/generate/kms/main.go b/examples/generate/kms/main.go index 9bf40250..91cc4358 100644 --- a/examples/generate/kms/main.go +++ b/examples/generate/kms/main.go @@ -27,6 +27,7 @@ func main() { const kmsKeyID = "48e76117-fd08-4dc0-bd10-b1c7d01de748" numWallets := flag.Int("n", 1, "Number of wallets to generate") + clientID := flag.String("client-id", "example-generate-kms", "Client ID used to scope result routing") flag.Parse() @@ -65,7 +66,7 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, Signer: kmsSigner, - }) + }, client.WithClientID(*clientID)) var walletStartTimes sync.Map var walletIDs []string diff --git a/examples/generate/main.go b/examples/generate/main.go index 3f0135f1..7209ebfc 100644 --- a/examples/generate/main.go +++ b/examples/generate/main.go @@ -25,6 +25,7 @@ import ( func main() { const environment = "development" numWallets := flag.Int("n", 1, "Number of wallets to generate") + clientID := flag.String("client-id", "example-generate", "Client ID used to scope result routing") flag.Parse() @@ -72,7 +73,7 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, Signer: localSigner, - }) + }, client.WithClientID(*clientID)) var walletStartTimes sync.Map var walletIDs []string diff --git a/examples/hdwallet/ecdsa/main.go b/examples/hdwallet/ecdsa/main.go index f73774f9..2e3365f6 100644 --- a/examples/hdwallet/ecdsa/main.go +++ b/examples/hdwallet/ecdsa/main.go @@ -3,6 +3,7 @@ package main import ( "crypto/ecdsa" "encoding/hex" + "flag" "fmt" "math/big" "os" @@ -49,6 +50,9 @@ func main() { fmt.Println() const environment = "dev" + clientID := flag.String("client-id", "example-hdwallet-ecdsa", "Client ID used to scope result routing") + flag.Parse() + config.InitViperConfig("") logger.Init(environment, true) @@ -93,7 +97,7 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, Signer: localSigner, - }) + }, client.WithClientID(*clientID)) // Step 1: Generate ONE master wallet fmt.Println("Step 1: Generating master MPC wallet...") diff --git a/examples/hdwallet/eddsa/main.go b/examples/hdwallet/eddsa/main.go index 2654cf75..7950a297 100644 --- a/examples/hdwallet/eddsa/main.go +++ b/examples/hdwallet/eddsa/main.go @@ -2,6 +2,7 @@ package main import ( "encoding/hex" + "flag" "fmt" "os" "os/signal" @@ -52,6 +53,9 @@ func main() { fmt.Println() const environment = "dev" + clientID := flag.String("client-id", "example-hdwallet-eddsa", "Client ID used to scope result routing") + flag.Parse() + config.InitViperConfig("") logger.Init(environment, true) @@ -96,7 +100,7 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, Signer: localSigner, - }) + }, client.WithClientID(*clientID)) // Step 1: Generate ONE master wallet fmt.Println("Step 1: Generating master MPC wallet...") diff --git a/examples/reshare/main.go b/examples/reshare/main.go index 47c4d858..9d3a34bb 100644 --- a/examples/reshare/main.go +++ b/examples/reshare/main.go @@ -1,6 +1,7 @@ package main import ( + "flag" "fmt" "os" "os/signal" @@ -19,6 +20,9 @@ import ( func main() { const environment = "dev" + clientID := flag.String("client-id", "example-reshare", "Client ID used to scope result routing") + flag.Parse() + config.InitViperConfig("") logger.Init(environment, true) @@ -63,7 +67,7 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, Signer: localSigner, - }) + }, client.WithClientID(*clientID)) // 3) Listen for signing results err = mpcClient.OnResharingResult(func(evt event.ResharingResultEvent) { diff --git a/examples/sign/main.go b/examples/sign/main.go index 3424610f..a8a39ced 100644 --- a/examples/sign/main.go +++ b/examples/sign/main.go @@ -1,6 +1,7 @@ package main import ( + "flag" "fmt" "os" "os/signal" @@ -19,6 +20,9 @@ import ( func main() { const environment = "dev" + clientID := flag.String("client-id", "example-sign", "Client ID used to scope result routing") + flag.Parse() + config.InitViperConfig("") logger.Init(environment, true) @@ -63,15 +67,15 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, Signer: localSigner, - }) + }, client.WithClientID(*clientID)) // 2) Once wallet exists, immediately fire a SignTransaction txID := uuid.New().String() dummyTx := []byte("deadbeef") // replace with real transaction bytes txMsg := &types.SignTxMessage{ - KeyType: types.KeyTypeEd25519, - WalletID: "ad24f678-b04b-4149-bcf6-bf9c90df8e63", // Use the generated wallet ID + KeyType: types.KeyTypeSecp256k1, + WalletID: "b8a32a42-b5ea-4c80-a489-d2ec9e873cdf", // Use the generated wallet ID NetworkInternalCode: "solana-devnet", TxID: txID, Tx: dummyTx, From a5e406486c4933b2f2b5ff8c124d3ee7bd4a1080 Mon Sep 17 00:00:00 2001 From: vietddude Date: Tue, 24 Mar 2026 14:44:45 +0700 Subject: [PATCH 3/9] test(scripts): add client routing checker --- scripts/check-client-routing/main.go | 347 +++++++++++++++++++++++++++ 1 file changed, 347 insertions(+) create mode 100644 scripts/check-client-routing/main.go diff --git a/scripts/check-client-routing/main.go b/scripts/check-client-routing/main.go new file mode 100644 index 00000000..eb691ae7 --- /dev/null +++ b/scripts/check-client-routing/main.go @@ -0,0 +1,347 @@ +package main + +import ( + "flag" + "fmt" + "os" + "slices" + "sync" + "time" + + "github.com/fystack/mpcium/pkg/client" + "github.com/fystack/mpcium/pkg/config" + "github.com/fystack/mpcium/pkg/event" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/types" + "github.com/google/uuid" + "github.com/nats-io/nats.go" + "github.com/spf13/viper" +) + +type clientStats struct { + name string + clientID string + requested map[string]struct{} + received map[string]event.KeygenResultEvent + misrouted map[string]event.KeygenResultEvent + untracked map[string]event.KeygenResultEvent +} + +type routingState struct { + mu sync.Mutex + clients map[string]*clientStats + totalWanted int + totalResults int + doneCh chan struct{} + doneOnce sync.Once +} + +func main() { + if err := run(); err != nil { + fmt.Fprintf(os.Stderr, "client routing check failed: %v\n", err) + os.Exit(1) + } +} + +func run() error { + clientAID := flag.String("client-a-id", "svc-a", "Client ID for client A") + clientBID := flag.String("client-b-id", "svc-b", "Client ID for client B") + keyPath := flag.String("key-path", "./event_initiator.key", "Path to the event initiator private key") + natsURLFlag := flag.String("nats-url", "", "NATS URL override (defaults to config nats.url)") + algorithmFlag := flag.String("algorithm", "", "Initiator signing algorithm override (ed25519 or p256)") + walletsPerClient := flag.Int("wallets-per-client", 3, "Number of wallet creation requests per client") + timeout := flag.Duration("timeout", 90*time.Second, "Max time to wait for all results") + listenerWarmup := flag.Duration("listener-warmup", 3*time.Second, "Delay after listener setup before sending requests") + legacyMode := flag.Bool("legacy", false, "Create both clients without client IDs to reproduce the old shared-queue behavior") + flag.Parse() + + if *walletsPerClient <= 0 { + return fmt.Errorf("wallets-per-client must be > 0") + } + + config.InitViperConfig("") + logger.Init("dev", true) + + algorithm := *algorithmFlag + if algorithm == "" { + algorithm = viper.GetString("event_initiator_algorithm") + } + if algorithm == "" { + algorithm = string(types.EventInitiatorKeyTypeEd25519) + } + if !slices.Contains( + []string{ + string(types.EventInitiatorKeyTypeEd25519), + string(types.EventInitiatorKeyTypeP256), + }, + algorithm, + ) { + return fmt.Errorf( + "invalid algorithm %q: must be %s or %s", + algorithm, + types.EventInitiatorKeyTypeEd25519, + types.EventInitiatorKeyTypeP256, + ) + } + + natsURL := *natsURLFlag + if natsURL == "" { + natsURL = viper.GetString("nats.url") + } + if natsURL == "" { + return fmt.Errorf("nats url is required") + } + + natsConn, err := nats.Connect(natsURL) + if err != nil { + return fmt.Errorf("connect nats: %w", err) + } + defer natsConn.Drain() + defer natsConn.Close() + + signer, err := client.NewLocalSigner(types.EventInitiatorKeyType(algorithm), client.LocalSignerOptions{ + KeyPath: *keyPath, + }) + if err != nil { + return fmt.Errorf("create local signer: %w", err) + } + + clientA := newMPCClient(natsConn, signer, *clientAID, *legacyMode) + clientB := newMPCClient(natsConn, signer, *clientBID, *legacyMode) + + effectiveClientAID := *clientAID + effectiveClientBID := *clientBID + if *legacyMode { + effectiveClientAID = "" + effectiveClientBID = "" + } + + state := &routingState{ + clients: map[string]*clientStats{ + "A": { + name: "A", + clientID: effectiveClientAID, + requested: make(map[string]struct{}), + received: make(map[string]event.KeygenResultEvent), + misrouted: make(map[string]event.KeygenResultEvent), + untracked: make(map[string]event.KeygenResultEvent), + }, + "B": { + name: "B", + clientID: effectiveClientBID, + requested: make(map[string]struct{}), + received: make(map[string]event.KeygenResultEvent), + misrouted: make(map[string]event.KeygenResultEvent), + untracked: make(map[string]event.KeygenResultEvent), + }, + }, + totalWanted: *walletsPerClient * 2, + doneCh: make(chan struct{}), + } + + if err := clientA.OnWalletCreationResult(func(result event.KeygenResultEvent) { + state.record("A", result) + }); err != nil { + return fmt.Errorf("subscribe client A: %w", err) + } + if err := clientB.OnWalletCreationResult(func(result event.KeygenResultEvent) { + state.record("B", result) + }); err != nil { + return fmt.Errorf("subscribe client B: %w", err) + } + + fmt.Printf("listeners ready, waiting %s before publishing requests\n", listenerWarmup.String()) + time.Sleep(*listenerWarmup) + + requestsA := make([]string, 0, *walletsPerClient) + requestsB := make([]string, 0, *walletsPerClient) + + for i := 0; i < *walletsPerClient; i++ { + walletID := "route-a-" + uuid.NewString() + state.clients["A"].requested[walletID] = struct{}{} + requestsA = append(requestsA, walletID) + } + for i := 0; i < *walletsPerClient; i++ { + walletID := "route-b-" + uuid.NewString() + state.clients["B"].requested[walletID] = struct{}{} + requestsB = append(requestsB, walletID) + } + + fmt.Printf("mode=%s clientA=%q clientB=%q wallets-per-client=%d\n", + modeName(*legacyMode), effectiveClientAID, effectiveClientBID, *walletsPerClient) + fmt.Printf("client A requested wallets: %v\n", requestsA) + fmt.Printf("client B requested wallets: %v\n", requestsB) + + var publishWG sync.WaitGroup + publishWG.Add(2) + go func() { + defer publishWG.Done() + for _, walletID := range requestsA { + if err := clientA.CreateWallet(walletID); err != nil { + logger.Error("Client A create wallet failed", err, "walletID", walletID) + } + } + }() + go func() { + defer publishWG.Done() + for _, walletID := range requestsB { + if err := clientB.CreateWallet(walletID); err != nil { + logger.Error("Client B create wallet failed", err, "walletID", walletID) + } + } + }() + publishWG.Wait() + + select { + case <-state.doneCh: + case <-time.After(*timeout): + fmt.Printf("timed out after %s waiting for results\n", timeout.String()) + } + + printSummary(state) + + if err := state.validate(); err != nil { + return err + } + + fmt.Println("routing check passed: no client received another client's result") + return nil +} + +func newMPCClient(natsConn *nats.Conn, signer client.Signer, clientID string, legacy bool) client.MPCClient { + opts := client.Options{ + NatsConn: natsConn, + Signer: signer, + } + if legacy { + return client.NewMPCClient(opts) + } + return client.NewMPCClient(opts, client.WithClientID(clientID)) +} + +func (s *routingState) record(clientName string, result event.KeygenResultEvent) { + s.mu.Lock() + defer s.mu.Unlock() + + stats := s.clients[clientName] + if _, exists := stats.received[result.WalletID]; exists { + return + } + + stats.received[result.WalletID] = result + s.totalResults++ + + if _, ok := stats.requested[result.WalletID]; ok { + if s.totalResults >= s.totalWanted { + s.doneOnce.Do(func() { + close(s.doneCh) + }) + } + return + } + + if otherName := otherClientName(clientName); otherName != "" { + if _, ok := s.clients[otherName].requested[result.WalletID]; ok { + stats.misrouted[result.WalletID] = result + } else { + stats.untracked[result.WalletID] = result + } + } + + if s.totalResults >= s.totalWanted { + s.doneOnce.Do(func() { + close(s.doneCh) + }) + } +} + +func (s *routingState) validate() error { + s.mu.Lock() + defer s.mu.Unlock() + + var reasons []string + for _, name := range []string{"A", "B"} { + stats := s.clients[name] + if len(stats.misrouted) > 0 { + reasons = append(reasons, fmt.Sprintf("client %s received %d misrouted result(s)", name, len(stats.misrouted))) + } + if len(stats.untracked) > 0 { + reasons = append(reasons, fmt.Sprintf("client %s received %d unexpected result(s)", name, len(stats.untracked))) + } + if missing := missingWallets(stats); len(missing) > 0 { + reasons = append(reasons, fmt.Sprintf("client %s is missing %d expected result(s): %v", name, len(missing), missing)) + } + } + + if len(reasons) == 0 { + return nil + } + return fmt.Errorf("%v", reasons) +} + +func printSummary(state *routingState) { + state.mu.Lock() + defer state.mu.Unlock() + + fmt.Println("---- routing summary ----") + for _, name := range []string{"A", "B"} { + stats := state.clients[name] + fmt.Printf("client %s (clientID=%q): requested=%d received=%d misrouted=%d unexpected=%d missing=%d\n", + stats.name, + stats.clientID, + len(stats.requested), + len(stats.received), + len(stats.misrouted), + len(stats.untracked), + len(missingWallets(stats)), + ) + if len(stats.misrouted) > 0 { + fmt.Printf(" misrouted wallets: %v\n", sortedEventKeys(stats.misrouted)) + } + if len(stats.untracked) > 0 { + fmt.Printf(" unexpected wallets: %v\n", sortedEventKeys(stats.untracked)) + } + if missing := missingWallets(stats); len(missing) > 0 { + fmt.Printf(" missing wallets: %v\n", missing) + } + } + fmt.Println("-------------------------") +} + +func missingWallets(stats *clientStats) []string { + missing := make([]string, 0) + for walletID := range stats.requested { + if _, ok := stats.received[walletID]; !ok { + missing = append(missing, walletID) + } + } + slices.Sort(missing) + return missing +} + +func sortedEventKeys(events map[string]event.KeygenResultEvent) []string { + keys := make([]string, 0, len(events)) + for walletID := range events { + keys = append(keys, walletID) + } + slices.Sort(keys) + return keys +} + +func otherClientName(name string) string { + switch name { + case "A": + return "B" + case "B": + return "A" + default: + return "" + } +} + +func modeName(legacy bool) string { + if legacy { + return "legacy" + } + return "scoped" +} From d195ba1609a87e808212957d47b63be8fe1f97b0 Mon Sep 17 00:00:00 2001 From: vietddude Date: Tue, 24 Mar 2026 15:04:54 +0700 Subject: [PATCH 4/9] fix(examples): add client-id flag to authorizers generate --- examples/authorizers/generate/main.go | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/authorizers/generate/main.go b/examples/authorizers/generate/main.go index 08da5e01..21109066 100644 --- a/examples/authorizers/generate/main.go +++ b/examples/authorizers/generate/main.go @@ -29,6 +29,7 @@ var requiredAuthorizers = []string{"authorizer1", "authorizer2"} func main() { const environment = "development" numWallets := flag.Int("n", 1, "Number of wallets to generate") + clientID := flag.String("client-id", "example-authorizers-generate", "Client ID used to scope result routing") flag.Parse() From ce5d4db3abd11bad2284286f7bb02da5f7737d74 Mon Sep 17 00:00:00 2001 From: vietddude Date: Tue, 24 Mar 2026 15:05:45 +0700 Subject: [PATCH 5/9] test(e2e): cover multi-client result routing --- e2e/go.mod | 2 +- e2e/multi_client_routing_test.go | 410 +++++++++++++++++++++++++++++++ 2 files changed, 411 insertions(+), 1 deletion(-) create mode 100644 e2e/multi_client_routing_test.go diff --git a/e2e/go.mod b/e2e/go.mod index ac74a6d1..75856958 100644 --- a/e2e/go.mod +++ b/e2e/go.mod @@ -1,6 +1,6 @@ module github.com/fystack/mpcium/e2e -go 1.25.5 +go 1.25.8 require ( github.com/bnb-chain/tss-lib/v2 v2.0.2 diff --git a/e2e/multi_client_routing_test.go b/e2e/multi_client_routing_test.go new file mode 100644 index 00000000..09356858 --- /dev/null +++ b/e2e/multi_client_routing_test.go @@ -0,0 +1,410 @@ +package e2e + +import ( + "fmt" + "path/filepath" + "slices" + "strings" + "sync" + "testing" + "time" + + "github.com/fystack/mpcium/pkg/client" + "github.com/fystack/mpcium/pkg/event" + "github.com/fystack/mpcium/pkg/types" + "github.com/google/uuid" + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const listenerSetupDelay = 5 * time.Second + +type multiClientObserver struct { + name string + + mu sync.Mutex + + expectedWallets map[string]struct{} + keygenResults map[string]event.KeygenResultEvent + unexpectedWallet map[string]event.KeygenResultEvent + + expectedTxs map[string]struct{} + signResults map[string]event.SigningResultEvent + unexpectedTx map[string]event.SigningResultEvent +} + +func TestMultiClientResultRouting(t *testing.T) { + suite := NewE2ETestSuite(".") + + t.Log("Performing pre-test cleanup...") + suite.CleanupTestEnvironment(t) + + defer func() { + t.Log("Performing post-test cleanup...") + suite.Cleanup(t) + }() + + t.Run("Setup", func(t *testing.T) { + t.Log("Running make clean to ensure clean build...") + err := suite.RunMakeClean() + require.NoError(t, err, "Failed to run make clean") + + suite.SetupInfrastructure(t) + suite.SetupTestNodes(t) + + err = suite.LoadConfig() + require.NoError(t, err, "Failed to load config after setup") + + suite.RegisterPeers(t) + suite.SeedPreParams(t) + suite.StartNodes(t) + suite.WaitForNodesReady(t) + }) + + t.Run("ScopedKeygenAndSigning", func(t *testing.T) { + testScopedKeygenAndSigningRouting(t, suite) + }) +} + +func testScopedKeygenAndSigningRouting(t *testing.T, suite *E2ETestSuite) { + clientA, connA := newScopedMPCClient(t, suite, "svc-a") + defer connA.Close() + + clientB, connB := newScopedMPCClient(t, suite, "svc-b") + defer connB.Close() + + observerA := newMultiClientObserver("A") + observerB := newMultiClientObserver("B") + + walletA := "route-a-" + uuid.NewString() + walletB := "route-b-" + uuid.NewString() + observerA.expectWallet(walletA) + observerB.expectWallet(walletB) + + require.NoError(t, clientA.OnWalletCreationResult(observerA.recordKeygen), "Failed to subscribe client A keygen results") + require.NoError(t, clientB.OnWalletCreationResult(observerB.recordKeygen), "Failed to subscribe client B keygen results") + require.NoError(t, clientA.OnSignResult(observerA.recordSigning), "Failed to subscribe client A signing results") + require.NoError(t, clientB.OnSignResult(observerB.recordSigning), "Failed to subscribe client B signing results") + + time.Sleep(listenerSetupDelay) + + var createWG sync.WaitGroup + createErrCh := make(chan error, 2) + createWG.Add(2) + go func() { + defer createWG.Done() + createErrCh <- clientA.CreateWallet(walletA) + }() + go func() { + defer createWG.Done() + createErrCh <- clientB.CreateWallet(walletB) + }() + createWG.Wait() + close(createErrCh) + for err := range createErrCh { + require.NoError(t, err, "Scoped client failed to create wallet") + } + + waitForKeygenRouting(t, observerA, observerB) + + observerA.assertNoUnexpectedKeygen(t) + observerB.assertNoUnexpectedKeygen(t) + observerA.assertKeygenSuccess(t, walletA) + observerB.assertKeygenSuccess(t, walletB) + + txA := uuid.NewString() + txB := uuid.NewString() + observerA.expectTx(txA) + observerB.expectTx(txB) + + var signWG sync.WaitGroup + signErrCh := make(chan error, 2) + signWG.Add(2) + go func() { + defer signWG.Done() + signErrCh <- clientA.SignTransaction(&types.SignTxMessage{ + WalletID: walletA, + TxID: txA, + Tx: []byte("route-a-signing-payload"), + KeyType: types.KeyTypeEd25519, + NetworkInternalCode: "test", + }) + }() + go func() { + defer signWG.Done() + signErrCh <- clientB.SignTransaction(&types.SignTxMessage{ + WalletID: walletB, + TxID: txB, + Tx: []byte("route-b-signing-payload"), + KeyType: types.KeyTypeEd25519, + NetworkInternalCode: "test", + }) + }() + signWG.Wait() + close(signErrCh) + for err := range signErrCh { + require.NoError(t, err, "Scoped client failed to sign transaction") + } + + waitForSigningRouting(t, observerA, observerB) + + observerA.assertNoUnexpectedSigning(t) + observerB.assertNoUnexpectedSigning(t) + observerA.assertSigningSuccess(t, txA) + observerB.assertSigningSuccess(t, txB) +} + +func newScopedMPCClient(t *testing.T, suite *E2ETestSuite, clientID string) (client.MPCClient, *nats.Conn) { + t.Helper() + + keyPath := filepath.Join(suite.testDir, "test_event_initiator.key") + signer, err := client.NewLocalSigner(types.EventInitiatorKeyTypeEd25519, client.LocalSignerOptions{ + KeyPath: keyPath, + }) + require.NoError(t, err, "Failed to create signer for client %s", clientID) + + natsConn, err := nats.Connect(suite.natsConn.ConnectedUrl()) + require.NoError(t, err, "Failed to connect scoped client %s to NATS", clientID) + + mpcClient := client.NewMPCClient(client.Options{ + NatsConn: natsConn, + Signer: signer, + }, client.WithClientID(clientID)) + + return mpcClient, natsConn +} + +func newMultiClientObserver(name string) *multiClientObserver { + return &multiClientObserver{ + name: name, + expectedWallets: make(map[string]struct{}), + keygenResults: make(map[string]event.KeygenResultEvent), + unexpectedWallet: make(map[string]event.KeygenResultEvent), + expectedTxs: make(map[string]struct{}), + signResults: make(map[string]event.SigningResultEvent), + unexpectedTx: make(map[string]event.SigningResultEvent), + } +} + +func (o *multiClientObserver) expectWallet(walletID string) { + o.mu.Lock() + defer o.mu.Unlock() + o.expectedWallets[walletID] = struct{}{} +} + +func (o *multiClientObserver) expectTx(txID string) { + o.mu.Lock() + defer o.mu.Unlock() + o.expectedTxs[txID] = struct{}{} +} + +func (o *multiClientObserver) recordKeygen(result event.KeygenResultEvent) { + o.mu.Lock() + defer o.mu.Unlock() + + if _, ok := o.expectedWallets[result.WalletID]; ok { + if _, exists := o.keygenResults[result.WalletID]; !exists { + o.keygenResults[result.WalletID] = result + } + return + } + + o.unexpectedWallet[result.WalletID] = result +} + +func (o *multiClientObserver) recordSigning(result event.SigningResultEvent) { + o.mu.Lock() + defer o.mu.Unlock() + + if _, ok := o.expectedTxs[result.TxID]; ok { + if _, exists := o.signResults[result.TxID]; !exists { + o.signResults[result.TxID] = result + } + return + } + + o.unexpectedTx[result.TxID] = result +} + +func (o *multiClientObserver) keygenComplete() bool { + o.mu.Lock() + defer o.mu.Unlock() + return len(o.keygenResults) == len(o.expectedWallets) +} + +func (o *multiClientObserver) signingComplete() bool { + o.mu.Lock() + defer o.mu.Unlock() + return len(o.signResults) == len(o.expectedTxs) +} + +func (o *multiClientObserver) hasUnexpectedKeygen() bool { + o.mu.Lock() + defer o.mu.Unlock() + return len(o.unexpectedWallet) > 0 +} + +func (o *multiClientObserver) hasUnexpectedSigning() bool { + o.mu.Lock() + defer o.mu.Unlock() + return len(o.unexpectedTx) > 0 +} + +func (o *multiClientObserver) assertNoUnexpectedKeygen(t *testing.T) { + t.Helper() + + o.mu.Lock() + defer o.mu.Unlock() + assert.Empty(t, sortedKeygenKeys(o.unexpectedWallet), "client %s received unexpected keygen results", o.name) +} + +func (o *multiClientObserver) assertNoUnexpectedSigning(t *testing.T) { + t.Helper() + + o.mu.Lock() + defer o.mu.Unlock() + assert.Empty(t, sortedSigningKeys(o.unexpectedTx), "client %s received unexpected signing results", o.name) +} + +func (o *multiClientObserver) assertKeygenSuccess(t *testing.T, walletID string) { + t.Helper() + + o.mu.Lock() + defer o.mu.Unlock() + + result, ok := o.keygenResults[walletID] + require.True(t, ok, "client %s missing keygen result for wallet %s", o.name, walletID) + assert.Equal(t, event.ResultTypeSuccess, result.ResultType, "client %s keygen result for wallet %s should succeed", o.name, walletID) + assert.NotEmpty(t, result.ECDSAPubKey, "client %s ECDSA pubkey should not be empty", o.name) + assert.NotEmpty(t, result.EDDSAPubKey, "client %s EdDSA pubkey should not be empty", o.name) +} + +func (o *multiClientObserver) assertSigningSuccess(t *testing.T, txID string) { + t.Helper() + + o.mu.Lock() + defer o.mu.Unlock() + + result, ok := o.signResults[txID] + require.True(t, ok, "client %s missing signing result for tx %s", o.name, txID) + assert.Equal(t, event.ResultTypeSuccess, result.ResultType, "client %s signing result for tx %s should succeed", o.name, txID) + assert.NotEmpty(t, result.Signature, "client %s signature should not be empty", o.name) +} + +func (o *multiClientObserver) keygenSnapshot() string { + o.mu.Lock() + defer o.mu.Unlock() + + return fmt.Sprintf( + "client %s expected wallets=%v received wallets=%v unexpected wallets=%v", + o.name, + sortedStringSet(o.expectedWallets), + sortedKeygenKeys(o.keygenResults), + sortedKeygenKeys(o.unexpectedWallet), + ) +} + +func (o *multiClientObserver) signingSnapshot() string { + o.mu.Lock() + defer o.mu.Unlock() + + return fmt.Sprintf( + "client %s expected txs=%v received txs=%v unexpected txs=%v", + o.name, + sortedStringSet(o.expectedTxs), + sortedSigningKeys(o.signResults), + sortedSigningKeys(o.unexpectedTx), + ) +} + +func waitForKeygenRouting(t *testing.T, observers ...*multiClientObserver) { + t.Helper() + + waitForPhase( + t, + keygenTimeout, + func(o *multiClientObserver) bool { return o.keygenComplete() }, + func(o *multiClientObserver) bool { return o.hasUnexpectedKeygen() }, + func(o *multiClientObserver) string { return o.keygenSnapshot() }, + observers..., + ) +} + +func waitForSigningRouting(t *testing.T, observers ...*multiClientObserver) { + t.Helper() + + waitForPhase( + t, + signingTimeout, + func(o *multiClientObserver) bool { return o.signingComplete() }, + func(o *multiClientObserver) bool { return o.hasUnexpectedSigning() }, + func(o *multiClientObserver) string { return o.signingSnapshot() }, + observers..., + ) +} + +func waitForPhase( + t *testing.T, + timeout time.Duration, + isComplete func(*multiClientObserver) bool, + hasUnexpected func(*multiClientObserver) bool, + snapshot func(*multiClientObserver) string, + observers ...*multiClientObserver, +) { + t.Helper() + + deadline := time.Now().Add(timeout) + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + for { + allComplete := true + for _, observer := range observers { + if hasUnexpected(observer) { + t.Fatalf("unexpected routed result detected: %s", snapshot(observer)) + } + if !isComplete(observer) { + allComplete = false + } + } + if allComplete { + return + } + if time.Now().After(deadline) { + snapshots := make([]string, 0, len(observers)) + for _, observer := range observers { + snapshots = append(snapshots, snapshot(observer)) + } + t.Fatalf("timed out waiting for routed results: %s", strings.Join(snapshots, " | ")) + } + <-ticker.C + } +} + +func sortedStringSet(values map[string]struct{}) []string { + keys := make([]string, 0, len(values)) + for value := range values { + keys = append(keys, value) + } + slices.Sort(keys) + return keys +} + +func sortedKeygenKeys(values map[string]event.KeygenResultEvent) []string { + keys := make([]string, 0, len(values)) + for key := range values { + keys = append(keys, key) + } + slices.Sort(keys) + return keys +} + +func sortedSigningKeys(values map[string]event.SigningResultEvent) []string { + keys := make([]string, 0, len(values)) + for key := range values { + keys = append(keys, key) + } + slices.Sort(keys) + return keys +} From 94b4d38d02d39273018688d9808ac70d3e64286d Mon Sep 17 00:00:00 2001 From: vietddude Date: Tue, 24 Mar 2026 18:44:41 +0700 Subject: [PATCH 6/9] refactor(client): update MPCClient initialization to use ClientID directly in options --- e2e/base_test.go | 3 +- e2e/multi_client_routing_test.go | 3 +- examples/authorizers/generate/main.go | 3 +- examples/authorizers/sign/main.go | 3 +- examples/ckd/main.go | 3 +- examples/generate/kms/main.go | 3 +- examples/generate/main.go | 3 +- examples/hdwallet/ecdsa/main.go | 3 +- examples/hdwallet/eddsa/main.go | 3 +- examples/reshare/main.go | 3 +- examples/sign/main.go | 3 +- pkg/client/client.go | 26 +- scripts/check-client-routing/main.go | 347 -------------------------- 13 files changed, 28 insertions(+), 378 deletions(-) delete mode 100644 scripts/check-client-routing/main.go diff --git a/e2e/base_test.go b/e2e/base_test.go index bcc010fb..194b5f1b 100644 --- a/e2e/base_test.go +++ b/e2e/base_test.go @@ -196,7 +196,8 @@ func (s *E2ETestSuite) SetupMPCClient(t *testing.T) { mpcClient := client.NewMPCClient(client.Options{ NatsConn: s.natsConn, Signer: localSigner, - }, client.WithClientID("e2e-suite")) + ClientID: "e2e-suite", + }) s.mpcClient = mpcClient t.Log("MPC client created") } diff --git a/e2e/multi_client_routing_test.go b/e2e/multi_client_routing_test.go index 09356858..d28d2e7d 100644 --- a/e2e/multi_client_routing_test.go +++ b/e2e/multi_client_routing_test.go @@ -170,7 +170,8 @@ func newScopedMPCClient(t *testing.T, suite *E2ETestSuite, clientID string) (cli mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, Signer: signer, - }, client.WithClientID(clientID)) + ClientID: clientID, + }) return mpcClient, natsConn } diff --git a/examples/authorizers/generate/main.go b/examples/authorizers/generate/main.go index 21109066..8504731d 100644 --- a/examples/authorizers/generate/main.go +++ b/examples/authorizers/generate/main.go @@ -90,7 +90,8 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, Signer: localSigner, - }, client.WithClientID(*clientID)) + ClientID: *clientID, + }) var walletStartTimes sync.Map var walletIDs []string diff --git a/examples/authorizers/sign/main.go b/examples/authorizers/sign/main.go index a5e14eed..0c248b55 100644 --- a/examples/authorizers/sign/main.go +++ b/examples/authorizers/sign/main.go @@ -84,7 +84,8 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, Signer: localSigner, - }, client.WithClientID(*clientID)) + ClientID: *clientID, + }) // Create a signing request with authorizers txID := uuid.New().String() diff --git a/examples/ckd/main.go b/examples/ckd/main.go index 621f19b4..1aeef636 100644 --- a/examples/ckd/main.go +++ b/examples/ckd/main.go @@ -67,7 +67,8 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, Signer: localSigner, - }, client.WithClientID(*clientID)) + ClientID: *clientID, + }) // 2) Once wallet exists, immediately fire a SignTransaction txID := uuid.New().String() diff --git a/examples/generate/kms/main.go b/examples/generate/kms/main.go index 91cc4358..b9d08e5f 100644 --- a/examples/generate/kms/main.go +++ b/examples/generate/kms/main.go @@ -66,7 +66,8 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, Signer: kmsSigner, - }, client.WithClientID(*clientID)) + ClientID: *clientID, + }) var walletStartTimes sync.Map var walletIDs []string diff --git a/examples/generate/main.go b/examples/generate/main.go index 7209ebfc..24916356 100644 --- a/examples/generate/main.go +++ b/examples/generate/main.go @@ -73,7 +73,8 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, Signer: localSigner, - }, client.WithClientID(*clientID)) + ClientID: *clientID, + }) var walletStartTimes sync.Map var walletIDs []string diff --git a/examples/hdwallet/ecdsa/main.go b/examples/hdwallet/ecdsa/main.go index 2e3365f6..817abc67 100644 --- a/examples/hdwallet/ecdsa/main.go +++ b/examples/hdwallet/ecdsa/main.go @@ -97,7 +97,8 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, Signer: localSigner, - }, client.WithClientID(*clientID)) + ClientID: *clientID, + }) // Step 1: Generate ONE master wallet fmt.Println("Step 1: Generating master MPC wallet...") diff --git a/examples/hdwallet/eddsa/main.go b/examples/hdwallet/eddsa/main.go index 7950a297..6ab69b27 100644 --- a/examples/hdwallet/eddsa/main.go +++ b/examples/hdwallet/eddsa/main.go @@ -100,7 +100,8 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, Signer: localSigner, - }, client.WithClientID(*clientID)) + ClientID: *clientID, + }) // Step 1: Generate ONE master wallet fmt.Println("Step 1: Generating master MPC wallet...") diff --git a/examples/reshare/main.go b/examples/reshare/main.go index 9d3a34bb..3995b1b5 100644 --- a/examples/reshare/main.go +++ b/examples/reshare/main.go @@ -67,7 +67,8 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, Signer: localSigner, - }, client.WithClientID(*clientID)) + ClientID: *clientID, + }) // 3) Listen for signing results err = mpcClient.OnResharingResult(func(evt event.ResharingResultEvent) { diff --git a/examples/sign/main.go b/examples/sign/main.go index a8a39ced..9064e310 100644 --- a/examples/sign/main.go +++ b/examples/sign/main.go @@ -67,7 +67,8 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, Signer: localSigner, - }, client.WithClientID(*clientID)) + ClientID: *clientID, + }) // 2) Once wallet exists, immediately fire a SignTransaction txID := uuid.New().String() diff --git a/pkg/client/client.go b/pkg/client/client.go index 5b5c240f..1926e0ec 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -44,12 +44,9 @@ type Options struct { // Signer for signing messages Signer Signer -} - -type ClientOption func(*clientConfig) -type clientConfig struct { - clientID string + // ClientID optionally scopes result routing for this client instance. + ClientID string } type clientResultRouting struct { @@ -61,24 +58,13 @@ type clientResultRouting struct { reshareSubject string } -func WithClientID(id string) ClientOption { - return func(cfg *clientConfig) { - cfg.clientID = id - } -} - // NewMPCClient creates a new MPC client using the provided options. // The signer must be provided to handle message signing. -func NewMPCClient(opts Options, clientOptions ...ClientOption) MPCClient { +func NewMPCClient(opts Options) MPCClient { if opts.Signer == nil { logger.Fatal("Signer is required", nil) } - - cfg := clientConfig{} - for _, opt := range clientOptions { - opt(&cfg) - } - if err := validateClientID(cfg.clientID); err != nil { + if err := validateClientID(opts.ClientID); err != nil { logger.Fatal("Invalid client ID", err) } @@ -109,7 +95,7 @@ func NewMPCClient(opts Options, clientOptions ...ClientOption) MPCClient { pubsub := messaging.NewNATSPubSub(opts.NatsConn) manager := messaging.NewNATsMessageQueueManager("mpc", event.ResultStreamSubjects(), opts.NatsConn) - routing := buildClientResultRouting(cfg.clientID) + routing := buildClientResultRouting(opts.ClientID) genKeySuccessQueue := manager.NewMessageQueue(routing.keygenConsumerName, routing.keygenSubject) signResultQueue := manager.NewMessageQueue(routing.signingConsumerName, routing.signingSubject) @@ -123,7 +109,7 @@ func NewMPCClient(opts Options, clientOptions ...ClientOption) MPCClient { signResultQueue: signResultQueue, reshareSuccessQueue: reshareSuccessQueue, signer: opts.Signer, - clientID: cfg.clientID, + clientID: opts.ClientID, } } diff --git a/scripts/check-client-routing/main.go b/scripts/check-client-routing/main.go deleted file mode 100644 index eb691ae7..00000000 --- a/scripts/check-client-routing/main.go +++ /dev/null @@ -1,347 +0,0 @@ -package main - -import ( - "flag" - "fmt" - "os" - "slices" - "sync" - "time" - - "github.com/fystack/mpcium/pkg/client" - "github.com/fystack/mpcium/pkg/config" - "github.com/fystack/mpcium/pkg/event" - "github.com/fystack/mpcium/pkg/logger" - "github.com/fystack/mpcium/pkg/types" - "github.com/google/uuid" - "github.com/nats-io/nats.go" - "github.com/spf13/viper" -) - -type clientStats struct { - name string - clientID string - requested map[string]struct{} - received map[string]event.KeygenResultEvent - misrouted map[string]event.KeygenResultEvent - untracked map[string]event.KeygenResultEvent -} - -type routingState struct { - mu sync.Mutex - clients map[string]*clientStats - totalWanted int - totalResults int - doneCh chan struct{} - doneOnce sync.Once -} - -func main() { - if err := run(); err != nil { - fmt.Fprintf(os.Stderr, "client routing check failed: %v\n", err) - os.Exit(1) - } -} - -func run() error { - clientAID := flag.String("client-a-id", "svc-a", "Client ID for client A") - clientBID := flag.String("client-b-id", "svc-b", "Client ID for client B") - keyPath := flag.String("key-path", "./event_initiator.key", "Path to the event initiator private key") - natsURLFlag := flag.String("nats-url", "", "NATS URL override (defaults to config nats.url)") - algorithmFlag := flag.String("algorithm", "", "Initiator signing algorithm override (ed25519 or p256)") - walletsPerClient := flag.Int("wallets-per-client", 3, "Number of wallet creation requests per client") - timeout := flag.Duration("timeout", 90*time.Second, "Max time to wait for all results") - listenerWarmup := flag.Duration("listener-warmup", 3*time.Second, "Delay after listener setup before sending requests") - legacyMode := flag.Bool("legacy", false, "Create both clients without client IDs to reproduce the old shared-queue behavior") - flag.Parse() - - if *walletsPerClient <= 0 { - return fmt.Errorf("wallets-per-client must be > 0") - } - - config.InitViperConfig("") - logger.Init("dev", true) - - algorithm := *algorithmFlag - if algorithm == "" { - algorithm = viper.GetString("event_initiator_algorithm") - } - if algorithm == "" { - algorithm = string(types.EventInitiatorKeyTypeEd25519) - } - if !slices.Contains( - []string{ - string(types.EventInitiatorKeyTypeEd25519), - string(types.EventInitiatorKeyTypeP256), - }, - algorithm, - ) { - return fmt.Errorf( - "invalid algorithm %q: must be %s or %s", - algorithm, - types.EventInitiatorKeyTypeEd25519, - types.EventInitiatorKeyTypeP256, - ) - } - - natsURL := *natsURLFlag - if natsURL == "" { - natsURL = viper.GetString("nats.url") - } - if natsURL == "" { - return fmt.Errorf("nats url is required") - } - - natsConn, err := nats.Connect(natsURL) - if err != nil { - return fmt.Errorf("connect nats: %w", err) - } - defer natsConn.Drain() - defer natsConn.Close() - - signer, err := client.NewLocalSigner(types.EventInitiatorKeyType(algorithm), client.LocalSignerOptions{ - KeyPath: *keyPath, - }) - if err != nil { - return fmt.Errorf("create local signer: %w", err) - } - - clientA := newMPCClient(natsConn, signer, *clientAID, *legacyMode) - clientB := newMPCClient(natsConn, signer, *clientBID, *legacyMode) - - effectiveClientAID := *clientAID - effectiveClientBID := *clientBID - if *legacyMode { - effectiveClientAID = "" - effectiveClientBID = "" - } - - state := &routingState{ - clients: map[string]*clientStats{ - "A": { - name: "A", - clientID: effectiveClientAID, - requested: make(map[string]struct{}), - received: make(map[string]event.KeygenResultEvent), - misrouted: make(map[string]event.KeygenResultEvent), - untracked: make(map[string]event.KeygenResultEvent), - }, - "B": { - name: "B", - clientID: effectiveClientBID, - requested: make(map[string]struct{}), - received: make(map[string]event.KeygenResultEvent), - misrouted: make(map[string]event.KeygenResultEvent), - untracked: make(map[string]event.KeygenResultEvent), - }, - }, - totalWanted: *walletsPerClient * 2, - doneCh: make(chan struct{}), - } - - if err := clientA.OnWalletCreationResult(func(result event.KeygenResultEvent) { - state.record("A", result) - }); err != nil { - return fmt.Errorf("subscribe client A: %w", err) - } - if err := clientB.OnWalletCreationResult(func(result event.KeygenResultEvent) { - state.record("B", result) - }); err != nil { - return fmt.Errorf("subscribe client B: %w", err) - } - - fmt.Printf("listeners ready, waiting %s before publishing requests\n", listenerWarmup.String()) - time.Sleep(*listenerWarmup) - - requestsA := make([]string, 0, *walletsPerClient) - requestsB := make([]string, 0, *walletsPerClient) - - for i := 0; i < *walletsPerClient; i++ { - walletID := "route-a-" + uuid.NewString() - state.clients["A"].requested[walletID] = struct{}{} - requestsA = append(requestsA, walletID) - } - for i := 0; i < *walletsPerClient; i++ { - walletID := "route-b-" + uuid.NewString() - state.clients["B"].requested[walletID] = struct{}{} - requestsB = append(requestsB, walletID) - } - - fmt.Printf("mode=%s clientA=%q clientB=%q wallets-per-client=%d\n", - modeName(*legacyMode), effectiveClientAID, effectiveClientBID, *walletsPerClient) - fmt.Printf("client A requested wallets: %v\n", requestsA) - fmt.Printf("client B requested wallets: %v\n", requestsB) - - var publishWG sync.WaitGroup - publishWG.Add(2) - go func() { - defer publishWG.Done() - for _, walletID := range requestsA { - if err := clientA.CreateWallet(walletID); err != nil { - logger.Error("Client A create wallet failed", err, "walletID", walletID) - } - } - }() - go func() { - defer publishWG.Done() - for _, walletID := range requestsB { - if err := clientB.CreateWallet(walletID); err != nil { - logger.Error("Client B create wallet failed", err, "walletID", walletID) - } - } - }() - publishWG.Wait() - - select { - case <-state.doneCh: - case <-time.After(*timeout): - fmt.Printf("timed out after %s waiting for results\n", timeout.String()) - } - - printSummary(state) - - if err := state.validate(); err != nil { - return err - } - - fmt.Println("routing check passed: no client received another client's result") - return nil -} - -func newMPCClient(natsConn *nats.Conn, signer client.Signer, clientID string, legacy bool) client.MPCClient { - opts := client.Options{ - NatsConn: natsConn, - Signer: signer, - } - if legacy { - return client.NewMPCClient(opts) - } - return client.NewMPCClient(opts, client.WithClientID(clientID)) -} - -func (s *routingState) record(clientName string, result event.KeygenResultEvent) { - s.mu.Lock() - defer s.mu.Unlock() - - stats := s.clients[clientName] - if _, exists := stats.received[result.WalletID]; exists { - return - } - - stats.received[result.WalletID] = result - s.totalResults++ - - if _, ok := stats.requested[result.WalletID]; ok { - if s.totalResults >= s.totalWanted { - s.doneOnce.Do(func() { - close(s.doneCh) - }) - } - return - } - - if otherName := otherClientName(clientName); otherName != "" { - if _, ok := s.clients[otherName].requested[result.WalletID]; ok { - stats.misrouted[result.WalletID] = result - } else { - stats.untracked[result.WalletID] = result - } - } - - if s.totalResults >= s.totalWanted { - s.doneOnce.Do(func() { - close(s.doneCh) - }) - } -} - -func (s *routingState) validate() error { - s.mu.Lock() - defer s.mu.Unlock() - - var reasons []string - for _, name := range []string{"A", "B"} { - stats := s.clients[name] - if len(stats.misrouted) > 0 { - reasons = append(reasons, fmt.Sprintf("client %s received %d misrouted result(s)", name, len(stats.misrouted))) - } - if len(stats.untracked) > 0 { - reasons = append(reasons, fmt.Sprintf("client %s received %d unexpected result(s)", name, len(stats.untracked))) - } - if missing := missingWallets(stats); len(missing) > 0 { - reasons = append(reasons, fmt.Sprintf("client %s is missing %d expected result(s): %v", name, len(missing), missing)) - } - } - - if len(reasons) == 0 { - return nil - } - return fmt.Errorf("%v", reasons) -} - -func printSummary(state *routingState) { - state.mu.Lock() - defer state.mu.Unlock() - - fmt.Println("---- routing summary ----") - for _, name := range []string{"A", "B"} { - stats := state.clients[name] - fmt.Printf("client %s (clientID=%q): requested=%d received=%d misrouted=%d unexpected=%d missing=%d\n", - stats.name, - stats.clientID, - len(stats.requested), - len(stats.received), - len(stats.misrouted), - len(stats.untracked), - len(missingWallets(stats)), - ) - if len(stats.misrouted) > 0 { - fmt.Printf(" misrouted wallets: %v\n", sortedEventKeys(stats.misrouted)) - } - if len(stats.untracked) > 0 { - fmt.Printf(" unexpected wallets: %v\n", sortedEventKeys(stats.untracked)) - } - if missing := missingWallets(stats); len(missing) > 0 { - fmt.Printf(" missing wallets: %v\n", missing) - } - } - fmt.Println("-------------------------") -} - -func missingWallets(stats *clientStats) []string { - missing := make([]string, 0) - for walletID := range stats.requested { - if _, ok := stats.received[walletID]; !ok { - missing = append(missing, walletID) - } - } - slices.Sort(missing) - return missing -} - -func sortedEventKeys(events map[string]event.KeygenResultEvent) []string { - keys := make([]string, 0, len(events)) - for walletID := range events { - keys = append(keys, walletID) - } - slices.Sort(keys) - return keys -} - -func otherClientName(name string) string { - switch name { - case "A": - return "B" - case "B": - return "A" - default: - return "" - } -} - -func modeName(legacy bool) string { - if legacy { - return "legacy" - } - return "scoped" -} From 03006a3708c78d73128f993794c87f826f1b5290 Mon Sep 17 00:00:00 2001 From: vietddude Date: Wed, 25 Mar 2026 09:27:08 +0700 Subject: [PATCH 7/9] fix(examples): update KeyType and WalletID in SignTxMessage for consistency --- examples/sign/main.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/sign/main.go b/examples/sign/main.go index 9064e310..cb198aea 100644 --- a/examples/sign/main.go +++ b/examples/sign/main.go @@ -75,8 +75,8 @@ func main() { dummyTx := []byte("deadbeef") // replace with real transaction bytes txMsg := &types.SignTxMessage{ - KeyType: types.KeyTypeSecp256k1, - WalletID: "b8a32a42-b5ea-4c80-a489-d2ec9e873cdf", // Use the generated wallet ID + KeyType: types.KeyTypeEd25519, + WalletID: "ad24f678-b04b-4149-bcf6-bf9c90df8e63", // Use the generated wallet ID NetworkInternalCode: "solana-devnet", TxID: txID, Tx: dummyTx, From f87b0d29c833c8968a9f9b4e2e3613705840ba62 Mon Sep 17 00:00:00 2001 From: anhthii Date: Thu, 26 Mar 2026 06:24:51 +0700 Subject: [PATCH 8/9] Add client id for benchmark cli --- cmd/mpcium-cli/benchmark.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cmd/mpcium-cli/benchmark.go b/cmd/mpcium-cli/benchmark.go index 9754bbd5..91679381 100644 --- a/cmd/mpcium-cli/benchmark.go +++ b/cmd/mpcium-cli/benchmark.go @@ -80,6 +80,11 @@ func benchmarkCommand() *cli.Command { Value: false, Category: "authentication", }, + &cli.StringFlag{ + Name: "client-id", + Usage: "Client ID for result routing (scopes results to this client instance)", + Category: "configuration", + }, &cli.BoolFlag{ Name: "debug", Usage: "Enable debug logging", @@ -244,6 +249,7 @@ func createMPCClient(cmd *cli.Command) (client.MPCClient, error) { opts := client.Options{ NatsConn: nc, Signer: signer, + ClientID: cmd.String("client-id"), } return client.NewMPCClient(opts), nil } From b408230c5a81e70cc4d9314d1f805a4b5d5f0dbe Mon Sep 17 00:00:00 2001 From: anhthii Date: Thu, 26 Mar 2026 06:28:17 +0700 Subject: [PATCH 9/9] Update readme example --- README.md | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 2255a1b1..fb51fa5c 100644 --- a/README.md +++ b/README.md @@ -166,6 +166,14 @@ $ mpcium start -n node2 Mpcium supports flexible client authentication through a signer interface, allowing you to use either local keys or AWS KMS for signing operations. +#### Client ID (Result Routing) + +When multiple client instances connect to the same MPC cluster, each client **must** set a unique `ClientID` to avoid result routing conflicts. Without distinct client IDs, two clients requesting operations concurrently may race for the same result message, causing one client to receive the other's response. + +- `ClientID` scopes the NATS consumer and result subject so each client only receives its own results. +- Allowed characters: alphanumeric, hyphens, and underscores (e.g. `"backend-svc-1"`, `"mobile_api"`). +- If you only run a single client instance, `ClientID` can be omitted (empty string). + #### Local Signer (Ed25519) ```go @@ -193,10 +201,11 @@ func main() { logger.Fatal("Failed to create local signer", err) } - // Create MPC client with signer + // Create MPC client with signer and a unique client ID mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, Signer: localSigner, + ClientID: "backend-svc-1", // unique per client instance }) // Handle wallet creation results @@ -253,6 +262,7 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, Signer: kmsSigner, + ClientID: "kms-client-1", // unique per client instance }) // ... rest of the client code }