diff --git a/pkg/connector/api_token.go b/pkg/connector/api_token.go index 72be2f4f..6b3d0319 100644 --- a/pkg/connector/api_token.go +++ b/pkg/connector/api_token.go @@ -128,7 +128,7 @@ func (o *apiTokenResourceType) List( }, nil } -func apiTokenBuilder(client *github.Client, hasSAMLEnabled *bool, orgCache *orgNameCache) *apiTokenResourceType { +func apiTokenBuilder(client *github.Client, orgCache *orgNameCache) *apiTokenResourceType { return &apiTokenResourceType{ resourceType: resourceTypeApiToken, client: client, diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index d035cedb..18cc61fe 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -94,24 +94,32 @@ var ( ) type GitHub struct { - orgs []string - client *github.Client - appClient *github.Client - customClient *customclient.Client - instanceURL string - graphqlClient *githubv4.Client - hasSAMLEnabled *bool - orgCache *orgNameCache - syncSecrets bool - omitArchivedRepositories bool - enterprises []string + orgs []string + client *github.Client + appClient *github.Client + customClient *customclient.Client + instanceURL string + graphqlClient *githubv4.Client + orgCache *orgNameCache + syncSecrets bool + omitArchivedRepositories bool + enterprises []string + enterpriseLicensesAvailable bool // set during Validate; guards enterprise role sync and SAML enrichment } func (gh *GitHub) ResourceSyncers(ctx context.Context) []connectorbuilder.ResourceSyncerV2 { + // Only pass enterprises to builders that need them if the consumed-licenses + // API is confirmed accessible. Otherwise, skip enterprise SAML enrichment + // and enterprise role sync gracefully. + var activeEnterprises []string + if gh.enterpriseLicensesAvailable { + activeEnterprises = gh.enterprises + } + resourceSyncers := []connectorbuilder.ResourceSyncerV2{ orgBuilder(gh.client, gh.appClient, gh.orgCache, gh.orgs, gh.syncSecrets), teamBuilder(gh.client, gh.orgCache), - userBuilder(gh.client, gh.hasSAMLEnabled, gh.graphqlClient, gh.orgCache, gh.orgs, gh.customClient, gh.enterprises), + userBuilder(gh.client, gh.graphqlClient, gh.orgCache, gh.orgs, gh.customClient, activeEnterprises), repositoryBuilder(gh.client, gh.orgCache, gh.omitArchivedRepositories), orgRoleBuilder(gh.client, gh.orgCache), invitationBuilder(invitationBuilderParams{ @@ -122,11 +130,11 @@ func (gh *GitHub) ResourceSyncers(ctx context.Context) []connectorbuilder.Resour } if gh.syncSecrets { - resourceSyncers = append(resourceSyncers, apiTokenBuilder(gh.client, gh.hasSAMLEnabled, gh.orgCache)) + resourceSyncers = append(resourceSyncers, apiTokenBuilder(gh.client, gh.orgCache)) } - if len(gh.enterprises) > 0 { - resourceSyncers = append(resourceSyncers, enterpriseRoleBuilder(gh.client, gh.customClient, gh.enterprises)) + if len(activeEnterprises) > 0 { + resourceSyncers = append(resourceSyncers, enterpriseRoleBuilder(gh.client, gh.customClient, activeEnterprises)) } return resourceSyncers } @@ -210,9 +218,13 @@ func (gh *GitHub) Validate(ctx context.Context) (annotations.Annotations, error) } if len(gh.enterprises) > 0 { + l := ctxzap.Extract(ctx) _, _, err := gh.customClient.ListEnterpriseConsumedLicenses(ctx, gh.enterprises[0], 1) if err != nil { - return nil, uhttp.WrapErrors(codes.PermissionDenied, "github-connector: failed to access enterprise licenses", err) + l.Warn("failed to access enterprise consumed licenses — enterprise SAML email enrichment and enterprise role sync will be skipped", + zap.Error(err)) + } else { + gh.enterpriseLicensesAvailable = true } } return nil, nil @@ -228,6 +240,19 @@ func (gh *GitHub) validateAppCredentials(ctx context.Context) (annotations.Annot if err != nil { return nil, err } + + if len(gh.enterprises) > 0 { + l := ctxzap.Extract(ctx) + _, _, err := gh.customClient.ListEnterpriseConsumedLicenses(ctx, gh.enterprises[0], 1) + if err != nil { + l.Warn("failed to access enterprise consumed licenses — enterprise SAML email enrichment and enterprise role sync will be skipped"+ + " (GitHub App installations cannot access this endpoint — use a PAT with enterprise admin scope)", + zap.Error(err)) + } else { + gh.enterpriseLicensesAvailable = true + } + } + return nil, nil } diff --git a/pkg/connector/enterprise_role.go b/pkg/connector/enterprise_role.go index 172af02d..f2c2e64c 100644 --- a/pkg/connector/enterprise_role.go +++ b/pkg/connector/enterprise_role.go @@ -53,7 +53,9 @@ func (o *enterpriseRoleResourceType) getRoleUsersCache(ctx context.Context) (map func (o *enterpriseRoleResourceType) fillCache(ctx context.Context) error { for _, enterprise := range o.enterprises { - page := 0 + // GitHub's consumed-licenses API is 1-indexed; page 0 is undocumented + // and may return the same results as page 1, causing duplicates. + page := 1 continuePagination := true for continuePagination { consumedLicenses, _, err := o.customClient.ListEnterpriseConsumedLicenses(ctx, enterprise, page) diff --git a/pkg/connector/helpers.go b/pkg/connector/helpers.go index 7225cf86..2659668b 100644 --- a/pkg/connector/helpers.go +++ b/pkg/connector/helpers.go @@ -258,7 +258,8 @@ type listUsersQuery struct { type hasSAMLQuery struct { Organization struct { SamlIdentityProvider struct { - Id string + Id string + SsoUrl githubv4.String } } `graphql:"organization(login: $orgLoginName)"` } diff --git a/pkg/connector/user.go b/pkg/connector/user.go index 83464133..71f55070 100644 --- a/pkg/connector/user.go +++ b/pkg/connector/user.go @@ -87,31 +87,43 @@ func userResource(ctx context.Context, user *github.User, userEmail string, extr return ret, nil } -// enterpriseEmailInfo holds SAML identity data from the enterprise consumed licenses API. -// Fields are exported for JSON serialization via session.Store. -type enterpriseEmailInfo struct { - SAMLNameID string `json:"saml_name_id"` -} +type samlState int + +const ( + samlStateUnknown samlState = iota // not yet checked + samlStateOrgEnabled // org-level SAML, use GraphQL + samlStateEnterprise // enterprise SAML, use consumed licenses API + samlStateDisabled // no SAML +) -const enterpriseEmailPrefix = "enterprise-email:" -const enterpriseEmailCacheLoadedKey = "enterprise-email-cache-loaded" +const ( + // enterpriseSAMLKeyPrefix is prepended to each GitHub login to form + // individual session keys, e.g. "enterprise_saml:octocat". + enterpriseSAMLKeyPrefix = "enterprise_saml:" + + // enterpriseSAMLKeysIndex is the session key that stores the list of all + // enterprise_saml:* keys. This allows bulk-reading SAML mappings with + // GetManyJSON without scanning the entire session store. + enterpriseSAMLKeysIndex = "enterprise_saml_keys" +) type userResourceType struct { - resourceType *v2.ResourceType - client *github.Client - graphqlClient *githubv4.Client - hasSAMLEnabled *bool - orgCache *orgNameCache - orgs []string - customClient *customclient.Client - enterprises []string + resourceType *v2.ResourceType + client *github.Client + graphqlClient *githubv4.Client + samlStates map[string]samlState // per-org SAML state, keyed by org name + enterpriseSAMLFetched bool // true after first page fetches and stores SAML data + orgCache *orgNameCache + orgs []string + customClient *customclient.Client + enterprises []string } -func (o *userResourceType) ResourceType(_ context.Context) *v2.ResourceType { - return o.resourceType +func (u *userResourceType) ResourceType(_ context.Context) *v2.ResourceType { + return u.resourceType } -func (o *userResourceType) List(ctx context.Context, parentID *v2.ResourceId, opts resource.SyncOpAttrs) ([]*v2.Resource, *resource.SyncOpResults, error) { +func (u *userResourceType) List(ctx context.Context, parentID *v2.ResourceId, opts resource.SyncOpAttrs) ([]*v2.Resource, *resource.SyncOpResults, error) { l := ctxzap.Extract(ctx) var annotations annotations.Annotations if parentID == nil { @@ -123,15 +135,40 @@ func (o *userResourceType) List(ctx context.Context, parentID *v2.ResourceId, op return nil, nil, err } - orgName, err := o.orgCache.GetOrgName(ctx, opts.Session, parentID) + orgName, err := u.orgCache.GetOrgName(ctx, opts.Session, parentID) if err != nil { return nil, nil, err } - hasSamlBool, err := o.hasSAML(ctx, orgName, opts.Session) + currentSAMLState, err := u.checkOrgSAML(ctx, orgName) if err != nil { return nil, nil, err } + + // For enterprise SAML: on the first page, fetch from the API and store in + // session. On every page, bulk-read the mappings into a local map so the + // user loop can do plain map lookups with no session calls. + var enterpriseSAMLEmails map[string]string + if currentSAMLState == samlStateEnterprise { + if !u.enterpriseSAMLFetched { + if err := u.fetchAndStoreEnterpriseSAML(ctx, opts.Session); err != nil { + l.Warn("failed to fetch enterprise SAML emails, falling back to REST API emails", + zap.Error(err)) + u.enterpriseSAMLFetched = true + u.samlStates[orgName] = samlStateDisabled + currentSAMLState = samlStateDisabled + } else { + u.enterpriseSAMLFetched = true + } + } + if currentSAMLState == samlStateEnterprise { + enterpriseSAMLEmails, err = loadEnterpriseSAMLEmails(ctx, opts.Session) + if err != nil { + return nil, nil, err + } + } + } + var restApiRateLimit *v2.RateLimitDescription listOpts := github.ListMembersOptions{ @@ -141,7 +178,7 @@ func (o *userResourceType) List(ctx context.Context, parentID *v2.ResourceId, op }, } - users, resp, err := o.client.Organizations.ListMembers(ctx, orgName, &listOpts) + users, resp, err := u.client.Organizations.ListMembers(ctx, orgName, &listOpts) if err != nil { return nil, nil, wrapGitHubError(err, resp, "github-connector: failed to list organization members") } @@ -161,47 +198,41 @@ func (o *userResourceType) List(ctx context.Context, parentID *v2.ResourceId, op return nil, nil, err } - q := listUsersQuery{} + var lastGraphQLRateLimit *struct { + Limit int + Remaining int + ResetAt githubv4.DateTime + } rv := make([]*v2.Resource, 0, len(users)) for _, user := range users { - u, res, err := o.client.Users.GetByID(ctx, user.GetID()) + ghUser, res, err := u.client.Users.GetByID(ctx, user.GetID()) if err != nil { // This undocumented API can return 404 for some users. If this fails it means we won't get some of their details like email if isNotFoundError(res) { l.Warn("error fetching user by id", zap.Error(err), zap.Int64("user_id", user.GetID())) - u = user + ghUser = user } else { return nil, nil, wrapGitHubError(err, res, "github-connector: failed to get user by id") } } - userEmail := u.GetEmail() + userEmail := ghUser.GetEmail() var extraEmails []string - if hasSamlBool { + + switch currentSAMLState { + case samlStateUnknown: + return nil, nil, fmt.Errorf("baton-github: unexpected unknown SAML state for org %s", orgName) + case samlStateOrgEnabled: + q := listUsersQuery{} variables := map[string]interface{}{ "orgLoginName": githubv4.String(orgName), - "userName": githubv4.String(u.GetLogin()), + "userName": githubv4.String(ghUser.GetLogin()), } - err = o.graphqlClient.Query(ctx, &q, variables) + err = u.graphqlClient.Query(ctx, &q, variables) + if err != nil { - // When SAML is configured at the Enterprise level (not org level), - // GitHub returns this error. Fall back to using the regular user email - // and disable further SAML queries for this connector instance. - if strings.Contains(err.Error(), "SAML identity provider is disabled when an Enterprise SAML identity provider is available") { - l.Info("org SAML disabled in favor of Enterprise SAML, falling back to enterprise consumed licenses API for email enrichment", - zap.String("org", orgName), - zap.String("user", u.GetLogin())) - samlDisabled := false - o.hasSAMLEnabled = &samlDisabled - hasSamlBool = false - // Load enterprise email data so we can enrich users - if loadErr := o.loadEnterpriseEmailCache(ctx, opts.Session); loadErr != nil { - l.Warn("failed to load enterprise email cache", zap.Error(loadErr)) - } - } else { - return nil, nil, fmt.Errorf("baton-github: GraphQL SAML identity query failed for user %s in org %s: %w", u.GetLogin(), orgName, err) - } + return nil, nil, err } - if err == nil && len(q.Organization.SamlIdentityProvider.ExternalIdentities.Edges) == 1 { + if len(q.Organization.SamlIdentityProvider.ExternalIdentities.Edges) == 1 { samlIdent := q.Organization.SamlIdentityProvider.ExternalIdentities.Edges[0].Node.SamlIdentity userEmail = samlIdent.NameId setUserEmail := false @@ -223,20 +254,27 @@ func (o *userResourceType) List(ctx context.Context, parentID *v2.ResourceId, op } } } - } else if len(o.enterprises) > 0 { - // Org-level SAML is not available but enterprise is configured. - // Defer to the enterprise SAML cache as the source of truth — if - // the user has a SAML identity, use it. If not, leave email blank - // rather than using the REST API public email, which is not a - // corporate identity. - userEmail = o.getEnterpriseSAMLEmail(ctx, opts.Session, u.GetLogin()) - if userEmail != "" { - l.Debug("enriched user email from enterprise consumed licenses", - zap.String("user", u.GetLogin()), - zap.String("email", userEmail)) + lastGraphQLRateLimit = &struct { + Limit int + Remaining int + ResetAt githubv4.DateTime + }{ + Limit: q.RateLimit.Limit, + Remaining: q.RateLimit.Remaining, + ResetAt: q.RateLimit.ResetAt, } + + case samlStateEnterprise: + key := enterpriseSAMLKeyPrefix + strings.ToLower(ghUser.GetLogin()) + if samlEmail, ok := enterpriseSAMLEmails[key]; ok && isEmail(samlEmail) { + userEmail = samlEmail + } + + case samlStateDisabled: + // no SAML enrichment } - ur, err := userResource(ctx, u, userEmail, extraEmails) + + ur, err := userResource(ctx, ghUser, userEmail, extraEmails) if err != nil { return nil, nil, err } @@ -244,11 +282,11 @@ func (o *userResourceType) List(ctx context.Context, parentID *v2.ResourceId, op rv = append(rv, ur) } annotations.WithRateLimiting(restApiRateLimit) - if *o.hasSAMLEnabled && int64(q.RateLimit.Remaining) < restApiRateLimit.Remaining { + if lastGraphQLRateLimit != nil && int64(lastGraphQLRateLimit.Remaining) < restApiRateLimit.Remaining { graphqlRateLimit := &v2.RateLimitDescription{ - Limit: int64(q.RateLimit.Limit), - Remaining: int64(q.RateLimit.Remaining), - ResetAt: timestamppb.New(q.RateLimit.ResetAt.Time), + Limit: int64(lastGraphQLRateLimit.Limit), + Remaining: int64(lastGraphQLRateLimit.Remaining), + ResetAt: timestamppb.New(lastGraphQLRateLimit.ResetAt.Time), } annotations.WithRateLimiting(graphqlRateLimit) } @@ -264,20 +302,20 @@ func isEmail(email string) bool { return err == nil } -func (o *userResourceType) Entitlements(_ context.Context, _ *v2.Resource, _ resource.SyncOpAttrs) ([]*v2.Entitlement, *resource.SyncOpResults, error) { +func (u *userResourceType) Entitlements(_ context.Context, _ *v2.Resource, _ resource.SyncOpAttrs) ([]*v2.Entitlement, *resource.SyncOpResults, error) { return nil, &resource.SyncOpResults{}, nil } -func (o *userResourceType) Grants(_ context.Context, _ *v2.Resource, _ resource.SyncOpAttrs) ([]*v2.Grant, *resource.SyncOpResults, error) { +func (u *userResourceType) Grants(_ context.Context, _ *v2.Resource, _ resource.SyncOpAttrs) ([]*v2.Grant, *resource.SyncOpResults, error) { return nil, &resource.SyncOpResults{}, nil } -func (o *userResourceType) Delete(ctx context.Context, resourceId *v2.ResourceId) (annotations.Annotations, error) { +func (u *userResourceType) Delete(ctx context.Context, resourceId *v2.ResourceId) (annotations.Annotations, error) { if resourceId.ResourceType != resourceTypeUser.Id { return nil, fmt.Errorf("baton-github: non-user resource passed to user delete") } - orgs, err := getOrgs(ctx, o.client, o.orgs) + orgs, err := getOrgs(ctx, u.client, u.orgs) if err != nil { return nil, err } @@ -287,7 +325,7 @@ func (o *userResourceType) Delete(ctx context.Context, resourceId *v2.ResourceId return nil, fmt.Errorf("baton-github: invalid invitation id") } - user, resp, err := o.client.Users.GetByID(ctx, userID) + user, resp, err := u.client.Users.GetByID(ctx, userID) if err != nil { return nil, wrapGitHubError(err, resp, "baton-github: invalid userID") } @@ -296,7 +334,7 @@ func (o *userResourceType) Delete(ctx context.Context, resourceId *v2.ResourceId isRemoved = false ) for _, org := range orgs { - resp, err = o.client.Organizations.RemoveOrgMembership(ctx, user.GetLogin(), org) + resp, err = u.client.Organizations.RemoveOrgMembership(ctx, user.GetLogin(), org) if err == nil { isRemoved = true } @@ -316,156 +354,153 @@ func (o *userResourceType) Delete(ctx context.Context, resourceId *v2.ResourceId return annotations, nil } -func userBuilder( - client *github.Client, - hasSAMLEnabled *bool, - graphqlClient *githubv4.Client, - orgCache *orgNameCache, - orgs []string, - customClient *customclient.Client, - enterprises []string, -) *userResourceType { +func userBuilder(client *github.Client, graphqlClient *githubv4.Client, orgCache *orgNameCache, orgs []string, customClient *customclient.Client, enterprises []string) *userResourceType { return &userResourceType{ - resourceType: resourceTypeUser, - client: client, - graphqlClient: graphqlClient, - hasSAMLEnabled: hasSAMLEnabled, - orgCache: orgCache, - orgs: orgs, - customClient: customClient, - enterprises: enterprises, + resourceType: resourceTypeUser, + client: client, + graphqlClient: graphqlClient, + samlStates: make(map[string]samlState), + orgCache: orgCache, + orgs: orgs, + customClient: customClient, + enterprises: enterprises, } } -// loadEnterpriseEmailCache fetches enterprise consumed licenses and stores -// SAML NameID data in the session store, keyed by lowercase GitHub login. -func (o *userResourceType) loadEnterpriseEmailCache(ctx context.Context, ss sessions.SessionStore) error { - l := ctxzap.Extract(ctx) +// checkOrgSAML queries GitHub to determine the SAML configuration. +// Returns one of: samlStateOrgEnabled, samlStateEnterprise, samlStateDisabled. +// The result is cached on the struct so subsequent calls skip the GraphQL query. +func (u *userResourceType) checkOrgSAML(ctx context.Context, orgName string) (samlState, error) { + if state, ok := u.samlStates[orgName]; ok { + return state, nil + } - // Check if cache has already been loaded this sync. - _, found, err := session.GetJSON[bool](ctx, ss, enterpriseEmailCacheLoadedKey) + l := ctxzap.Extract(ctx) + q := hasSAMLQuery{} + variables := map[string]interface{}{ + "orgLoginName": githubv4.String(orgName), + } + err := u.graphqlClient.Query(ctx, &q, variables) if err != nil { - return err + // When SAML is configured at the Enterprise level (not org level), + // GitHub returns this error. + if strings.Contains(err.Error(), "SAML identity provider is disabled when an Enterprise SAML identity provider is available") { + if len(u.enterprises) == 0 { + l.Debug("enterprise SAML detected but no enterprises configured, skipping SAML enrichment", + zap.String("org", orgName)) + u.samlStates[orgName] = samlStateDisabled + return u.samlStates[orgName], nil + } + l.Debug("org SAML disabled in favor of Enterprise SAML, will use consumed licenses API", + zap.String("org", orgName)) + u.samlStates[orgName] = samlStateEnterprise + return u.samlStates[orgName], nil + } + return samlStateUnknown, err + } + if q.Organization.SamlIdentityProvider.Id == "" { + if len(u.enterprises) > 0 { + l.Debug("no org-level SAML provider found but enterprises configured, will try consumed licenses API", + zap.String("org", orgName)) + u.samlStates[orgName] = samlStateEnterprise + return u.samlStates[orgName], nil + } + l.Debug("no SAML identity provider found for org, disabling SAML enrichment", + zap.String("org", orgName)) + u.samlStates[orgName] = samlStateDisabled + return u.samlStates[orgName], nil } - if found { - return nil + + ssoUrl := string(q.Organization.SamlIdentityProvider.SsoUrl) + if strings.Contains(ssoUrl, "/enterprises/") && len(u.enterprises) > 0 { + l.Debug("SAML provider SSO URL points to enterprise, will use consumed licenses API", + zap.String("org", orgName), + zap.String("sso_url", ssoUrl)) + u.samlStates[orgName] = samlStateEnterprise + return u.samlStates[orgName], nil } - if o.customClient == nil || len(o.enterprises) == 0 { - _ = session.SetJSON(ctx, ss, enterpriseEmailCacheLoadedKey, true) - return nil + if strings.Contains(ssoUrl, "/enterprises/") && len(u.enterprises) == 0 { + l.Debug("SAML provider SSO URL points to enterprise but no enterprises configured, skipping SAML enrichment", + zap.String("org", orgName), + zap.String("sso_url", ssoUrl)) + u.samlStates[orgName] = samlStateDisabled + return u.samlStates[orgName], nil } - userCount := 0 - for _, enterprise := range o.enterprises { + l.Debug("org-level SAML provider found, will use GraphQL for SAML identity lookups", + zap.String("org", orgName), + zap.String("sso_url", ssoUrl)) + u.samlStates[orgName] = samlStateOrgEnabled + return u.samlStates[orgName], nil +} + +// fetchAndStoreEnterpriseSAML pages through the consumed licenses API for all +// configured enterprises, aggregates the login-to-SAML-email mappings, and +// writes them to the session store in a single batch. It also stores the list +// of keys under enterpriseSAMLKeysIndex so that loadEnterpriseSAMLEmails can +// bulk-read them back on subsequent List pages. +func (u *userResourceType) fetchAndStoreEnterpriseSAML(ctx context.Context, ss sessions.SessionStore) error { + l := ctxzap.Extract(ctx) + samlByLogin := make(map[string]string) + + for _, enterprise := range u.enterprises { + // GitHub's consumed-licenses API is 1-indexed; page 0 is undocumented + // and may return the same results as page 1, causing duplicates. page := 1 for { - consumedLicenses, _, err := o.customClient.ListEnterpriseConsumedLicenses(ctx, enterprise, page) + consumedLicenses, _, err := u.customClient.ListEnterpriseConsumedLicenses(ctx, enterprise, page) if err != nil { - l.Warn("failed to fetch enterprise consumed licenses", - zap.String("enterprise", enterprise), - zap.Int("page", page), - zap.String("endpoint", fmt.Sprintf("GET /enterprises/%s/consumed-licenses", enterprise)), - zap.Error(err), - ) - // Mark as loaded so we don't retry; partial data is still available. - _ = session.SetJSON(ctx, ss, enterpriseEmailCacheLoadedKey, true) - return fmt.Errorf("baton-github: failed to fetch enterprise consumed licenses for %s (page %d): %w", enterprise, page, err) + return fmt.Errorf("baton-github: error fetching enterprise consumed licenses for %s: %w", enterprise, err) } - if len(consumedLicenses.Users) == 0 { break } - batch := make(map[string]*enterpriseEmailInfo, len(consumedLicenses.Users)) for _, user := range consumedLicenses.Users { - if user.GitHubComLogin == "" { - continue + if user.GitHubComSAMLNameID != nil && *user.GitHubComSAMLNameID != "" && user.GitHubComLogin != "" { + key := enterpriseSAMLKeyPrefix + strings.ToLower(user.GitHubComLogin) + samlByLogin[key] = *user.GitHubComSAMLNameID } - info := &enterpriseEmailInfo{} - if user.GitHubComSAMLNameID != nil { - info.SAMLNameID = *user.GitHubComSAMLNameID - } - key := enterpriseEmailPrefix + strings.ToLower(user.GitHubComLogin) - batch[key] = info - } - if err := session.SetManyJSON(ctx, ss, batch); err != nil { - _ = session.SetJSON(ctx, ss, enterpriseEmailCacheLoadedKey, true) - return fmt.Errorf("baton-github: failed to store enterprise email batch for %s (page %d): %w", enterprise, page, err) } - userCount += len(batch) page++ } } - l.Info("loaded enterprise email cache", - zap.Int("user_count", userCount)) - _ = session.SetJSON(ctx, ss, enterpriseEmailCacheLoadedKey, true) - return nil -} - -// getEnterpriseSAMLEmail looks up the SAML NameID for a user from the -// enterprise consumed-licenses data stored in the session store. -// Returns empty string if no SAML email is available. -func (o *userResourceType) getEnterpriseSAMLEmail(ctx context.Context, ss sessions.SessionStore, login string) string { - key := enterpriseEmailPrefix + strings.ToLower(login) - info, found, err := session.GetJSON[enterpriseEmailInfo](ctx, ss, key) - if err != nil || !found { - return "" - } + if len(samlByLogin) > 0 { + if err := session.SetManyJSON(ctx, ss, samlByLogin); err != nil { + return fmt.Errorf("baton-github: error storing enterprise SAML mappings: %w", err) + } - if info.SAMLNameID != "" && isEmail(info.SAMLNameID) { - return info.SAMLNameID + keys := make([]string, 0, len(samlByLogin)) + for k := range samlByLogin { + keys = append(keys, k) + } + if err := session.SetJSON(ctx, ss, enterpriseSAMLKeysIndex, keys); err != nil { + return fmt.Errorf("baton-github: error storing enterprise SAML key index: %w", err) + } } - return "" + l.Debug("stored enterprise SAML mappings in session", zap.Int("count", len(samlByLogin))) + return nil } -func (o *userResourceType) hasSAML(ctx context.Context, orgName string, ss sessions.SessionStore) (bool, error) { - if o.hasSAMLEnabled != nil { - return *o.hasSAMLEnabled, nil - } - - l := ctxzap.Extract(ctx) - samlBool := false - q := hasSAMLQuery{} - variables := map[string]interface{}{ - "orgLoginName": githubv4.String(orgName), - } - err := o.graphqlClient.Query(ctx, &q, variables) +// loadEnterpriseSAMLEmails bulk-reads all enterprise SAML mappings from the +// session store in two calls: one to get the key index, one to get the values. +// Returns a map of "enterprise_saml:" -> SAML email for use as a local +// lookup table in the List loop (no session calls needed per user). +func loadEnterpriseSAMLEmails(ctx context.Context, ss sessions.SessionStore) (map[string]string, error) { + keys, found, err := session.GetJSON[[]string](ctx, ss, enterpriseSAMLKeysIndex) if err != nil { - // When SAML is configured at the Enterprise level (not org level), - // GitHub returns this error. Fall back to treating SAML as disabled. - if strings.Contains(err.Error(), "SAML identity provider is disabled when an Enterprise SAML identity provider is available") { - l.Info("org SAML disabled in favor of Enterprise SAML, will use enterprise consumed licenses API for email enrichment", - zap.String("org", orgName)) - o.hasSAMLEnabled = &samlBool - // Proactively load enterprise email data - if loadErr := o.loadEnterpriseEmailCache(ctx, ss); loadErr != nil { - l.Warn("failed to load enterprise email cache", zap.Error(loadErr)) - } - return false, nil - } - l.Warn("GraphQL SAML provider query failed", - zap.String("org", orgName), - zap.Error(err), - ) - return false, fmt.Errorf("baton-github: GraphQL SAML provider query failed for org %s: %w", orgName, err) + return nil, fmt.Errorf("baton-github: error reading enterprise SAML key index: %w", err) } - if q.Organization.SamlIdentityProvider.Id != "" { - samlBool = true + if !found || len(keys) == 0 { + return nil, nil } - o.hasSAMLEnabled = &samlBool - // If org has no SAML but we have enterprises configured, proactively - // load the enterprise email cache for email enrichment. - if !samlBool && len(o.enterprises) > 0 { - l.Info("org has no SAML provider, will use enterprise consumed licenses API for email enrichment", - zap.String("org", orgName)) - if loadErr := o.loadEnterpriseEmailCache(ctx, ss); loadErr != nil { - l.Warn("failed to load enterprise email cache", zap.Error(loadErr)) - } + samlByLogin, err := session.GetManyJSON[string](ctx, ss, keys) + if err != nil { + return nil, fmt.Errorf("baton-github: error reading enterprise SAML mappings: %w", err) } - - return *o.hasSAMLEnabled, nil + return samlByLogin, nil } diff --git a/pkg/connector/user_test.go b/pkg/connector/user_test.go index 16cb5447..6b21d2f3 100644 --- a/pkg/connector/user_test.go +++ b/pkg/connector/user_test.go @@ -2,7 +2,6 @@ package connector import ( "context" - "fmt" "testing" "github.com/conductorone/baton-github/test" @@ -16,58 +15,45 @@ import ( func TestUsersList(t *testing.T) { ctx := context.Background() - trueBool, falseBool := true, false - - testCases := []struct { - hasSamlEnabled *bool - message string - }{ - {&trueBool, "true"}, - {&falseBool, "false"}, - {nil, "nil"}, - } - for _, testCase := range testCases { - t.Run(fmt.Sprintf("should get a list of users (SAML:%s)", testCase.message), func(t *testing.T) { - mgh := mocks.NewMockGitHub() - - githubOrganization, _, _, githubUser, _, _ := mgh.Seed() - - organization, err := organizationResource( - ctx, - githubOrganization, - nil, - false, - ) - if err != nil { - t.Error(err) - } - - githubClient := github.NewClient(mgh.Server()) - graphQLClient := mocks.MockGraphQL() - cache := newOrgNameCache(githubClient) - client := userBuilder( - githubClient, - testCase.hasSamlEnabled, - graphQLClient, - cache, - []string{organization.DisplayName}, - nil, - nil, - ) - - users, results, err := client.List( - ctx, - organization.Id, - resourceSdk.SyncOpAttrs{ - PageToken: pagination.Token{}, - Session: &noOpSessionStore{}, - }, - ) - require.Nil(t, err) - test.AssertHasRatelimitAnnotations(t, results.Annotations) - require.Equal(t, "", results.NextPageToken) - require.Len(t, users, 1) - require.Equal(t, *githubUser.Login, users[0].Id.Resource) - }) - } + t.Run("should get a list of users", func(t *testing.T) { + mgh := mocks.NewMockGitHub() + + githubOrganization, _, _, githubUser, _, _ := mgh.Seed() + + organization, err := organizationResource( + ctx, + githubOrganization, + nil, + false, + ) + if err != nil { + t.Error(err) + } + + githubClient := github.NewClient(mgh.Server()) + graphQLClient := mocks.MockGraphQL() + cache := newOrgNameCache(githubClient) + client := userBuilder( + githubClient, + graphQLClient, + cache, + []string{organization.DisplayName}, + nil, + nil, + ) + + users, results, err := client.List( + ctx, + organization.Id, + resourceSdk.SyncOpAttrs{ + PageToken: pagination.Token{}, + Session: &noOpSessionStore{}, + }, + ) + require.Nil(t, err) + test.AssertHasRatelimitAnnotations(t, results.Annotations) + require.Equal(t, "", results.NextPageToken) + require.Len(t, users, 1) + require.Equal(t, *githubUser.Login, users[0].Id.Resource) + }) } diff --git a/pkg/customclient/client.go b/pkg/customclient/client.go index 3aa40e02..85485447 100644 --- a/pkg/customclient/client.go +++ b/pkg/customclient/client.go @@ -32,7 +32,6 @@ func (c *Client) ListEnterpriseConsumedLicenses(ctx context.Context, enterprise q := req.URL.Query() q.Add("page", fmt.Sprintf("%d", page)) // GitHub REST API max per_page is 100, default is 30. - // https://docs.github.com/en/enterprise-cloud@latest/rest/enterprise-admin/license#list-enterprise-consumed-licenses q.Add("per_page", "100") req.URL.RawQuery = q.Encode()