diff --git a/src/composables/create-dialog.ts b/src/composables/create-dialog.ts index 2793afe7..cc4dacf4 100644 --- a/src/composables/create-dialog.ts +++ b/src/composables/create-dialog.ts @@ -19,6 +19,7 @@ export function useCreateDialog(workspace: Ref) { name: t('createDialog.newDialog'), msgTree: { $root: [messageId], [messageId]: [] }, msgRoute: [], + msgBranchState: {}, assistantId: workspace.value.defaultAssistantId, inputVars: {}, ...props diff --git a/src/utils/types.ts b/src/utils/types.ts index 3f796e50..190b37e7 100644 --- a/src/utils/types.ts +++ b/src/utils/types.ts @@ -424,6 +424,7 @@ interface Dialog { assistantId?: string msgTree: Record msgRoute: number[] + msgBranchState?: Record inputVars: Record modelOverride?: Model } diff --git a/src/views/DialogView.vue b/src/views/DialogView.vue index 876211f9..8040059e 100644 --- a/src/views/DialogView.vue +++ b/src/views/DialogView.vue @@ -496,29 +496,49 @@ const assistant = computed(() => { }) provide('dialog', dialog) -const chain = computed(() => liveData.value.dialog ? getChain('$root', liveData.value.dialog.msgRoute)[0] : []) +const chain = computed(() => liveData.value.dialog ? getChain(liveData.value.dialog.msgTree, '$root', getBranchState())[0] : []) const historyChain = ref([]) +function clampIndex(value: number, min: number, max: number) { + return Math.min(Math.max(value, min), max) +} function switchChain(index, value) { - const route = [...dialog.value.msgRoute.slice(0, index), value] - updateChain(route) + const branchState = { ...getBranchState(), [chain.value[index]]: value } + updateChain(branchState) +} +function getBranchState(route = dialog.value?.msgRoute || []) { + if (!liveData.value.dialog) return {} + if (dialog.value?.msgBranchState) return { ...dialog.value.msgBranchState } + return mergeRouteIntoBranchState(liveData.value.dialog.msgTree, route, {}) } -function updateChain(route) { - const res = getChain('$root', route) +function mergeRouteIntoBranchState(tree: Record, route: number[], baseState: Record) { + const branchState = { ...baseState } + let node = '$root' + for (const rawIndex of route) { + const children = tree[node] + if (!children?.length) break + const index = clampIndex(rawIndex ?? 0, 0, children.length - 1) + branchState[node] = index + node = children[index] + } + return branchState +} +function updateChain(routeOrBranchState: number[] | Record) { + const branchState = Array.isArray(routeOrBranchState) + ? mergeRouteIntoBranchState(liveData.value.dialog.msgTree, routeOrBranchState, getBranchState()) + : routeOrBranchState + const res = getChain(liveData.value.dialog.msgTree, '$root', branchState) historyChain.value = res[0] - db.dialogs.update(dialog.value.id, { msgRoute: res[1] }) + db.dialogs.update(dialog.value.id, { msgRoute: res[1], msgBranchState: branchState }) } watch([() => liveData.value.messages.length, () => liveData.value.dialog?.id], () => { liveData.value.dialog && updateChain(liveData.value.dialog.msgRoute) }) -function getChain(node, route: number[]) { - const children = liveData.value.dialog.msgTree[node] - const r = route.at(0) || 0 - if (children[r]) { - const [restChain, restRoute] = getChain(children[r], route.slice(1)) - return [[node, ...restChain], [r, ...restRoute]] - } else { - return [[node], [r]] - } +function getChain(tree: Record, node: string, branchState: Record) { + const children = tree[node] + if (!children?.length) return [[node], []] + const index = clampIndex(branchState[node] ?? 0, 0, children.length - 1) + const [restChain, restRoute] = getChain(tree, children[index], branchState) + return [[node, ...restChain], [index, ...restRoute]] } const messageInput = ref() @@ -528,15 +548,14 @@ function focusInput() { async function edit(index) { const target = chain.value[index - 1] const { type, contents } = messageMap.value[chain.value[index]] - switchChain(index - 1, dialog.value.msgTree[target].length) - await db.transaction('rw', db.dialogs, db.messages, db.items, () => { - appendMessage(target, { + await db.transaction('rw', db.dialogs, db.messages, db.items, async () => { + await appendMessage(target, { type, contents, status: 'inputing' - }) + }, false, true) const content = contents[0] as UserMessageContent - saveItems(content.items.map(id => itemMap.value[id])) + await saveItems(content.items.map(id => itemMap.value[id])) }) await nextTick() focusInput() @@ -551,7 +570,6 @@ async function regenerate(index) { return } const target = chain.value[index - 1] - switchChain(index - 1, dialog.value.msgTree[target].length) await stream(target, false) } async function deleteBranch(index) { @@ -581,7 +599,7 @@ async function deleteBranch(index) { }) } -async function appendMessage(target, info: Partial, insert = false) { +async function appendMessage(target, info: Partial, insert = false, selectBranch = false) { const id = genId() await db.transaction('rw', db.dialogs, db.messages, async () => { await db.messages.add({ @@ -599,9 +617,17 @@ async function appendMessage(target, info: Partial, insert = false) { [target]: [...children, id], [id]: [] } - await db.dialogs.update(props.id, { - msgTree: { ...d.msgTree, ...changes } - }) + const msgTree = { ...d.msgTree, ...changes } + const dialogChanges: Partial = { msgTree } + if (selectBranch) { + const branchState = d.msgBranchState + ? { ...d.msgBranchState } + : mergeRouteIntoBranchState(d.msgTree, d.msgRoute, {}) + branchState[target] = insert ? 0 : children.length + dialogChanges.msgBranchState = branchState + dialogChanges.msgRoute = getChain(msgTree, '$root', branchState)[1] + } + await db.dialogs.update(props.id, dialogChanges) }) return id } @@ -960,7 +986,7 @@ async function stream(target, insert = false) { status: 'pending', generatingSession: sessions.id, modelName: model.value.name - }, insert) + }, insert, true) !insert && await appendMessage(id, { type: 'user', contents: [{