Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion apps/client/src/translations/en/translation.json
Original file line number Diff line number Diff line change
Expand Up @@ -1660,7 +1660,8 @@
"note_context_enabled": "Click to disable note context: {{title}}",
"note_context_disabled": "Click to include current note in context",
"no_provider_message": "No AI provider configured. Add one to start chatting.",
"add_provider": "Add AI Provider"
"add_provider": "Add AI Provider",
"free": "Free"
},
"sidebar_chat": {
"title": "AI Chat",
Expand Down Expand Up @@ -2340,6 +2341,7 @@
"delete_provider_confirmation": "Are you sure you want to delete the provider \"{{name}}\"?",
"api_key": "API Key",
"api_key_placeholder": "Enter your API key",
"base_url": "Base URL",
"cancel": "Cancel",
"mcp_title": "MCP (Model Context Protocol)",
"mcp_enabled": "MCP server",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ export default function ChatInputBar({
onClick={() => handleModelSelect(model.id)}
checked={chat.selectedModel === model.id}
>
{model.name} <small>({model.costDescription})</small>
{model.name}{model.costDescription && <> <small>({model.costDescription})</small></>}
</FormListItem>
))}
{legacyModels.length > 0 && (
Expand All @@ -169,7 +169,7 @@ export default function ChatInputBar({
onClick={() => handleModelSelect(model.id)}
checked={chat.selectedModel === model.id}
>
{model.name} <small>({model.costDescription})</small>
{model.name}{model.costDescription && <> <small>({model.costDescription})</small></>}
</FormListItem>
))}
</FormDropdownSubmenu>
Expand Down
7 changes: 6 additions & 1 deletion apps/client/src/widgets/type_widgets/llm_chat/useLlmChat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { RefObject } from "preact";
import { useCallback, useEffect, useRef, useState } from "preact/hooks";

import { getAvailableModels, streamChatCompletion } from "../../../services/llm_chat.js";
import { t } from "../../../services/i18n.js";
import { randomString } from "../../../services/utils.js";
import type { ContentBlock, LlmChatContent, StoredMessage } from "./llm_chat_types.js";

Expand Down Expand Up @@ -122,7 +123,11 @@ export function useLlmChat(
getAvailableModels().then(models => {
const modelsWithDescription = models.map(m => ({
...m,
costDescription: m.costMultiplier ? `${m.costMultiplier}x` : undefined
costDescription: m.costMultiplier
? `${m.costMultiplier}x`
: m.pricing.input === 0 && m.pricing.output === 0
? t("llm_chat.free")
: undefined
Comment on lines +126 to +130
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too complicated. Extract to a function with simple ifs and returns.

}));
setAvailableModels(modelsWithDescription);
setHasProvider(models.length > 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,26 @@ export interface LlmProviderConfig {
name: string;
provider: string;
apiKey: string;
/** Base URL for self-hosted providers (e.g. Ollama). */
baseUrl?: string;
}

export interface ProviderType {
id: string;
name: string;
/** Whether this provider needs an API key (defaults to true). */
needsApiKey?: boolean;
/** Whether this provider needs a base URL. */
needsBaseUrl?: boolean;
/** Default base URL for the provider. */
defaultBaseUrl?: string;
}

export const PROVIDER_TYPES: ProviderType[] = [
{ id: "anthropic", name: "Anthropic" },
{ id: "openai", name: "OpenAI" },
{ id: "google", name: "Google Gemini" }
{ id: "google", name: "Google Gemini" },
{ id: "ollama", name: "Ollama", needsApiKey: false, needsBaseUrl: true, defaultBaseUrl: "http://localhost:11434" }
];

interface AddProviderModalProps {
Expand All @@ -33,19 +42,34 @@ interface AddProviderModalProps {
export default function AddProviderModal({ show, onHidden, onSave }: AddProviderModalProps) {
const [selectedProvider, setSelectedProvider] = useState(PROVIDER_TYPES[0].id);
const [apiKey, setApiKey] = useState("");
const [baseUrl, setBaseUrl] = useState("");
const formRef = useRef<HTMLFormElement>(null);

const providerType = PROVIDER_TYPES.find(p => p.id === selectedProvider);
const needsApiKey = providerType?.needsApiKey !== false;
const needsBaseUrl = providerType?.needsBaseUrl === true;

function handleProviderChange(value: string) {
setSelectedProvider(value);
const pt = PROVIDER_TYPES.find(p => p.id === value);
if (pt?.defaultBaseUrl) {
setBaseUrl(pt.defaultBaseUrl);
} else {
setBaseUrl("");
}
}

function handleSubmit() {
if (!apiKey.trim()) {
if (needsApiKey && !apiKey.trim()) {
return;
}

const providerType = PROVIDER_TYPES.find(p => p.id === selectedProvider);
const newProvider: LlmProviderConfig = {
id: `${selectedProvider}_${Date.now()}`,
name: providerType?.name || selectedProvider,
provider: selectedProvider,
apiKey: apiKey.trim()
apiKey: apiKey.trim(),
...(needsBaseUrl && baseUrl.trim() ? { baseUrl: baseUrl.trim() } : {})
};

onSave(newProvider);
Expand All @@ -56,13 +80,16 @@ export default function AddProviderModal({ show, onHidden, onSave }: AddProvider
function resetForm() {
setSelectedProvider(PROVIDER_TYPES[0].id);
setApiKey("");
setBaseUrl("");
}

function handleCancel() {
resetForm();
onHidden();
}

const isSubmitDisabled = needsApiKey ? !apiKey.trim() : false;

return createPortal(
<Modal
show={show}
Expand All @@ -77,7 +104,7 @@ export default function AddProviderModal({ show, onHidden, onSave }: AddProvider
<button type="button" className="btn btn-secondary" onClick={handleCancel}>
{t("llm.cancel")}
</button>
<button type="submit" className="btn btn-primary" disabled={!apiKey.trim()}>
<button type="submit" className="btn btn-primary" disabled={isSubmitDisabled}>
{t("llm.add_provider")}
</button>
</>
Expand All @@ -89,19 +116,31 @@ export default function AddProviderModal({ show, onHidden, onSave }: AddProvider
keyProperty="id"
titleProperty="name"
currentValue={selectedProvider}
onChange={setSelectedProvider}
onChange={handleProviderChange}
/>
</FormGroup>

<FormGroup name="api-key" label={t("llm.api_key")}>
<FormTextBox
type="password"
currentValue={apiKey}
onChange={setApiKey}
placeholder={t("llm.api_key_placeholder")}
autoFocus
/>
</FormGroup>
{needsApiKey && (
<FormGroup name="api-key" label={t("llm.api_key")}>
<FormTextBox
type="password"
currentValue={apiKey}
onChange={setApiKey}
placeholder={t("llm.api_key_placeholder")}
autoFocus
/>
</FormGroup>
)}

{needsBaseUrl && (
<FormGroup name="base-url" label={t("llm.base_url")}>
<FormTextBox
currentValue={baseUrl}
onChange={setBaseUrl}
placeholder={providerType?.defaultBaseUrl || "http://localhost:11434"}
/>
</FormGroup>
)}
</Modal>,
document.body
);
Expand Down
11 changes: 9 additions & 2 deletions apps/server/src/routes/api/llm_chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import type { Request, Response } from "express";

import { generateChatTitle } from "../../services/llm/chat_title.js";
import { getAllModels, getProviderByType, hasConfiguredProviders, type LlmProviderConfig } from "../../services/llm/index.js";
import { OllamaProvider } from "../../services/llm/providers/ollama.js";
import { streamToChunks } from "../../services/llm/stream.js";
import log from "../../services/log.js";
import { safeExtractMessageAndStackFromError } from "../../services/utils.js";
Expand Down Expand Up @@ -51,6 +52,12 @@ async function streamChat(req: Request, res: Response) {
}

const provider = getProviderByType(config.provider || "anthropic");

// Ensure Ollama models are loaded so defaultModel/titleModel are set
if (provider instanceof OllamaProvider) {
await provider.loadModels();
}

const result = provider.chat(messages, config);

// Get pricing and display name for the model
Expand Down Expand Up @@ -90,12 +97,12 @@ async function streamChat(req: Request, res: Response) {
/**
* Get available models from all configured providers.
*/
function getModels(_req: Request, _res: Response) {
async function getModels(_req: Request, _res: Response) {
if (!hasConfiguredProviders()) {
return { models: [] };
}

return { models: getAllModels() };
return { models: await getAllModels() };
}

export default {
Expand Down
2 changes: 1 addition & 1 deletion apps/server/src/routes/routes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ function register(app: express.Application) {

// LLM chat endpoints
asyncRoute(PST, "/api/llm-chat/stream", [auth.checkApiAuthOrElectron, csrfMiddleware], llmChatRoute.streamChat, null);
apiRoute(GET, "/api/llm-chat/models", llmChatRoute.getModels);
asyncApiRoute(GET, "/api/llm-chat/models", llmChatRoute.getModels);

// no CSRF since this is called from android app
route(PST, "/api/sender/login", [loginRateLimiter], loginApiRoute.token, apiResultHandler);
Expand Down
7 changes: 7 additions & 0 deletions apps/server/src/services/llm/chat_title.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import becca from "../../becca/becca.js";
import { getProvider } from "./index.js";
import { OllamaProvider } from "./providers/ollama.js";
import log from "../log.js";
import { t } from "i18next";

Expand Down Expand Up @@ -28,6 +29,12 @@ export async function generateChatTitle(chatNoteId: string, firstMessage: string
}

const provider = getProvider();

// Ensure Ollama models are loaded so titleModel is set
if (provider instanceof OllamaProvider) {
await provider.loadModels();
}

const title = await provider.generateTitle(firstMessage);
if (title) {
note.title = title;
Expand Down
18 changes: 14 additions & 4 deletions apps/server/src/services/llm/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type { LlmProvider, ModelInfo } from "./types.js";
import { AnthropicProvider } from "./providers/anthropic.js";
import { GoogleProvider } from "./providers/google.js";
import { OllamaProvider } from "./providers/ollama.js";
import { OpenAiProvider } from "./providers/openai.js";
import optionService from "../options.js";
import log from "../log.js";
Expand All @@ -14,13 +15,16 @@ export interface LlmProviderSetup {
name: string;
provider: string;
apiKey: string;
/** Base URL for self-hosted providers (e.g. Ollama). */
baseUrl?: string;
}

/** Factory functions for creating provider instances */
const providerFactories: Record<string, (apiKey: string) => LlmProvider> = {
const providerFactories: Record<string, (apiKey: string, baseUrl?: string) => LlmProvider> = {
anthropic: (apiKey) => new AnthropicProvider(apiKey),
openai: (apiKey) => new OpenAiProvider(apiKey),
google: (apiKey) => new GoogleProvider(apiKey)
google: (apiKey) => new GoogleProvider(apiKey),
ollama: (_apiKey, baseUrl) => new OllamaProvider(baseUrl)
};

/** Cache of instantiated providers by their config ID */
Expand Down Expand Up @@ -73,7 +77,7 @@ export function getProvider(providerId?: string): LlmProvider {
throw new Error(`Unknown LLM provider type: ${config.provider}. Available: ${Object.keys(providerFactories).join(", ")}`);
}

const provider = factory(config.apiKey);
const provider = factory(config.apiKey, config.baseUrl);
cachedProviders[config.id] = provider;
return provider;
}
Expand Down Expand Up @@ -102,7 +106,7 @@ export function hasConfiguredProviders(): boolean {
/**
* Get all models from all configured providers, tagged with their provider type.
*/
export function getAllModels(): ModelInfo[] {
export async function getAllModels(): Promise<ModelInfo[]> {
const configs = getConfiguredProviders();
const seenProviderTypes = new Set<string>();
const allModels: ModelInfo[] = [];
Expand All @@ -116,6 +120,12 @@ export function getAllModels(): ModelInfo[] {

try {
const provider = getProvider(config.id);

// Ollama needs to fetch models from the running instance
if (provider instanceof OllamaProvider) {
await provider.loadModels();
}

const models = provider.getAvailableModels();
for (const model of models) {
allModels.push({ ...model, provider: config.provider });
Expand Down
Loading
Loading