diff --git a/cmd/auth/token.go b/cmd/auth/token.go index 106c6b9f0b9..df588dd7976 100644 --- a/cmd/auth/token.go +++ b/cmd/auth/token.go @@ -5,12 +5,16 @@ import ( "encoding/json" "errors" "fmt" + "strings" "time" "github.com/databricks/cli/libs/auth" + "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/databrickscfg/profile" + "github.com/databricks/databricks-sdk-go/config" "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/databricks/databricks-sdk-go/credentials/u2m/cache" + "github.com/manifoldco/promptui" "github.com/spf13/cobra" "golang.org/x/oauth2" ) @@ -22,7 +26,7 @@ func helpfulError(ctx context.Context, profile string, persistentAuth u2m.OAuthA func newTokenCommand(authArguments *auth.AuthArguments) *cobra.Command { cmd := &cobra.Command{ - Use: "token [HOST]", + Use: "token [HOST_OR_PROFILE]", Short: "Get authentication token", Long: `Get authentication token from the local cache in ~/.databricks/token-cache.json. Refresh the access token if it is expired. Note: This command only works with @@ -93,6 +97,19 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) { return nil, errors.New("providing both a profile and host is not supported") } + // If no --profile flag, try resolving the positional arg as a profile name. + // If it matches, use it. If not, fall through to host treatment. + if args.profileName == "" && len(args.args) == 1 { + candidateProfile, err := loadProfileByName(ctx, args.args[0], args.profiler) + if err != nil { + return nil, err + } + if candidateProfile != nil { + args.profileName = args.args[0] + args.args = nil + } + } + existingProfile, err := loadProfileByName(ctx, args.profileName, args.profiler) if err != nil { return nil, err @@ -113,6 +130,47 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) { return nil, err } + // When no profile was specified, check if multiple profiles match the + // effective cache key for this host. + if args.profileName == "" && args.authArguments.Host != "" { + cfg := &config.Config{ + Host: args.authArguments.Host, + AccountID: args.authArguments.AccountID, + Experimental_IsUnifiedHost: args.authArguments.IsUnifiedHost, + } + // Canonicalize first so HostType() can correctly identify account hosts + // even when the host string lacks a scheme (e.g. "accounts.cloud.databricks.com"). + cfg.CanonicalHostName() + var matchFn profile.ProfileMatchFunction + switch cfg.HostType() { + case config.AccountHost, config.UnifiedHost: + matchFn = profile.WithHostAndAccountID(args.authArguments.Host, args.authArguments.AccountID) + default: + matchFn = profile.WithHost(args.authArguments.Host) + } + + matchingProfiles, err := args.profiler.LoadProfiles(ctx, matchFn) + if err != nil && !errors.Is(err, profile.ErrNoConfiguration) { + return nil, err + } + if len(matchingProfiles) > 1 { + configPath, _ := args.profiler.GetPath(ctx) + if configPath == "" { + panic("configPath is empty but LoadProfiles returned multiple profiles") + } + if !cmdio.IsPromptSupported(ctx) { + names := strings.Join(matchingProfiles.Names(), " and ") + return nil, fmt.Errorf("%s match %s in %s. Use --profile to specify which profile to use", + names, args.authArguments.Host, configPath) + } + selected, err := askForMatchingProfile(ctx, matchingProfiles, args.authArguments.Host) + if err != nil { + return nil, err + } + args.profileName = selected + } + } + args.authArguments.Profile = args.profileName ctx, cancel := context.WithTimeout(ctx, args.tokenTimeout) @@ -149,3 +207,22 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) { } return t, nil } + +func askForMatchingProfile(ctx context.Context, profiles profile.Profiles, host string) (string, error) { + i, _, err := cmdio.RunSelect(ctx, &promptui.Select{ + Label: "Multiple profiles match " + host, + Items: profiles, + Searcher: profiles.SearchCaseInsensitive, + StartInSearchMode: true, + Templates: &promptui.SelectTemplates{ + Label: "{{ . | faint }}", + Active: `{{.Name | bold}} ({{.Host|faint}})`, + Inactive: `{{.Name}}`, + Selected: `{{ "Using profile" | faint }}: {{ .Name | bold }}`, + }, + }) + if err != nil { + return "", err + } + return profiles[i].Name, nil +} diff --git a/cmd/auth/token_test.go b/cmd/auth/token_test.go index 01c6e83ea3d..19acc989152 100644 --- a/cmd/auth/token_test.go +++ b/cmd/auth/token_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/databricks/cli/libs/auth" + "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/databrickscfg/profile" "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" @@ -89,6 +90,32 @@ func TestToken_loadToken(t *testing.T) { Host: "https://accounts.cloud.databricks.com", AccountID: "active", }, + { + Name: "workspace-a", + Host: "https://workspace-a.cloud.databricks.com", + }, + { + Name: "dup1", + Host: "https://shared.cloud.databricks.com", + }, + { + Name: "dup2", + Host: "https://shared.cloud.databricks.com", + }, + { + Name: "acct-dup1", + Host: "https://accounts.cloud.databricks.com", + AccountID: "same-account", + }, + { + Name: "acct-dup2", + Host: "https://accounts.cloud.databricks.com", + AccountID: "same-account", + }, + { + Name: "default.dev", + Host: "https://dev.cloud.databricks.com", + }, }, } tokenCache := &inMemoryTokenCache{ @@ -107,6 +134,18 @@ func TestToken_loadToken(t *testing.T) { RefreshToken: "active", Expiry: time.Now().Add(1 * time.Hour), }, + "workspace-a": { + RefreshToken: "workspace-a", + Expiry: time.Now().Add(1 * time.Hour), + }, + "https://workspace-a.cloud.databricks.com": { + RefreshToken: "workspace-a", + Expiry: time.Now().Add(1 * time.Hour), + }, + "default.dev": { + RefreshToken: "default.dev", + Expiry: time.Now().Add(1 * time.Hour), + }, }, } validateToken := func(resp *oauth2.Token) { @@ -116,6 +155,7 @@ func TestToken_loadToken(t *testing.T) { cases := []struct { name string + ctx context.Context args loadTokenArgs validateToken func(*oauth2.Token) wantErr string @@ -223,10 +263,169 @@ func TestToken_loadToken(t *testing.T) { }, validateToken: validateToken, }, + { + name: "positional arg resolved as profile name", + args: loadTokenArgs{ + authArguments: &auth.AuthArguments{}, + profileName: "", + args: []string{"workspace-a"}, + tokenTimeout: 1 * time.Hour, + profiler: profiler, + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(tokenCache), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshSuccessTokenResponse}}), + }, + }, + validateToken: validateToken, + }, + { + name: "positional arg with dot treated as host when no profile matches", + args: loadTokenArgs{ + authArguments: &auth.AuthArguments{}, + profileName: "", + args: []string{"workspace-a.cloud.databricks.com"}, + tokenTimeout: 1 * time.Hour, + profiler: profiler, + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(tokenCache), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshSuccessTokenResponse}}), + }, + }, + validateToken: validateToken, + }, + { + name: "dotted profile name resolved as profile not host", + args: loadTokenArgs{ + authArguments: &auth.AuthArguments{}, + profileName: "", + args: []string{"default.dev"}, + tokenTimeout: 1 * time.Hour, + profiler: profiler, + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(tokenCache), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshSuccessTokenResponse}}), + }, + }, + validateToken: validateToken, + }, + { + name: "positional arg not a profile falls through to host", + args: loadTokenArgs{ + authArguments: &auth.AuthArguments{}, + profileName: "", + args: []string{"nonexistent"}, + tokenTimeout: 1 * time.Hour, + profiler: profiler, + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(tokenCache), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + }, + }, + wantErr: "cache: databricks OAuth is not configured for this host. " + + "Try logging in again with `databricks auth login --host https://nonexistent` before retrying. " + + "If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new", + }, + { + name: "scheme-less account host ambiguity detected correctly", + ctx: cmdio.MockDiscard(context.Background()), + args: loadTokenArgs{ + authArguments: &auth.AuthArguments{ + Host: "accounts.cloud.databricks.com", + AccountID: "same-account", + }, + profileName: "", + args: []string{}, + tokenTimeout: 1 * time.Hour, + profiler: profiler, + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(tokenCache), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + }, + }, + wantErr: "acct-dup1 and acct-dup2 match accounts.cloud.databricks.com in . Use --profile to specify which profile to use", + }, + { + name: "workspace host ambiguity — multiple profiles, non-interactive", + ctx: cmdio.MockDiscard(context.Background()), + args: loadTokenArgs{ + authArguments: &auth.AuthArguments{ + Host: "https://shared.cloud.databricks.com", + }, + profileName: "", + args: []string{}, + tokenTimeout: 1 * time.Hour, + profiler: profiler, + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(tokenCache), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + }, + }, + wantErr: "dup1 and dup2 match https://shared.cloud.databricks.com in . Use --profile to specify which profile to use", + }, + { + name: "account host — same host, different account IDs — no ambiguity", + args: loadTokenArgs{ + authArguments: &auth.AuthArguments{ + Host: "https://accounts.cloud.databricks.com", + AccountID: "active", + }, + profileName: "", + args: []string{}, + tokenTimeout: 1 * time.Hour, + profiler: profiler, + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(tokenCache), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshSuccessTokenResponse}}), + }, + }, + validateToken: validateToken, + }, + { + name: "account host — same host AND same account ID — ambiguity", + ctx: cmdio.MockDiscard(context.Background()), + args: loadTokenArgs{ + authArguments: &auth.AuthArguments{ + Host: "https://accounts.cloud.databricks.com", + AccountID: "same-account", + }, + profileName: "", + args: []string{}, + tokenTimeout: 1 * time.Hour, + profiler: profiler, + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(tokenCache), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + }, + }, + wantErr: "acct-dup1 and acct-dup2 match https://accounts.cloud.databricks.com in . Use --profile to specify which profile to use", + }, + { + name: "profile flag + positional non-host arg still errors", + args: loadTokenArgs{ + authArguments: &auth.AuthArguments{}, + profileName: "active", + args: []string{"workspace-a"}, + tokenTimeout: 1 * time.Hour, + profiler: profiler, + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(tokenCache), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + }, + }, + wantErr: "providing both a profile and host is not supported", + }, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - got, err := loadToken(context.Background(), c.args) + ctx := c.ctx + if ctx == nil { + ctx = context.Background() + } + got, err := loadToken(ctx, c.args) if c.wantErr != "" { assert.Equal(t, c.wantErr, err.Error()) } else { diff --git a/libs/databrickscfg/profile/profiler.go b/libs/databrickscfg/profile/profiler.go index 5d1ea0e72f0..53ff7b305d2 100644 --- a/libs/databrickscfg/profile/profiler.go +++ b/libs/databrickscfg/profile/profiler.go @@ -2,6 +2,8 @@ package profile import ( "context" + + "github.com/databricks/databricks-sdk-go/config" ) type ProfileMatchFunction func(Profile) bool @@ -30,6 +32,29 @@ func WithName(name string) ProfileMatchFunction { } } +// WithHost returns a ProfileMatchFunction that matches profiles whose +// canonical host equals the given host. +func WithHost(host string) ProfileMatchFunction { + target := canonicalizeHost(host) + return func(p Profile) bool { + return p.Host != "" && canonicalizeHost(p.Host) == target + } +} + +// WithHostAndAccountID returns a ProfileMatchFunction that matches profiles +// by both canonical host and account ID. +func WithHostAndAccountID(host, accountID string) ProfileMatchFunction { + target := canonicalizeHost(host) + return func(p Profile) bool { + return p.Host != "" && canonicalizeHost(p.Host) == target && p.AccountID == accountID + } +} + +// canonicalizeHost normalizes a host using the SDK's canonical host logic. +func canonicalizeHost(host string) string { + return (&config.Config{Host: host}).CanonicalHostName() +} + type Profiler interface { LoadProfiles(context.Context, ProfileMatchFunction) (Profiles, error) GetPath(context.Context) (string, error) diff --git a/libs/databrickscfg/profile/profiler_test.go b/libs/databrickscfg/profile/profiler_test.go new file mode 100644 index 00000000000..75f4fa57d5b --- /dev/null +++ b/libs/databrickscfg/profile/profiler_test.go @@ -0,0 +1,111 @@ +package profile + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWithHost(t *testing.T) { + cases := []struct { + name string + inputHost string + profileHost string + want bool + }{ + { + name: "exact match with scheme", + inputHost: "https://myworkspace.cloud.databricks.com", + profileHost: "https://myworkspace.cloud.databricks.com", + want: true, + }, + { + name: "match without scheme on input", + inputHost: "myworkspace.cloud.databricks.com", + profileHost: "https://myworkspace.cloud.databricks.com", + want: true, + }, + { + name: "match stripping trailing slash", + inputHost: "https://myworkspace.cloud.databricks.com/", + profileHost: "https://myworkspace.cloud.databricks.com", + want: true, + }, + { + name: "match stripping path", + inputHost: "https://myworkspace.cloud.databricks.com/some/path?query=1", + profileHost: "https://myworkspace.cloud.databricks.com", + want: true, + }, + { + name: "no match different host", + inputHost: "https://other.cloud.databricks.com", + profileHost: "https://myworkspace.cloud.databricks.com", + want: false, + }, + { + name: "empty host on profile skipped", + inputHost: "https://myworkspace.cloud.databricks.com", + profileHost: "", + want: false, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + p := Profile{Host: c.profileHost} + fn := WithHost(c.inputHost) + assert.Equal(t, c.want, fn(p)) + }) + } +} + +func TestWithHostAndAccountID(t *testing.T) { + cases := []struct { + name string + inputHost string + inputAccountID string + profileHost string + profileAccountID string + want bool + }{ + { + name: "same host same account ID", + inputHost: "https://accounts.cloud.databricks.com", + inputAccountID: "abc123", + profileHost: "https://accounts.cloud.databricks.com", + profileAccountID: "abc123", + want: true, + }, + { + name: "same host different account ID", + inputHost: "https://accounts.cloud.databricks.com", + inputAccountID: "abc123", + profileHost: "https://accounts.cloud.databricks.com", + profileAccountID: "xyz789", + want: false, + }, + { + name: "different host same account ID", + inputHost: "https://other.cloud.databricks.com", + inputAccountID: "abc123", + profileHost: "https://accounts.cloud.databricks.com", + profileAccountID: "abc123", + want: false, + }, + { + name: "empty host on profile skipped", + inputHost: "https://accounts.cloud.databricks.com", + inputAccountID: "abc123", + profileHost: "", + profileAccountID: "abc123", + want: false, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + p := Profile{Host: c.profileHost, AccountID: c.profileAccountID} + fn := WithHostAndAccountID(c.inputHost, c.inputAccountID) + assert.Equal(t, c.want, fn(p)) + }) + } +}