Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmd/auth/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ func setHostAndAccountId(ctx context.Context, existingProfile *profile.Profile,
}
}

authArguments.Host = strings.TrimSuffix(authArguments.Host, "/")

// Determine the host type and handle account ID / workspace ID accordingly
cfg := &config.Config{
Host: authArguments.Host,
Expand Down
12 changes: 12 additions & 0 deletions cmd/auth/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,24 @@ func TestSetHost(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "val from --host", authArguments.Host)

// Test setting host from flag with trailing slash is stripped
authArguments.Host = "https://www.host1.com/"
err = setHostAndAccountId(ctx, profile1, &authArguments, []string{})
assert.NoError(t, err)
assert.Equal(t, "https://www.host1.com", authArguments.Host)

// Test setting host from argument
authArguments.Host = ""
err = setHostAndAccountId(ctx, profile1, &authArguments, []string{"val from [HOST]"})
assert.NoError(t, err)
assert.Equal(t, "val from [HOST]", authArguments.Host)

// Test setting host from argument with trailing slash is stripped
authArguments.Host = ""
err = setHostAndAccountId(ctx, profile1, &authArguments, []string{"https://www.host1.com/"})
assert.NoError(t, err)
assert.Equal(t, "https://www.host1.com", authArguments.Host)

// Test setting host from profile
authArguments.Host = ""
err = setHostAndAccountId(ctx, profile1, &authArguments, []string{})
Expand Down
36 changes: 36 additions & 0 deletions cmd/auth/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,22 @@ func TestToken_loadToken(t *testing.T) {
},
validateToken: validateToken,
},
{
name: "host with trailing slash is stripped",
args: loadTokenArgs{
authArguments: &auth.AuthArguments{Host: "https://accounts.cloud.databricks.com/", AccountID: "active"},
profileName: "",
args: []string{},
tokenTimeout: 1 * time.Hour,
profiler: profiler,
persistentAuthOpts: []u2m.PersistentAuthOption{
u2m.WithTokenCache(tokenCache),
u2m.WithOAuthEndpointSupplier(&MockApiClient{}),
u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshSuccessTokenResponse}}),
},
},
validateToken: validateToken,
},
{
name: "positional arg resolved as profile name",
args: loadTokenArgs{
Expand Down Expand Up @@ -596,6 +612,26 @@ func TestToken_loadToken(t *testing.T) {
},
validateToken: validateToken,
},
{
name: "no args, DATABRICKS_HOST env with trailing slash resolves",
setupCtx: func(ctx context.Context) context.Context {
ctx = env.Set(ctx, "DATABRICKS_HOST", "https://workspace-a.cloud.databricks.com/")
return ctx
},
args: loadTokenArgs{
authArguments: &auth.AuthArguments{},
profileName: "",
args: []string{},
tokenTimeout: 1 * time.Hour,
profiler: profiler,
persistentAuthOpts: []u2m.PersistentAuthOption{
u2m.WithTokenCache(tokenCache),
u2m.WithOAuthEndpointSupplier(&MockApiClient{}),
u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshSuccessTokenResponse}}),
},
},
validateToken: validateToken,
},
{
name: "no args, DATABRICKS_CONFIG_PROFILE env resolves",
setupCtx: func(ctx context.Context) context.Context {
Expand Down
21 changes: 12 additions & 9 deletions cmd/configure/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,27 @@ import (
"strings"
)

// normalizeHost normalizes host input to prevent double https:// prefixes.
// If the input already starts with https://, it returns it as-is.
// If the input doesn't start with https://, it prepends https://.
// normalizeHost ensures a https:// scheme is present and returns only scheme
// and host, consistent with the normalizeHost in libs/databrickscfg/host.go.
func normalizeHost(input string) string {
input = strings.TrimSpace(input)
u, err := url.Parse(input)
// If the input is not a valid URL, return it as-is
if err != nil {
return input
}

// If it already starts with https:// or http://, return as-is
if u.Scheme == "https" || u.Scheme == "http" {
if u.Scheme != "https" && u.Scheme != "http" {
u, err = url.Parse("https://" + input)
if err != nil {
return input
}
}

if u.Host == "" {
return input
}

// Otherwise, prepend https://
return "https://" + input
return (&url.URL{Scheme: u.Scheme, Host: u.Host}).String()
}

func validateHost(s string) error {
Expand All @@ -34,7 +37,7 @@ func validateHost(s string) error {
if u.Host == "" || u.Scheme != "https" {
return errors.New("must start with https://")
}
if u.Path != "" && u.Path != "/" {
if u.Path != "" {
return errors.New("must use empty path")
}
return nil
Expand Down
25 changes: 14 additions & 11 deletions cmd/configure/host_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,30 @@ func TestNormalizeHost(t *testing.T) {
expected string
}{
// Empty input
{"", "https://"},
{" ", "https://"},
{"", ""},
{" ", ""},

// Already has https://
{"https://example.databricks.com", "https://example.databricks.com"},
{"HTTPS://EXAMPLE.DATABRICKS.COM", "HTTPS://EXAMPLE.DATABRICKS.COM"},
{"https://example.databricks.com/", "https://example.databricks.com/"},
{"https://example.databricks.com/", "https://example.databricks.com"},

// Missing protocol (should add https://)
{"example.databricks.com", "https://example.databricks.com"},
{" example.databricks.com ", "https://example.databricks.com"},
{"subdomain.example.databricks.com", "https://subdomain.example.databricks.com"},
{"example.databricks.com/", "https://example.databricks.com"},

// Paths, query strings, and anchors are stripped
{"https://example.databricks.com/path", "https://example.databricks.com"},
{"https://example.databricks.com/path/", "https://example.databricks.com"},
{"https://example.databricks.com?query", "https://example.databricks.com"},
{"https://example.databricks.com#anchor", "https://example.databricks.com"},

// Edge cases
{"https://", "https://"},
{"example.com", "https://example.com"},
{"https://example.databricks.com/path", "https://example.databricks.com/path"},
{"https://example.databricks.com/path/", "https://example.databricks.com/path/"},
{"http://localhost:8080", "http://localhost:8080"},
{"http://localhost:8080/", "http://localhost:8080"},
}

for _, test := range tests {
Expand All @@ -50,15 +55,13 @@ func TestValidateHost(t *testing.T) {
err = validateHost("http://host")
assert.ErrorContains(t, err, "must start with https://")
err = validateHost("ftp://host")
assert.ErrorContains(t, err, "must start with https://")

// Must use empty path
assert.ErrorContains(t, err, "must start with https://")
err = validateHost("https://host/path")
assert.ErrorContains(t, err, "must use empty path")

// Ignore query params
err = validateHost("https://host/?query")
assert.NoError(t, err)
err = validateHost("https://host/")
// Valid hosts
err = validateHost("https://host")
assert.NoError(t, err)
}