Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions cmd/root/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,11 @@ func (f *modelsListFlags) runModelsListCommand(cmd *cobra.Command, args []string
availableProviders[p] = true
}

// Determine which model auto-selection would pick.
autoModel := config.AutoModelConfig(ctx, f.runConfig.ModelsGateway, env, f.runConfig.DefaultModel)
// Determine which model auto-selection would pick. DMR discovery is left
// out here (nil lister) so listing models stays a pure, side-effect-free
// operation; the default marker therefore reflects the static per-provider
// default rather than a locally-pulled DMR model.
autoModel := config.AutoModelConfig(ctx, f.runConfig.ModelsGateway, env, f.runConfig.DefaultModel, nil)

rows := f.collectModels(ctx, availableProviders, autoModel)

Expand Down
2 changes: 1 addition & 1 deletion cmd/root/models_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func TestModelsListCommand_DefaultMarker(t *testing.T) {

// The auto-selected model should be marked as default
rc := config.RuntimeConfig{}
autoModel := config.AutoModelConfig(t.Context(), "", rc.EnvProvider(), nil)
autoModel := config.AutoModelConfig(t.Context(), "", rc.EnvProvider(), nil, nil)
for _, r := range rows {
if r.Provider == autoModel.Provider && r.Model == autoModel.Model {
assert.True(t, r.Default, "auto-selected model %s/%s should be marked as default", r.Provider, r.Model)
Expand Down
118 changes: 110 additions & 8 deletions pkg/config/auto.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,21 @@ package config
import (
"context"
"fmt"
"log/slog"
"slices"
"strings"

"github.com/docker/docker-agent/pkg/config/latest"
"github.com/docker/docker-agent/pkg/environment"
)

// DMRModelLister returns the IDs of the models currently available to Docker
// Model Runner (i.e. pulled locally). It is injected so DMR discovery can be
// stubbed in tests and disabled by callers that must stay side-effect-free:
// `docker agent models` passes nil to avoid shelling out to `docker model`,
// while the agent run path passes dmr.ListModels.
type DMRModelLister func(ctx context.Context) ([]string, error)

// providerConfig defines a cloud provider and how to detect/describe its API keys.
type providerConfig struct {
name string // provider name (e.g., "anthropic")
Expand Down Expand Up @@ -37,21 +46,38 @@ var cloudProviders = []providerConfig{
}

// AutoModelFallbackError is returned when auto model selection fails because
// no providers are available (no API keys configured and DMR not installed).
type AutoModelFallbackError struct{}
// no model could be initialized (no API keys configured and no usable Docker
// Model Runner model, e.g. DMR not installed or the pull was declined).
type AutoModelFallbackError struct {
// Cause is the underlying provider-initialization error, when available
// (for example "model pull declined by user"). It is surfaced in the
// message so the user understands why selection fell through, and exposed
// via Unwrap for errors.Is/As callers.
Cause error
}

func (e *AutoModelFallbackError) Error() string {
var hints []string
for _, p := range cloudProviders {
hints = append(hints, fmt.Sprintf(" - %s: %s", p.name, p.hint))
}

return "No model providers available.\n\nTo fix this, you can:\n" +
" - Install Docker Model Runner: https://docs.docker.com/ai/model-runner/get-started/\n" +
" - Configure an API key for a cloud provider:\n" +
strings.Join(hints, "\n")
var b strings.Builder
if e.Cause != nil {
fmt.Fprintf(&b, "Could not initialize the auto-selected model: %v\n\n", e.Cause)
}
b.WriteString("No model is currently available.\n\nTo fix this, you can:\n")
b.WriteString(" - Pull a Docker Model Runner model, e.g. `docker model pull ai/qwen3`\n")
b.WriteString(" - Install Docker Model Runner: https://docs.docker.com/ai/model-runner/get-started/\n")
b.WriteString(" - Configure an API key for a cloud provider:\n")
b.WriteString(strings.Join(hints, "\n"))
return b.String()
}

// Unwrap exposes the underlying initialization error so callers can inspect it
// with errors.Is/errors.As.
func (e *AutoModelFallbackError) Unwrap() error { return e.Cause }

var DefaultModels = map[string]string{
"openai": "gpt-5",
"anthropic": "claude-sonnet-4-6",
Expand Down Expand Up @@ -84,7 +110,7 @@ func AvailableProviders(ctx context.Context, modelsGateway string, env environme
return providers
}

func AutoModelConfig(ctx context.Context, modelsGateway string, env environment.Provider, defaultModel *latest.ModelConfig) latest.ModelConfig {
func AutoModelConfig(ctx context.Context, modelsGateway string, env environment.Provider, defaultModel *latest.ModelConfig, dmrLister DMRModelLister) latest.ModelConfig {
// If user specified a default model config, use it (with defaults for unset fields)
if defaultModel != nil && defaultModel.Provider != "" && defaultModel.Model != "" {
result := *defaultModel
Expand All @@ -97,13 +123,89 @@ func AutoModelConfig(ctx context.Context, modelsGateway string, env environment.
availableProviders := AvailableProviders(ctx, modelsGateway, env)
firstAvailable := availableProviders[0]

model := DefaultModels[firstAvailable]
if firstAvailable == "dmr" {
// Prefer a model the user already pulled so that, when DMR is set up
// with models other than ai/qwen3:latest, auto-selection doesn't force
// a pull prompt and then fail when it's declined.
model = pickDMRAutoModel(ctx, model, dmrLister)
}

return latest.ModelConfig{
Provider: firstAvailable,
Model: DefaultModels[firstAvailable],
Model: model,
MaxTokens: PreferredMaxTokens(firstAvailable),
}
}

// pickDMRAutoModel chooses which Docker Model Runner model auto-selection
// should use. It prefers the configured default when it is already pulled
// locally; otherwise it falls back to the first locally-available
// (non-embedding) model. When discovery fails, finds nothing, or no lister is
// provided, it returns defaultModel unchanged, preserving the previous
// behavior of pulling the default on demand.
func pickDMRAutoModel(ctx context.Context, defaultModel string, lister DMRModelLister) string {
if lister == nil {
return defaultModel
}

installed, err := lister(ctx)
if err != nil {
slog.DebugContext(ctx, "DMR model discovery failed during auto-selection, using default", "error", err, "default", defaultModel)
return defaultModel
}
if len(installed) == 0 {
return defaultModel
}

// The default is already pulled: use it so behavior is unchanged for users
// who do have ai/qwen3:latest.
if slices.Contains(installed, defaultModel) {
return defaultModel
}

// The default model pulled under a different tag (e.g. ai/qwen3:Q4_K_M)
// still satisfies "prefer the default", so match on the repository.
defaultRepo := dmrModelRepo(defaultModel)
for _, m := range installed {
if dmrModelRepo(m) == defaultRepo {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[LOW] Repo-prefix fallback in pickDMRAutoModel skips the embedding-model filter

The general fallback loop (used when no exact tag match is found) correctly calls looksLikeEmbeddingModel:

for _, m := range installed {
    if !looksLikeEmbeddingModel(m) {
        return m   // safe: chat models only
    }
}

But the repo-prefix loop that runs first does not apply the same guard:

defaultRepo := dmrModelRepo(defaultModel)
for _, m := range installed {
    if dmrModelRepo(m) == defaultRepo {
        return m   // no embedding check here
    }
}

If an embedding-only variant that shares the same repository prefix as the default model is installed (e.g. a future ai/qwen3-embed:latest published under the ai/qwen3 namespace), auto-selection could return it as the chat model. The agent would then start and fail when it tries to use a text-embedding model for conversation.

The fix is a one-liner:

for _, m := range installed {
    if dmrModelRepo(m) == defaultRepo && !looksLikeEmbeddingModel(m) {
        return m
    }
}

slog.DebugContext(ctx, "DMR auto-selection using default model under a non-default tag", "model", m, "default", defaultModel)
return m
}
}

// installed is sorted by the lister; pick the first chat-capable model so
// the choice is deterministic and never lands on an embedding model.
for _, m := range installed {
if !looksLikeEmbeddingModel(m) {
slog.DebugContext(ctx, "DMR auto-selection using locally-available model", "model", m, "default_not_installed", defaultModel)
return m
}
}

return defaultModel
}

// dmrModelRepo returns the repository portion of a DMR model ID, dropping a
// trailing ":<tag>" suffix (e.g. both "ai/qwen3:latest" and "ai/qwen3:Q4_K_M"
// yield "ai/qwen3"). A trailing colon is only treated as a tag separator when
// the suffix has no slash, so a registry host:port like "registry:5000/ai/x"
// is preserved.
func dmrModelRepo(id string) string {
if i := strings.LastIndex(id, ":"); i >= 0 && !strings.Contains(id[i+1:], "/") {
return id[:i]
}
return id
}

// looksLikeEmbeddingModel reports whether a DMR model ID names an embedding
// model, which should never be chosen as an agent's chat model. It is a simple
// name-substring heuristic (e.g. "ai/embeddinggemma"); the model picker layer
// applies a richer models.dev-backed check for display purposes.
func looksLikeEmbeddingModel(modelID string) bool {
return strings.Contains(strings.ToLower(modelID), "embed")
}

func PreferredMaxTokens(provider string) *int64 {
var mt int64 = 32000
if provider == "dmr" {
Expand Down
125 changes: 120 additions & 5 deletions pkg/config/auto_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package config

import (
"context"
"errors"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -219,7 +221,7 @@ func TestAutoModelConfig(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

modelConfig := AutoModelConfig(t.Context(), tt.gateway, environment.NewMapEnvProvider(tt.envVars), nil)
modelConfig := AutoModelConfig(t.Context(), tt.gateway, environment.NewMapEnvProvider(tt.envVars), nil, nil)

assert.Equal(t, tt.expectedProvider, modelConfig.Provider)
assert.Equal(t, tt.expectedModel, modelConfig.Model)
Expand Down Expand Up @@ -319,7 +321,7 @@ func TestAutoModelConfig_IntegrationWithDefaultModels(t *testing.T) {
envVars["MISTRAL_API_KEY"] = "test-key"
}

modelConfig := AutoModelConfig(t.Context(), "", environment.NewMapEnvProvider(envVars), nil)
modelConfig := AutoModelConfig(t.Context(), "", environment.NewMapEnvProvider(envVars), nil, nil)

// Verify the returned model matches the DefaultModels entry
expectedModel := DefaultModels[provider]
Expand All @@ -332,7 +334,7 @@ func TestAutoModelConfig_IntegrationWithDefaultModels(t *testing.T) {
t.Run("dmr", func(t *testing.T) {
t.Parallel()

modelConfig := AutoModelConfig(t.Context(), "", environment.NewNoEnvProvider(), nil)
modelConfig := AutoModelConfig(t.Context(), "", environment.NewNoEnvProvider(), nil, nil)

assert.Equal(t, "dmr", modelConfig.Provider)
assert.Equal(t, DefaultModels["dmr"], modelConfig.Model)
Expand Down Expand Up @@ -448,7 +450,7 @@ func TestAutoModelConfig_UserDefaultModel(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

modelConfig := AutoModelConfig(t.Context(), "", environment.NewMapEnvProvider(tt.envVars), tt.defaultModel)
modelConfig := AutoModelConfig(t.Context(), "", environment.NewMapEnvProvider(tt.envVars), tt.defaultModel, nil)

assert.Equal(t, tt.expectedProvider, modelConfig.Provider)
assert.Equal(t, tt.expectedModel, modelConfig.Model)
Expand All @@ -471,11 +473,124 @@ func TestAutoModelConfig_UserDefaultModelWithOptions(t *testing.T) {
ThinkingBudget: thinkingBudget,
}

modelConfig := AutoModelConfig(t.Context(), "", environment.NewNoEnvProvider(), defaultModel)
modelConfig := AutoModelConfig(t.Context(), "", environment.NewNoEnvProvider(), defaultModel, nil)

assert.Equal(t, "anthropic", modelConfig.Provider)
assert.Equal(t, "claude-sonnet-4-5", modelConfig.Model)
assert.Equal(t, int64(64000), *modelConfig.MaxTokens)
assert.NotNil(t, modelConfig.ThinkingBudget)
assert.Equal(t, 10000, modelConfig.ThinkingBudget.Tokens)
}

func TestAutoModelConfig_DMRLocalModels(t *testing.T) {
t.Parallel()

lister := func(models []string, err error) DMRModelLister {
return func(context.Context) ([]string, error) { return models, err }
}

tests := []struct {
name string
lister DMRModelLister
expectedModel string
}{
{
name: "nil lister keeps the static default",
lister: nil,
expectedModel: DefaultModels["dmr"],
},
{
name: "default model already pulled is used",
lister: lister([]string{"ai/gemma3:latest", "ai/qwen3:latest"}, nil),
expectedModel: "ai/qwen3:latest",
},
{
name: "default not pulled falls back to first installed model",
lister: lister([]string{"ai/llama3.2:latest", "ai/smollm2:latest"}, nil),
expectedModel: "ai/llama3.2:latest",
},
{
name: "default model under a different tag is preferred over other models",
lister: lister([]string{"ai/gemma3:latest", "ai/qwen3:Q4_K_M"}, nil),
expectedModel: "ai/qwen3:Q4_K_M",
},
{
name: "embedding-only models are skipped, default retained",
lister: lister([]string{"ai/embeddinggemma", "ai/nomic-embed-text"}, nil),
expectedModel: DefaultModels["dmr"],
},
{
name: "embedding models are skipped when a chat model exists",
lister: lister([]string{"ai/embeddinggemma", "ai/mistral:latest"}, nil),
expectedModel: "ai/mistral:latest",
},
{
name: "discovery error keeps the static default",
lister: lister(nil, errors.New("dmr unreachable")),
expectedModel: DefaultModels["dmr"],
},
{
name: "empty list keeps the static default",
lister: lister([]string{}, nil),
expectedModel: DefaultModels["dmr"],
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

modelConfig := AutoModelConfig(t.Context(), "", environment.NewNoEnvProvider(), nil, tt.lister)

assert.Equal(t, "dmr", modelConfig.Provider)
assert.Equal(t, tt.expectedModel, modelConfig.Model)
assert.Equal(t, int64(16000), *modelConfig.MaxTokens)
})
}
}

func TestAutoModelConfig_DMRListerNotConsultedForCloudProvider(t *testing.T) {
t.Parallel()

called := false
lister := func(context.Context) ([]string, error) {
called = true
return []string{"ai/qwen3:latest"}, nil
}

// A cloud provider is available, so the DMR lister must never run.
modelConfig := AutoModelConfig(
t.Context(),
"",
environment.NewMapEnvProvider(map[string]string{"ANTHROPIC_API_KEY": "test-key"}),
nil,
lister,
)

assert.Equal(t, "anthropic", modelConfig.Provider)
assert.False(t, called, "DMR lister should not be consulted when a cloud provider is selected")
}

func TestAutoModelFallbackError(t *testing.T) {
t.Parallel()

t.Run("without cause", func(t *testing.T) {
t.Parallel()

err := &AutoModelFallbackError{}
msg := err.Error()
assert.Contains(t, msg, "No model is currently available")
assert.Contains(t, msg, "docker model pull")
assert.Contains(t, msg, "ANTHROPIC_API_KEY")
assert.NotContains(t, msg, "Could not initialize")
})

t.Run("with cause is surfaced and unwrappable", func(t *testing.T) {
t.Parallel()

cause := errors.New("model pull declined by user")
err := &AutoModelFallbackError{Cause: cause}
assert.Contains(t, err.Error(), "model pull declined by user")
assert.ErrorIs(t, err, cause)
})
}
6 changes: 4 additions & 2 deletions pkg/config/latest/model_ref.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ import (
)

// ParseModelRef parses an inline "provider/model" reference into a
// ModelConfig. It returns an error when the string does not contain
// exactly one "/" separator or when either part is empty.
// ModelConfig. It splits on the first "/", so the model portion may itself
// contain slashes (e.g. "dmr/ai/qwen3:latest" yields provider "dmr" and model
// "ai/qwen3:latest"). It returns an error when there is no "/" or when either
// part is empty.
//
// cfg, err := ParseModelRef("openai/gpt-4o")
// // cfg.Provider == "openai", cfg.Model == "gpt-4o"
Expand Down
Loading
Loading