Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
### Documentation

### Internal Changes
* Move cloud-based credential filtering from individual credential providers into `DefaultCredentialsProvider`. Azure strategies are skipped on GCP/AWS hosts in auto-detect mode; GCP strategies are skipped on Azure/AWS hosts. When `authType` is explicitly set, cloud filtering is bypassed so the named strategy is always attempted regardless of host cloud.

### API Changes
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,6 @@ private Optional<String> getSubscription(DatabricksConfig config) {

@Override
public OAuthHeaderFactory configure(DatabricksConfig config) {
if (!config.isAzure()) {
return null;
}

try {
AzureUtils.ensureHostPresent(config, mapper, this::tokenSourceFor);
String resource = config.getEffectiveAzureLoginAppId();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package com.databricks.sdk.core;

import com.databricks.sdk.core.oauth.*;
import com.databricks.sdk.core.utils.Cloud;
import com.databricks.sdk.support.InternalApi;
import com.google.common.base.Strings;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -18,6 +22,21 @@
public class DefaultCredentialsProvider implements CredentialsProvider {
private static final Logger LOG = LoggerFactory.getLogger(DefaultCredentialsProvider.class);

// cloudRequirements declares the cloud each strategy requires. DefaultCredentialsProvider uses
// this to skip cloud-specific strategies in auto-detect mode when the host cloud does not match.
// Cloud filtering is bypassed when authType is explicitly set.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something we already document in our SDK docs?

private static final Map<String, Cloud> CLOUD_REQUIREMENTS;

static {
Map<String, Cloud> m = new HashMap<>();
m.put("github-oidc-azure", Cloud.AZURE);
m.put("azure-client-secret", Cloud.AZURE);
m.put("azure-cli", Cloud.AZURE);
m.put("google-credentials", Cloud.GCP);
m.put("google-id", Cloud.GCP);
CLOUD_REQUIREMENTS = Collections.unmodifiableMap(m);
}

/* List of credential providers that will be tried in sequence */
private List<CredentialsProvider> providers = new ArrayList<>();

Expand All @@ -40,6 +59,11 @@ public NamedIDTokenSource(String name, IDTokenSource idTokenSource) {

public DefaultCredentialsProvider() {}

/** For testing: creates a provider with a fixed list of credential providers. */
DefaultCredentialsProvider(List<CredentialsProvider> providers) {
this.providers = new ArrayList<>(providers);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is only used in testing, is it possible to move it to tests rather than define here?


/**
* Returns the current authentication type being used
*
Expand All @@ -60,14 +84,25 @@ public String authType() {
@Override
public synchronized HeaderFactory configure(DatabricksConfig config) {
addDefaultCredentialsProviders(config);
boolean explicitAuthType = config.getAuthType() != null && !config.getAuthType().isEmpty();
for (CredentialsProvider provider : providers) {
if (config.getAuthType() != null
&& !config.getAuthType().isEmpty()
&& !provider.authType().equals(config.getAuthType())) {
if (explicitAuthType && !provider.authType().equals(config.getAuthType())) {
LOG.info(
"Ignoring {} auth, because {} is preferred", provider.authType(), config.getAuthType());
continue;
}
// In auto-detect mode, skip cloud-specific strategies whose required cloud does not match
// the detected host cloud. When authType is explicitly set, cloud filtering is bypassed so
// that users can request any strategy regardless of detected cloud (e.g. "azure-cli" on GCP).
if (!explicitAuthType) {
Cloud requiredCloud = CLOUD_REQUIREMENTS.get(provider.authType());
if (requiredCloud != null
&& config.getDatabricksEnvironment().getCloud() != requiredCloud) {
LOG.debug(
"Skipping \"{}\" auth: not configured for {}", provider.authType(), requiredCloud);
continue;
}
}
try {
LOG.info("Trying {} auth", provider.authType());
HeaderFactory headerFactory = provider.configure(config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public String authType() {
public HeaderFactory configure(DatabricksConfig config) {
String host = config.getHost();
String googleCredentials = config.getGoogleCredentials();
if (host == null || googleCredentials == null || !config.isGcp()) {
if (host == null || googleCredentials == null) {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public String authType() {
public HeaderFactory configure(DatabricksConfig config) {
String host = config.getHost();
String googleServiceAccount = config.getGoogleServiceAccount();
if (host == null || googleServiceAccount == null || !config.isGcp()) {
if (host == null || googleServiceAccount == null) {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ public String authType() {

@Override
public OAuthHeaderFactory configure(DatabricksConfig config) {
if (!config.isAzure()
|| config.getAzureClientId() == null
if (config.getAzureClientId() == null
|| config.getAzureTenantId() == null
|| config.getHost() == null) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ public String authType() {

@Override
public OAuthHeaderFactory configure(DatabricksConfig config) {
if (!config.isAzure()
|| config.getAzureClientId() == null
|| config.getAzureClientSecret() == null) {
if (config.getAzureClientId() == null || config.getAzureClientSecret() == null) {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package com.databricks.sdk.core;

import static org.junit.jupiter.api.Assertions.*;

import java.util.Collections;
import org.junit.jupiter.api.Test;

class DefaultCredentialsProviderTest {

/** A credentials provider that records whether configure() was called. */
private static class RecordingCredentialsProvider implements CredentialsProvider {
private final String name;
private boolean called = false;

RecordingCredentialsProvider(String name) {
this.name = name;
}

@Override
public String authType() {
return name;
}

@Override
public HeaderFactory configure(DatabricksConfig config) {
called = true;
return null;
}

public boolean wasCalled() {
return called;
}
}

/**
* In auto-detect mode (authType not set), azure-cli must be skipped on a GCP host because the
* detected cloud does not match the strategy's required cloud (AZURE).
*/
@Test
void testCloudFiltering_SkipsOnCloudMismatch() {
RecordingCredentialsProvider azureCli = new RecordingCredentialsProvider("azure-cli");
DefaultCredentialsProvider provider =
new DefaultCredentialsProvider(Collections.singletonList(azureCli));

DatabricksConfig config = new DatabricksConfig().setHost("https://xyz.gcp.databricks.com/");
// configure() throws because no provider succeeds; that's expected
assertThrows(DatabricksException.class, () -> provider.configure(config));

assertFalse(
azureCli.wasCalled(), "azure-cli should be skipped on a GCP host in auto-detect mode");
}

/**
* When authType is explicitly set, cloud filtering is bypassed so that the named strategy is
* always attempted regardless of the detected host cloud.
*/
@Test
void testCloudFiltering_BypassesOnExplicitAuthType() {
RecordingCredentialsProvider azureCli = new RecordingCredentialsProvider("azure-cli");
DefaultCredentialsProvider provider =
new DefaultCredentialsProvider(Collections.singletonList(azureCli));

DatabricksConfig config =
new DatabricksConfig().setHost("https://xyz.gcp.databricks.com/").setAuthType("azure-cli");
// configure() throws because azure-cli returns null; that's expected
assertThrows(DatabricksException.class, () -> provider.configure(config));

assertTrue(
azureCli.wasCalled(),
"azure-cli should be called when authType is explicitly set, even on a GCP host");
}
}
Loading