diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index f881abc99..bbd4d3abc 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -11,5 +11,6 @@ ### Documentation ### Internal Changes +* Move cloud-based credential filtering from individual credential providers into `DefaultCredentialsProvider`. Azure strategies are skipped on GCP/AWS hosts in auto-detect mode; GCP strategies are skipped on Azure/AWS hosts. When `authType` is explicitly set, cloud filtering is bypassed so the named strategy is always attempted regardless of host cloud. ### API Changes diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java index 7e67ca72e..645b25211 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java @@ -77,10 +77,6 @@ private Optional getSubscription(DatabricksConfig config) { @Override public OAuthHeaderFactory configure(DatabricksConfig config) { - if (!config.isAzure()) { - return null; - } - try { AzureUtils.ensureHostPresent(config, mapper, this::tokenSourceFor); String resource = config.getEffectiveAzureLoginAppId(); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java index 99716890f..ff452156c 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java @@ -1,10 +1,14 @@ package com.databricks.sdk.core; import com.databricks.sdk.core.oauth.*; +import com.databricks.sdk.core.utils.Cloud; import com.databricks.sdk.support.InternalApi; import com.google.common.base.Strings; import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -18,6 +22,21 @@ public class DefaultCredentialsProvider implements CredentialsProvider { private static final Logger LOG = LoggerFactory.getLogger(DefaultCredentialsProvider.class); + // cloudRequirements declares the cloud each strategy requires. DefaultCredentialsProvider uses + // this to skip cloud-specific strategies in auto-detect mode when the host cloud does not match. + // Cloud filtering is bypassed when authType is explicitly set. + private static final Map CLOUD_REQUIREMENTS; + + static { + Map m = new HashMap<>(); + m.put("github-oidc-azure", Cloud.AZURE); + m.put("azure-client-secret", Cloud.AZURE); + m.put("azure-cli", Cloud.AZURE); + m.put("google-credentials", Cloud.GCP); + m.put("google-id", Cloud.GCP); + CLOUD_REQUIREMENTS = Collections.unmodifiableMap(m); + } + /* List of credential providers that will be tried in sequence */ private List providers = new ArrayList<>(); @@ -40,6 +59,11 @@ public NamedIDTokenSource(String name, IDTokenSource idTokenSource) { public DefaultCredentialsProvider() {} + /** For testing: creates a provider with a fixed list of credential providers. */ + DefaultCredentialsProvider(List providers) { + this.providers = new ArrayList<>(providers); + } + /** * Returns the current authentication type being used * @@ -60,14 +84,25 @@ public String authType() { @Override public synchronized HeaderFactory configure(DatabricksConfig config) { addDefaultCredentialsProviders(config); + boolean explicitAuthType = config.getAuthType() != null && !config.getAuthType().isEmpty(); for (CredentialsProvider provider : providers) { - if (config.getAuthType() != null - && !config.getAuthType().isEmpty() - && !provider.authType().equals(config.getAuthType())) { + if (explicitAuthType && !provider.authType().equals(config.getAuthType())) { LOG.info( "Ignoring {} auth, because {} is preferred", provider.authType(), config.getAuthType()); continue; } + // In auto-detect mode, skip cloud-specific strategies whose required cloud does not match + // the detected host cloud. When authType is explicitly set, cloud filtering is bypassed so + // that users can request any strategy regardless of detected cloud (e.g. "azure-cli" on GCP). + if (!explicitAuthType) { + Cloud requiredCloud = CLOUD_REQUIREMENTS.get(provider.authType()); + if (requiredCloud != null + && config.getDatabricksEnvironment().getCloud() != requiredCloud) { + LOG.debug( + "Skipping \"{}\" auth: not configured for {}", provider.authType(), requiredCloud); + continue; + } + } try { LOG.info("Trying {} auth", provider.authType()); HeaderFactory headerFactory = provider.configure(config); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/GoogleCredentialsCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/GoogleCredentialsCredentialsProvider.java index 463d2bab9..fc259ccef 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/GoogleCredentialsCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/GoogleCredentialsCredentialsProvider.java @@ -30,7 +30,7 @@ public String authType() { public HeaderFactory configure(DatabricksConfig config) { String host = config.getHost(); String googleCredentials = config.getGoogleCredentials(); - if (host == null || googleCredentials == null || !config.isGcp()) { + if (host == null || googleCredentials == null) { return null; } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/GoogleIdCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/GoogleIdCredentialsProvider.java index 376d691c5..fb39d0809 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/GoogleIdCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/GoogleIdCredentialsProvider.java @@ -27,7 +27,7 @@ public String authType() { public HeaderFactory configure(DatabricksConfig config) { String host = config.getHost(); String googleServiceAccount = config.getGoogleServiceAccount(); - if (host == null || googleServiceAccount == null || !config.isGcp()) { + if (host == null || googleServiceAccount == null) { return null; } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureGithubOidcCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureGithubOidcCredentialsProvider.java index 6a5a8c278..37ac37e44 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureGithubOidcCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureGithubOidcCredentialsProvider.java @@ -26,8 +26,7 @@ public String authType() { @Override public OAuthHeaderFactory configure(DatabricksConfig config) { - if (!config.isAzure() - || config.getAzureClientId() == null + if (config.getAzureClientId() == null || config.getAzureTenantId() == null || config.getHost() == null) { return null; diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java index eca52808d..f932524ea 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java @@ -27,9 +27,7 @@ public String authType() { @Override public OAuthHeaderFactory configure(DatabricksConfig config) { - if (!config.isAzure() - || config.getAzureClientId() == null - || config.getAzureClientSecret() == null) { + if (config.getAzureClientId() == null || config.getAzureClientSecret() == null) { return null; } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DefaultCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DefaultCredentialsProviderTest.java new file mode 100644 index 000000000..4a1a26bcc --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DefaultCredentialsProviderTest.java @@ -0,0 +1,72 @@ +package com.databricks.sdk.core; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.Collections; +import org.junit.jupiter.api.Test; + +class DefaultCredentialsProviderTest { + + /** A credentials provider that records whether configure() was called. */ + private static class RecordingCredentialsProvider implements CredentialsProvider { + private final String name; + private boolean called = false; + + RecordingCredentialsProvider(String name) { + this.name = name; + } + + @Override + public String authType() { + return name; + } + + @Override + public HeaderFactory configure(DatabricksConfig config) { + called = true; + return null; + } + + public boolean wasCalled() { + return called; + } + } + + /** + * In auto-detect mode (authType not set), azure-cli must be skipped on a GCP host because the + * detected cloud does not match the strategy's required cloud (AZURE). + */ + @Test + void testCloudFiltering_SkipsOnCloudMismatch() { + RecordingCredentialsProvider azureCli = new RecordingCredentialsProvider("azure-cli"); + DefaultCredentialsProvider provider = + new DefaultCredentialsProvider(Collections.singletonList(azureCli)); + + DatabricksConfig config = new DatabricksConfig().setHost("https://xyz.gcp.databricks.com/"); + // configure() throws because no provider succeeds; that's expected + assertThrows(DatabricksException.class, () -> provider.configure(config)); + + assertFalse( + azureCli.wasCalled(), "azure-cli should be skipped on a GCP host in auto-detect mode"); + } + + /** + * When authType is explicitly set, cloud filtering is bypassed so that the named strategy is + * always attempted regardless of the detected host cloud. + */ + @Test + void testCloudFiltering_BypassesOnExplicitAuthType() { + RecordingCredentialsProvider azureCli = new RecordingCredentialsProvider("azure-cli"); + DefaultCredentialsProvider provider = + new DefaultCredentialsProvider(Collections.singletonList(azureCli)); + + DatabricksConfig config = + new DatabricksConfig().setHost("https://xyz.gcp.databricks.com/").setAuthType("azure-cli"); + // configure() throws because azure-cli returns null; that's expected + assertThrows(DatabricksException.class, () -> provider.configure(config)); + + assertTrue( + azureCli.wasCalled(), + "azure-cli should be called when authType is explicitly set, even on a GCP host"); + } +}