diff --git a/cmd/root/models.go b/cmd/root/models.go index 4c3f028f9..9240d8722 100644 --- a/cmd/root/models.go +++ b/cmd/root/models.go @@ -95,8 +95,11 @@ func (f *modelsListFlags) runModelsListCommand(cmd *cobra.Command, args []string availableProviders[p] = true } - // Determine which model auto-selection would pick. - autoModel := config.AutoModelConfig(ctx, f.runConfig.ModelsGateway, env, f.runConfig.DefaultModel) + // Determine which model auto-selection would pick. DMR discovery is left + // out here (nil lister) so listing models stays a pure, side-effect-free + // operation; the default marker therefore reflects the static per-provider + // default rather than a locally-pulled DMR model. + autoModel := config.AutoModelConfig(ctx, f.runConfig.ModelsGateway, env, f.runConfig.DefaultModel, nil) rows := f.collectModels(ctx, availableProviders, autoModel) diff --git a/cmd/root/models_test.go b/cmd/root/models_test.go index 4b8cae72a..4839b3500 100644 --- a/cmd/root/models_test.go +++ b/cmd/root/models_test.go @@ -128,7 +128,7 @@ func TestModelsListCommand_DefaultMarker(t *testing.T) { // The auto-selected model should be marked as default rc := config.RuntimeConfig{} - autoModel := config.AutoModelConfig(t.Context(), "", rc.EnvProvider(), nil) + autoModel := config.AutoModelConfig(t.Context(), "", rc.EnvProvider(), nil, nil) for _, r := range rows { if r.Provider == autoModel.Provider && r.Model == autoModel.Model { assert.True(t, r.Default, "auto-selected model %s/%s should be marked as default", r.Provider, r.Model) diff --git a/pkg/config/auto.go b/pkg/config/auto.go index c59e4e14c..f324817b5 100644 --- a/pkg/config/auto.go +++ b/pkg/config/auto.go @@ -3,12 +3,21 @@ package config import ( "context" "fmt" + "log/slog" + "slices" "strings" "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/environment" ) +// DMRModelLister returns the IDs of the models currently available to Docker +// Model Runner (i.e. pulled locally). It is injected so DMR discovery can be +// stubbed in tests and disabled by callers that must stay side-effect-free: +// `docker agent models` passes nil to avoid shelling out to `docker model`, +// while the agent run path passes dmr.ListModels. +type DMRModelLister func(ctx context.Context) ([]string, error) + // providerConfig defines a cloud provider and how to detect/describe its API keys. type providerConfig struct { name string // provider name (e.g., "anthropic") @@ -37,8 +46,15 @@ var cloudProviders = []providerConfig{ } // AutoModelFallbackError is returned when auto model selection fails because -// no providers are available (no API keys configured and DMR not installed). -type AutoModelFallbackError struct{} +// no model could be initialized (no API keys configured and no usable Docker +// Model Runner model, e.g. DMR not installed or the pull was declined). +type AutoModelFallbackError struct { + // Cause is the underlying provider-initialization error, when available + // (for example "model pull declined by user"). It is surfaced in the + // message so the user understands why selection fell through, and exposed + // via Unwrap for errors.Is/As callers. + Cause error +} func (e *AutoModelFallbackError) Error() string { var hints []string @@ -46,12 +62,22 @@ func (e *AutoModelFallbackError) Error() string { hints = append(hints, fmt.Sprintf(" - %s: %s", p.name, p.hint)) } - return "No model providers available.\n\nTo fix this, you can:\n" + - " - Install Docker Model Runner: https://docs.docker.com/ai/model-runner/get-started/\n" + - " - Configure an API key for a cloud provider:\n" + - strings.Join(hints, "\n") + var b strings.Builder + if e.Cause != nil { + fmt.Fprintf(&b, "Could not initialize the auto-selected model: %v\n\n", e.Cause) + } + b.WriteString("No model is currently available.\n\nTo fix this, you can:\n") + b.WriteString(" - Pull a Docker Model Runner model, e.g. `docker model pull ai/qwen3`\n") + b.WriteString(" - Install Docker Model Runner: https://docs.docker.com/ai/model-runner/get-started/\n") + b.WriteString(" - Configure an API key for a cloud provider:\n") + b.WriteString(strings.Join(hints, "\n")) + return b.String() } +// Unwrap exposes the underlying initialization error so callers can inspect it +// with errors.Is/errors.As. +func (e *AutoModelFallbackError) Unwrap() error { return e.Cause } + var DefaultModels = map[string]string{ "openai": "gpt-5", "anthropic": "claude-sonnet-4-6", @@ -84,7 +110,7 @@ func AvailableProviders(ctx context.Context, modelsGateway string, env environme return providers } -func AutoModelConfig(ctx context.Context, modelsGateway string, env environment.Provider, defaultModel *latest.ModelConfig) latest.ModelConfig { +func AutoModelConfig(ctx context.Context, modelsGateway string, env environment.Provider, defaultModel *latest.ModelConfig, dmrLister DMRModelLister) latest.ModelConfig { // If user specified a default model config, use it (with defaults for unset fields) if defaultModel != nil && defaultModel.Provider != "" && defaultModel.Model != "" { result := *defaultModel @@ -97,13 +123,89 @@ func AutoModelConfig(ctx context.Context, modelsGateway string, env environment. availableProviders := AvailableProviders(ctx, modelsGateway, env) firstAvailable := availableProviders[0] + model := DefaultModels[firstAvailable] + if firstAvailable == "dmr" { + // Prefer a model the user already pulled so that, when DMR is set up + // with models other than ai/qwen3:latest, auto-selection doesn't force + // a pull prompt and then fail when it's declined. + model = pickDMRAutoModel(ctx, model, dmrLister) + } + return latest.ModelConfig{ Provider: firstAvailable, - Model: DefaultModels[firstAvailable], + Model: model, MaxTokens: PreferredMaxTokens(firstAvailable), } } +// pickDMRAutoModel chooses which Docker Model Runner model auto-selection +// should use. It prefers the configured default when it is already pulled +// locally; otherwise it falls back to the first locally-available +// (non-embedding) model. When discovery fails, finds nothing, or no lister is +// provided, it returns defaultModel unchanged, preserving the previous +// behavior of pulling the default on demand. +func pickDMRAutoModel(ctx context.Context, defaultModel string, lister DMRModelLister) string { + if lister == nil { + return defaultModel + } + + installed, err := lister(ctx) + if err != nil { + slog.DebugContext(ctx, "DMR model discovery failed during auto-selection, using default", "error", err, "default", defaultModel) + return defaultModel + } + if len(installed) == 0 { + return defaultModel + } + + // The default is already pulled: use it so behavior is unchanged for users + // who do have ai/qwen3:latest. + if slices.Contains(installed, defaultModel) { + return defaultModel + } + + // The default model pulled under a different tag (e.g. ai/qwen3:Q4_K_M) + // still satisfies "prefer the default", so match on the repository. + defaultRepo := dmrModelRepo(defaultModel) + for _, m := range installed { + if dmrModelRepo(m) == defaultRepo { + slog.DebugContext(ctx, "DMR auto-selection using default model under a non-default tag", "model", m, "default", defaultModel) + return m + } + } + + // installed is sorted by the lister; pick the first chat-capable model so + // the choice is deterministic and never lands on an embedding model. + for _, m := range installed { + if !looksLikeEmbeddingModel(m) { + slog.DebugContext(ctx, "DMR auto-selection using locally-available model", "model", m, "default_not_installed", defaultModel) + return m + } + } + + return defaultModel +} + +// dmrModelRepo returns the repository portion of a DMR model ID, dropping a +// trailing ":" suffix (e.g. both "ai/qwen3:latest" and "ai/qwen3:Q4_K_M" +// yield "ai/qwen3"). A trailing colon is only treated as a tag separator when +// the suffix has no slash, so a registry host:port like "registry:5000/ai/x" +// is preserved. +func dmrModelRepo(id string) string { + if i := strings.LastIndex(id, ":"); i >= 0 && !strings.Contains(id[i+1:], "/") { + return id[:i] + } + return id +} + +// looksLikeEmbeddingModel reports whether a DMR model ID names an embedding +// model, which should never be chosen as an agent's chat model. It is a simple +// name-substring heuristic (e.g. "ai/embeddinggemma"); the model picker layer +// applies a richer models.dev-backed check for display purposes. +func looksLikeEmbeddingModel(modelID string) bool { + return strings.Contains(strings.ToLower(modelID), "embed") +} + func PreferredMaxTokens(provider string) *int64 { var mt int64 = 32000 if provider == "dmr" { diff --git a/pkg/config/auto_test.go b/pkg/config/auto_test.go index 545d58f19..b89db35fc 100644 --- a/pkg/config/auto_test.go +++ b/pkg/config/auto_test.go @@ -1,6 +1,8 @@ package config import ( + "context" + "errors" "testing" "github.com/stretchr/testify/assert" @@ -219,7 +221,7 @@ func TestAutoModelConfig(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - modelConfig := AutoModelConfig(t.Context(), tt.gateway, environment.NewMapEnvProvider(tt.envVars), nil) + modelConfig := AutoModelConfig(t.Context(), tt.gateway, environment.NewMapEnvProvider(tt.envVars), nil, nil) assert.Equal(t, tt.expectedProvider, modelConfig.Provider) assert.Equal(t, tt.expectedModel, modelConfig.Model) @@ -319,7 +321,7 @@ func TestAutoModelConfig_IntegrationWithDefaultModels(t *testing.T) { envVars["MISTRAL_API_KEY"] = "test-key" } - modelConfig := AutoModelConfig(t.Context(), "", environment.NewMapEnvProvider(envVars), nil) + modelConfig := AutoModelConfig(t.Context(), "", environment.NewMapEnvProvider(envVars), nil, nil) // Verify the returned model matches the DefaultModels entry expectedModel := DefaultModels[provider] @@ -332,7 +334,7 @@ func TestAutoModelConfig_IntegrationWithDefaultModels(t *testing.T) { t.Run("dmr", func(t *testing.T) { t.Parallel() - modelConfig := AutoModelConfig(t.Context(), "", environment.NewNoEnvProvider(), nil) + modelConfig := AutoModelConfig(t.Context(), "", environment.NewNoEnvProvider(), nil, nil) assert.Equal(t, "dmr", modelConfig.Provider) assert.Equal(t, DefaultModels["dmr"], modelConfig.Model) @@ -448,7 +450,7 @@ func TestAutoModelConfig_UserDefaultModel(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - modelConfig := AutoModelConfig(t.Context(), "", environment.NewMapEnvProvider(tt.envVars), tt.defaultModel) + modelConfig := AutoModelConfig(t.Context(), "", environment.NewMapEnvProvider(tt.envVars), tt.defaultModel, nil) assert.Equal(t, tt.expectedProvider, modelConfig.Provider) assert.Equal(t, tt.expectedModel, modelConfig.Model) @@ -471,7 +473,7 @@ func TestAutoModelConfig_UserDefaultModelWithOptions(t *testing.T) { ThinkingBudget: thinkingBudget, } - modelConfig := AutoModelConfig(t.Context(), "", environment.NewNoEnvProvider(), defaultModel) + modelConfig := AutoModelConfig(t.Context(), "", environment.NewNoEnvProvider(), defaultModel, nil) assert.Equal(t, "anthropic", modelConfig.Provider) assert.Equal(t, "claude-sonnet-4-5", modelConfig.Model) @@ -479,3 +481,116 @@ func TestAutoModelConfig_UserDefaultModelWithOptions(t *testing.T) { assert.NotNil(t, modelConfig.ThinkingBudget) assert.Equal(t, 10000, modelConfig.ThinkingBudget.Tokens) } + +func TestAutoModelConfig_DMRLocalModels(t *testing.T) { + t.Parallel() + + lister := func(models []string, err error) DMRModelLister { + return func(context.Context) ([]string, error) { return models, err } + } + + tests := []struct { + name string + lister DMRModelLister + expectedModel string + }{ + { + name: "nil lister keeps the static default", + lister: nil, + expectedModel: DefaultModels["dmr"], + }, + { + name: "default model already pulled is used", + lister: lister([]string{"ai/gemma3:latest", "ai/qwen3:latest"}, nil), + expectedModel: "ai/qwen3:latest", + }, + { + name: "default not pulled falls back to first installed model", + lister: lister([]string{"ai/llama3.2:latest", "ai/smollm2:latest"}, nil), + expectedModel: "ai/llama3.2:latest", + }, + { + name: "default model under a different tag is preferred over other models", + lister: lister([]string{"ai/gemma3:latest", "ai/qwen3:Q4_K_M"}, nil), + expectedModel: "ai/qwen3:Q4_K_M", + }, + { + name: "embedding-only models are skipped, default retained", + lister: lister([]string{"ai/embeddinggemma", "ai/nomic-embed-text"}, nil), + expectedModel: DefaultModels["dmr"], + }, + { + name: "embedding models are skipped when a chat model exists", + lister: lister([]string{"ai/embeddinggemma", "ai/mistral:latest"}, nil), + expectedModel: "ai/mistral:latest", + }, + { + name: "discovery error keeps the static default", + lister: lister(nil, errors.New("dmr unreachable")), + expectedModel: DefaultModels["dmr"], + }, + { + name: "empty list keeps the static default", + lister: lister([]string{}, nil), + expectedModel: DefaultModels["dmr"], + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + modelConfig := AutoModelConfig(t.Context(), "", environment.NewNoEnvProvider(), nil, tt.lister) + + assert.Equal(t, "dmr", modelConfig.Provider) + assert.Equal(t, tt.expectedModel, modelConfig.Model) + assert.Equal(t, int64(16000), *modelConfig.MaxTokens) + }) + } +} + +func TestAutoModelConfig_DMRListerNotConsultedForCloudProvider(t *testing.T) { + t.Parallel() + + called := false + lister := func(context.Context) ([]string, error) { + called = true + return []string{"ai/qwen3:latest"}, nil + } + + // A cloud provider is available, so the DMR lister must never run. + modelConfig := AutoModelConfig( + t.Context(), + "", + environment.NewMapEnvProvider(map[string]string{"ANTHROPIC_API_KEY": "test-key"}), + nil, + lister, + ) + + assert.Equal(t, "anthropic", modelConfig.Provider) + assert.False(t, called, "DMR lister should not be consulted when a cloud provider is selected") +} + +func TestAutoModelFallbackError(t *testing.T) { + t.Parallel() + + t.Run("without cause", func(t *testing.T) { + t.Parallel() + + err := &AutoModelFallbackError{} + msg := err.Error() + assert.Contains(t, msg, "No model is currently available") + assert.Contains(t, msg, "docker model pull") + assert.Contains(t, msg, "ANTHROPIC_API_KEY") + assert.NotContains(t, msg, "Could not initialize") + }) + + t.Run("with cause is surfaced and unwrappable", func(t *testing.T) { + t.Parallel() + + cause := errors.New("model pull declined by user") + err := &AutoModelFallbackError{Cause: cause} + assert.Contains(t, err.Error(), "model pull declined by user") + assert.ErrorIs(t, err, cause) + }) +} diff --git a/pkg/config/latest/model_ref.go b/pkg/config/latest/model_ref.go index 2ab6c3c4c..2cab13352 100644 --- a/pkg/config/latest/model_ref.go +++ b/pkg/config/latest/model_ref.go @@ -6,8 +6,10 @@ import ( ) // ParseModelRef parses an inline "provider/model" reference into a -// ModelConfig. It returns an error when the string does not contain -// exactly one "/" separator or when either part is empty. +// ModelConfig. It splits on the first "/", so the model portion may itself +// contain slashes (e.g. "dmr/ai/qwen3:latest" yields provider "dmr" and model +// "ai/qwen3:latest"). It returns an error when there is no "/" or when either +// part is empty. // // cfg, err := ParseModelRef("openai/gpt-4o") // // cfg.Provider == "openai", cfg.Model == "gpt-4o" diff --git a/pkg/model/provider/dmr/list.go b/pkg/model/provider/dmr/list.go new file mode 100644 index 000000000..7ac652f2b --- /dev/null +++ b/pkg/model/provider/dmr/list.go @@ -0,0 +1,102 @@ +package dmr + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "os" + "slices" + "strings" + "time" + + "github.com/docker/docker-agent/pkg/config/latest" +) + +// listModelsTimeout bounds the /models request so a slow or wedged Docker +// Model Runner endpoint can't stall model discovery (the model picker and +// auto-selection both call ListModels synchronously). +const listModelsTimeout = 5 * time.Second + +// ListModels returns the IDs of the models available to Docker Model Runner +// (i.e. pulled locally), as reported by its OpenAI-compatible /models +// endpoint. IDs keep their full DMR form, e.g. "ai/qwen3:latest". The result +// is sorted for deterministic ordering. +// +// It returns ErrNotInstalled when Docker Model Runner is not installed, and a +// wrapped error when the endpoint is unreachable or returns an unparseable +// body. A nil error with an empty slice means DMR is reachable but has no +// models pulled. +func ListModels(ctx context.Context) ([]string, error) { + var endpoint string + if os.Getenv("MODEL_RUNNER_HOST") == "" { + ep, _, err := getDockerModelEndpointAndEngine(ctx) + if err != nil { + // Mirror NewClient: the unknown "--json" flag is the signal that + // the Docker installation predates Model Runner, i.e. DMR is not + // installed at all. + if strings.Contains(err.Error(), "unknown flag: --json") { + return nil, ErrNotInstalled + } + // Otherwise the docker CLI plugin may simply be unavailable while + // the engine still serves DMR on a default endpoint, so fall + // through and let resolveDMRBaseURL probe the defaults. + slog.DebugContext(ctx, "docker model status query failed while listing models", "error", err) + } + endpoint = ep + } + + baseURL, _, httpClient := resolveDMRBaseURL(ctx, &latest.ModelConfig{}, endpoint) + if httpClient == nil { + httpClient = &http.Client{} + } + + return listModelsAt(ctx, httpClient, baseURL) +} + +// listModelsAt fetches and parses the OpenAI-compatible /models response from +// the given DMR base URL. It is split out from ListModels so the HTTP handling +// can be unit-tested with an httptest server. +func listModelsAt(ctx context.Context, httpClient *http.Client, baseURL string) ([]string, error) { + modelsURL := strings.TrimSuffix(baseURL, "/") + "/models" + + ctx, cancel := context.WithTimeout(ctx, listModelsTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, modelsURL, http.NoBody) + if err != nil { + return nil, fmt.Errorf("creating DMR models request: %w", err) + } + + resp, err := httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("querying DMR models endpoint: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("DMR models endpoint returned status %d", resp.StatusCode) + } + + var body struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + } + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + return nil, fmt.Errorf("decoding DMR models response: %w", err) + } + + models := make([]string, 0, len(body.Data)) + for _, m := range body.Data { + if id := strings.TrimSpace(m.ID); id != "" { + models = append(models, id) + } + } + slices.Sort(models) + models = slices.Compact(models) + + slog.DebugContext(ctx, "Listed DMR models", "count", len(models), "base_url", baseURL) + return models, nil +} diff --git a/pkg/model/provider/dmr/list_test.go b/pkg/model/provider/dmr/list_test.go new file mode 100644 index 000000000..71e6f2795 --- /dev/null +++ b/pkg/model/provider/dmr/list_test.go @@ -0,0 +1,128 @@ +package dmr + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestListModelsAt(t *testing.T) { + t.Parallel() + + t.Run("parses and sorts model ids", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/models", r.URL.Path) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":[ + {"id":"ai/qwen3:latest"}, + {"id":"ai/gemma3:latest"}, + {"id":"ai/embeddinggemma"} + ]}`)) + })) + defer server.Close() + + models, err := listModelsAt(t.Context(), server.Client(), server.URL+"/") + require.NoError(t, err) + // Sorted, embedding models are NOT filtered here (callers do that). + assert.Equal(t, []string{"ai/embeddinggemma", "ai/gemma3:latest", "ai/qwen3:latest"}, models) + }) + + t.Run("empty list is not an error", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":[]}`)) + })) + defer server.Close() + + models, err := listModelsAt(t.Context(), server.Client(), server.URL) + require.NoError(t, err) + assert.Empty(t, models) + }) + + t.Run("blank ids are skipped and duplicates compacted", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":[ + {"id":"ai/qwen3:latest"}, + {"id":" "}, + {"id":""}, + {"id":"ai/qwen3:latest"} + ]}`)) + })) + defer server.Close() + + models, err := listModelsAt(t.Context(), server.Client(), server.URL) + require.NoError(t, err) + assert.Equal(t, []string{"ai/qwen3:latest"}, models) + }) + + t.Run("non-200 status is an error", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer server.Close() + + _, err := listModelsAt(t.Context(), server.Client(), server.URL) + require.Error(t, err) + }) + + t.Run("malformed body is an error", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`not json`)) + })) + defer server.Close() + + _, err := listModelsAt(t.Context(), server.Client(), server.URL) + require.Error(t, err) + }) + + t.Run("unreachable endpoint is an error", func(t *testing.T) { + t.Parallel() + + _, err := listModelsAt(t.Context(), &http.Client{}, "http://127.0.0.1:59998/") + require.Error(t, err) + }) +} + +// TestListModels exercises the exported entry point through MODEL_RUNNER_HOST, +// which makes resolveDMRBaseURL bypass the `docker model` CLI and return a +// nil http client (so ListModels falls back to its default client). It is not +// parallel because it mutates the environment. +func TestListModels(t *testing.T) { + t.Run("resolves via MODEL_RUNNER_HOST", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // MODEL_RUNNER_HOST + /engines/v1/ + models + assert.Equal(t, "/engines/v1/models", r.URL.Path) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":[{"id":"ai/qwen3:latest"},{"id":"ai/gemma3:latest"}]}`)) + })) + defer server.Close() + + t.Setenv("MODEL_RUNNER_HOST", server.URL) + + models, err := ListModels(t.Context()) + require.NoError(t, err) + assert.Equal(t, []string{"ai/gemma3:latest", "ai/qwen3:latest"}, models) + }) + + t.Run("unreachable MODEL_RUNNER_HOST returns an error, not a panic", func(t *testing.T) { + t.Setenv("MODEL_RUNNER_HOST", "http://127.0.0.1:59997") + + _, err := ListModels(t.Context()) + require.Error(t, err) + }) +} diff --git a/pkg/runtime/dmr_models.go b/pkg/runtime/dmr_models.go new file mode 100644 index 000000000..440d82b71 --- /dev/null +++ b/pkg/runtime/dmr_models.go @@ -0,0 +1,154 @@ +package runtime + +import ( + "context" + "log/slog" + "sync" + "time" + + "golang.org/x/sync/singleflight" + + "github.com/docker/docker-agent/pkg/modelsdev" +) + +// dmrModelsTTL is how long a Docker Model Runner /models response (or failure) +// is reused before DMR is queried again. It keeps the model picker snappy on +// repeated opens while still picking up newly-pulled models eventually. +const dmrModelsTTL = 1 * time.Minute + +// dmrModelsCache memoizes the result of DMR model discovery, including +// failures, so an unreachable or absent Model Runner is not re-queried on +// every picker open. The mutex only guards the cached fields; the lookup +// itself runs outside the lock, coalesced by the singleflight group so +// concurrent callers share one in-flight request. +type dmrModelsCache struct { + mu sync.Mutex + sf singleflight.Group + ids []string + err error + fetchedAt time.Time +} + +// listDMRModels returns the model IDs available to Docker Model Runner, using +// the runtime's cache when fresh. It returns (nil, nil) when no lister is +// configured (e.g. runtimes built directly in tests), so DMR discovery is +// opt-in via NewLocalRuntime. +func (r *LocalRuntime) listDMRModels(ctx context.Context) ([]string, error) { + if r.dmrModelLister == nil { + return nil, nil + } + + now := time.Now + if r.now != nil { + now = r.now + } + + c := &r.dmrModels + + readFresh := func() (ids []string, ok bool, err error) { + c.mu.Lock() + defer c.mu.Unlock() + if !c.fetchedAt.IsZero() && now().Sub(c.fetchedAt) < dmrModelsTTL { + return c.ids, true, c.err + } + return nil, false, nil + } + + if ids, ok, err := readFresh(); ok { + return ids, err + } + + v, err, _ := c.sf.Do("models", func() (any, error) { + // Double-check the cache now that we hold the in-flight slot: a caller + // that read a stale cache right before a concurrent singleflight + // completed would otherwise trigger a redundant fetch. + if ids, ok, err := readFresh(); ok { + return ids, err + } + + ids, err := r.dmrModelLister(ctx) + c.mu.Lock() + c.ids, c.err, c.fetchedAt = ids, err, now() + c.mu.Unlock() + return ids, err + }) + if err != nil { + return nil, err + } + return v.([]string), nil +} + +// buildDMRChoices builds ModelChoice entries for the models currently pulled +// in Docker Model Runner, deduplicated against the explicitly configured +// models. DMR models aren't part of the models.dev catalog, so without this +// the picker shows nothing for a working local Model Runner. When DMR is not +// installed or unreachable it returns nil. +func (r *LocalRuntime) buildDMRChoices(ctx context.Context) []ModelChoice { + ids, err := r.listDMRModels(ctx) + if err != nil { + slog.DebugContext(ctx, "DMR model discovery failed, skipping DMR picker entries", "error", err) + return nil + } + if len(ids) == 0 { + return nil + } + + existingRefs := make(map[string]bool, len(r.modelSwitcherCfg.Models)*2) + for name, cfg := range r.modelSwitcherCfg.Models { + existingRefs[name] = true + if cfg.Provider != "" && cfg.Model != "" { + existingRefs[cfg.Provider+"/"+cfg.Model] = true + } + } + + choices := make([]ModelChoice, 0, len(ids)) + for _, id := range ids { + // DMR model IDs (e.g. "ai/qwen3:latest") contain slashes; the ref is + // "dmr/" and ParseModelRef cuts on the first slash, so it + // round-trips back to provider="dmr", model="". + ref := "dmr/" + id + + // Resolve catalog metadata before the embedding filter so a model + // whose models.dev Family is "text-embedding" is filtered even when + // its ID doesn't contain "embed". + var meta *modelsdev.Model + if r.modelsStore != nil { + if m, err := r.modelsStore.GetModel(ctx, modelsdev.NewID("dmr", id)); err == nil { + meta = m + } + } + family := "" + if meta != nil { + family = meta.Family + } + if isEmbeddingModel(family, id) { + continue + } + + if existingRefs[ref] { + continue + } + existingRefs[ref] = true + + choice := ModelChoice{ + Name: id, + Ref: ref, + Provider: "dmr", + Model: id, + // Discovered (not explicitly configured), so it groups under the + // picker's "Other models" separator alongside gateway/catalog + // entries rather than intermixing with the configured models. + IsCatalog: true, + } + if meta != nil { + if meta.Name != "" { + choice.Name = meta.Name + } + applyCatalogMetadata(&choice, meta) + } + choices = append(choices, choice) + } + + slog.DebugContext(ctx, "Built DMR model choices", "count", len(choices)) + return choices +} diff --git a/pkg/runtime/dmr_models_test.go b/pkg/runtime/dmr_models_test.go new file mode 100644 index 000000000..7545f5aac --- /dev/null +++ b/pkg/runtime/dmr_models_test.go @@ -0,0 +1,219 @@ +package runtime + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/environment" + "github.com/docker/docker-agent/pkg/modelsdev" +) + +// dmrRuntime builds a LocalRuntime whose DMR discovery is stubbed with the +// given lister, no models gateway, and no configured models unless provided. +func dmrRuntime(lister func(context.Context) ([]string, error), store ModelStore, models map[string]latest.ModelConfig) *LocalRuntime { + return &LocalRuntime{ + modelsStore: store, + dmrModelLister: lister, + now: time.Now, + modelSwitcherCfg: &ModelSwitcherConfig{ + EnvProvider: environment.NewNoEnvProvider(), + Models: models, + }, + } +} + +func refsOf(choices []ModelChoice) []string { + out := make([]string, 0, len(choices)) + for _, c := range choices { + out = append(out, c.Ref) + } + return out +} + +func TestBuildDMRChoices(t *testing.T) { + t.Parallel() + + t.Run("nil lister yields no DMR entries", func(t *testing.T) { + t.Parallel() + + r := dmrRuntime(nil, nil, nil) + assert.Empty(t, r.buildDMRChoices(t.Context())) + }) + + t.Run("discovery error yields no DMR entries", func(t *testing.T) { + t.Parallel() + + r := dmrRuntime(func(context.Context) ([]string, error) { + return nil, errors.New("dmr not installed") + }, nil, nil) + assert.Empty(t, r.buildDMRChoices(t.Context())) + }) + + t.Run("installed models become dmr-prefixed choices", func(t *testing.T) { + t.Parallel() + + r := dmrRuntime(func(context.Context) ([]string, error) { + return []string{"ai/qwen3:latest", "ai/gemma3:latest"}, nil + }, nil, nil) + + choices := r.buildDMRChoices(t.Context()) + require.Len(t, choices, 2) + + byRef := map[string]ModelChoice{} + for _, c := range choices { + byRef[c.Ref] = c + } + + qwen, ok := byRef["dmr/ai/qwen3:latest"] + require.True(t, ok, "expected dmr/ai/qwen3:latest, got %v", refsOf(choices)) + assert.Equal(t, "dmr", qwen.Provider) + assert.Equal(t, "ai/qwen3:latest", qwen.Model) + assert.Equal(t, "ai/qwen3:latest", qwen.Name) + // Discovered (not configured) DMR models group with catalog entries. + assert.True(t, qwen.IsCatalog) + + // The ref round-trips back to provider="dmr" + the full model id. + parsed, err := latest.ParseModelRef(qwen.Ref) + require.NoError(t, err) + assert.Equal(t, "dmr", parsed.Provider) + assert.Equal(t, "ai/qwen3:latest", parsed.Model) + }) + + t.Run("embedding models are filtered out by name", func(t *testing.T) { + t.Parallel() + + r := dmrRuntime(func(context.Context) ([]string, error) { + return []string{"ai/qwen3:latest", "ai/embeddinggemma"}, nil + }, nil, nil) + + assert.Equal(t, []string{"dmr/ai/qwen3:latest"}, refsOf(r.buildDMRChoices(t.Context()))) + }) + + t.Run("embedding models are filtered out by catalog family even without an 'embed' id substring", func(t *testing.T) { + t.Parallel() + + // The vector model's ID contains no "embed" substring; only the + // models.dev Family marks it, exercising the metadata-before-filter + // ordering in buildDMRChoices. + store := stubModelStore{models: map[string]*modelsdev.Model{ + "dmr/ai/nomic-text-v1.5": {Name: "Nomic Text", Family: "text-embedding"}, + }} + + r := dmrRuntime(func(context.Context) ([]string, error) { + return []string{"ai/qwen3:latest", "ai/nomic-text-v1.5"}, nil + }, store, nil) + + assert.Equal(t, []string{"dmr/ai/qwen3:latest"}, refsOf(r.buildDMRChoices(t.Context()))) + }) + + t.Run("models already in config are deduplicated", func(t *testing.T) { + t.Parallel() + + r := dmrRuntime(func(context.Context) ([]string, error) { + return []string{"ai/qwen3:latest", "ai/gemma3:latest"}, nil + }, nil, map[string]latest.ModelConfig{ + "local": {Provider: "dmr", Model: "ai/qwen3:latest"}, + }) + + assert.Equal(t, []string{"dmr/ai/gemma3:latest"}, refsOf(r.buildDMRChoices(t.Context()))) + }) + + t.Run("catalog metadata is applied when available", func(t *testing.T) { + t.Parallel() + + store := stubModelStore{models: map[string]*modelsdev.Model{ + "dmr/ai/qwen3:latest": { + Name: "Qwen 3", + Family: "qwen", + Limit: modelsdev.Limit{Context: 32768}, + }, + }} + + r := dmrRuntime(func(context.Context) ([]string, error) { + return []string{"ai/qwen3:latest"}, nil + }, store, nil) + + choices := r.buildDMRChoices(t.Context()) + require.Len(t, choices, 1) + assert.Equal(t, "Qwen 3", choices[0].Name) + assert.Equal(t, "qwen", choices[0].Family) + assert.Equal(t, 32768, choices[0].ContextLimit) + }) +} + +func TestListDMRModelsCachesResult(t *testing.T) { + t.Parallel() + + var calls atomic.Int32 + r := dmrRuntime(func(context.Context) ([]string, error) { + calls.Add(1) + return []string{"ai/qwen3:latest"}, nil + }, nil, nil) + + for range 3 { + ids, err := r.listDMRModels(t.Context()) + require.NoError(t, err) + assert.Equal(t, []string{"ai/qwen3:latest"}, ids) + } + + assert.Equal(t, int32(1), calls.Load(), "lister should be called once within the TTL window") +} + +func TestListDMRModels_CachesFailure(t *testing.T) { + t.Parallel() + + var calls atomic.Int32 + r := dmrRuntime(func(context.Context) ([]string, error) { + calls.Add(1) + return nil, errors.New("dmr unreachable") + }, nil, nil) + + _, err := r.listDMRModels(t.Context()) + require.Error(t, err) + _, err = r.listDMRModels(t.Context()) + require.Error(t, err) + + assert.Equal(t, int32(1), calls.Load(), "failures must be cached to avoid re-probing DMR on every picker open") +} + +func TestListDMRModels_CacheExpires(t *testing.T) { + t.Parallel() + + var calls atomic.Int32 + r := dmrRuntime(func(context.Context) ([]string, error) { + calls.Add(1) + return []string{"ai/qwen3:latest"}, nil + }, nil, nil) + + now := time.Now() + r.now = func() time.Time { return now } + + _, err := r.listDMRModels(t.Context()) + require.NoError(t, err) + + now = now.Add(dmrModelsTTL + time.Second) + _, err = r.listDMRModels(t.Context()) + require.NoError(t, err) + + assert.Equal(t, int32(2), calls.Load(), "DMR must be re-queried after the cache TTL") +} + +func TestAvailableModelsIncludesDMR(t *testing.T) { + t.Parallel() + + // A stub store with no database makes buildCatalogChoices a no-op, so the + // only entries come from DMR discovery. + r := dmrRuntime(func(context.Context) ([]string, error) { + return []string{"ai/qwen3:latest"}, nil + }, stubModelStore{}, nil) + + got := refsOf(r.AvailableModels(t.Context())) + assert.Contains(t, got, "dmr/ai/qwen3:latest") +} diff --git a/pkg/runtime/model_switcher.go b/pkg/runtime/model_switcher.go index df8127739..2e0ef5a88 100644 --- a/pkg/runtime/model_switcher.go +++ b/pkg/runtime/model_switcher.go @@ -502,6 +502,11 @@ func (r *LocalRuntime) AvailableModels(ctx context.Context) []ModelChoice { } } + // Surface models pulled locally in Docker Model Runner. They are not part + // of the models.dev catalog, so without this a working local DMR setup + // would show nothing selectable in the picker. + choices = append(choices, r.buildDMRChoices(ctx)...) + // Append models.dev catalog entries filtered by available credentials catalogChoices := r.buildCatalogChoices(ctx) choices = append(choices, catalogChoices...) diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index ce5c38b7e..bb7d70d6e 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -23,6 +23,7 @@ import ( "github.com/docker/docker-agent/pkg/hooks/builtins" "github.com/docker/docker-agent/pkg/httpclient" "github.com/docker/docker-agent/pkg/model/provider" + "github.com/docker/docker-agent/pkg/model/provider/dmr" "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/sessiontitle" @@ -212,6 +213,7 @@ type LocalRuntime struct { modelSwitcherCfg *ModelSwitcherConfig providerRegistry *provider.Registry gatewayModels gatewayModelsCache + dmrModels dmrModelsCache // hooksRegistry is the runtime-private hooks.Registry used to build // every Executor. It carries the runtime-owned builtin hooks @@ -268,6 +270,12 @@ type LocalRuntime struct { bgAgents *agenttool.Handler + // dmrModelLister lists the models pulled locally in Docker Model Runner, + // used to populate DMR entries in the model picker. Defaults to + // dmr.ListModels in NewLocalRuntime; left nil by runtimes built directly + // (e.g. tests) so DMR discovery stays opt-in. Tests inject a stub here. + dmrModelLister func(ctx context.Context) ([]string, error) + // now is the runtime's clock. Defaults to time.Now and can be replaced // in tests via WithClock to make timestamps and cooldown windows // deterministic. Every time-dependent call inside the runtime (message @@ -549,6 +557,7 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) { providerRegistry: provider.DefaultRegistry(), maxOverflowCompactions: defaultMaxOverflowCompactions, toolListTimeout: defaultToolListTimeout, + dmrModelLister: dmr.ListModels, } r.bgAgents = agenttool.NewHandler(r) diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index 34ad661d2..614be3628 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -188,7 +188,7 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c agentsByName := make(map[string]*agent.Agent) autoModel := sync.OnceValue(func() latest.ModelConfig { - return config.AutoModelConfig(ctx, runConfig.ModelsGateway, env, runConfig.DefaultModel) + return config.AutoModelConfig(ctx, runConfig.ModelsGateway, env, runConfig.DefaultModel, dmr.ListModels) }) expander := js.NewJsExpander(env) @@ -417,9 +417,11 @@ func getModelsForAgent(ctx context.Context, cfg *latest.Config, a *latest.AgentC opts..., ) if err != nil { - // Return a cleaner error message for auto model selection failures + // Return a cleaner error message for auto model selection failures, + // keeping the underlying cause (e.g. a declined DMR pull) so the + // message can explain why selection fell through. if isAutoModel { - return nil, &config.AutoModelFallbackError{} + return nil, &config.AutoModelFallbackError{Cause: err} } return nil, err }