diff --git a/api/auth_middleware_test.go b/api/auth_middleware_test.go index 98d8d651..d7edebe7 100644 --- a/api/auth_middleware_test.go +++ b/api/auth_middleware_test.go @@ -315,6 +315,50 @@ func TestGetApiSignerWithApiAccessKey(t *testing.T) { "body %s should contain address %s", string(body), parentApiKey) } +func TestRequireWriteScope(t *testing.T) { + // requireWriteScope only reads c.Locals("oauthScope"), so no DB is needed. + app := &ApiServer{} + + // Create a dummy write route that chains requireWriteScope after a scope-setting middleware + testApp := fiber.New() + testApp.Post("/write", func(c *fiber.Ctx) error { + // Simulate what authMiddleware does: set oauthScope if a PKCE token was used + scope := c.Get("X-Test-Oauth-Scope") + if scope != "" { + c.Locals("oauthScope", scope) + } + return c.Next() + }, app.requireWriteScope, func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + // PKCE token with scope=read should be rejected (403) + t.Run("read scope rejected", func(t *testing.T) { + req := httptest.NewRequest("POST", "/write", nil) + req.Header.Set("X-Test-Oauth-Scope", "read") + res, err := testApp.Test(req, -1) + assert.NoError(t, err) + assert.Equal(t, fiber.StatusForbidden, res.StatusCode) + }) + + // PKCE token with scope=write should be allowed (200) + t.Run("write scope allowed", func(t *testing.T) { + req := httptest.NewRequest("POST", "/write", nil) + req.Header.Set("X-Test-Oauth-Scope", "write") + res, err := testApp.Test(req, -1) + assert.NoError(t, err) + assert.Equal(t, fiber.StatusOK, res.StatusCode) + }) + + // Non-OAuth auth (no oauthScope set) should pass through (200) + t.Run("non-oauth auth passes through", func(t *testing.T) { + req := httptest.NewRequest("POST", "/write", nil) + res, err := testApp.Test(req, -1) + assert.NoError(t, err) + assert.Equal(t, fiber.StatusOK, res.StatusCode) + }) +} + // ensureApiKeysTables creates api_keys and api_access_keys if they do not exist. func ensureApiKeysTables(t *testing.T, app *ApiServer, ctx context.Context) { t.Helper()