Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/composables/create-dialog.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export function useCreateDialog(workspace: Ref<Workspace>) {
name: t('createDialog.newDialog'),
msgTree: { $root: [messageId], [messageId]: [] },
msgRoute: [],
msgBranchState: {},
assistantId: workspace.value.defaultAssistantId,
inputVars: {},
...props
Expand Down
1 change: 1 addition & 0 deletions src/utils/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ interface Dialog {
assistantId?: string
msgTree: Record<string, string[]>
msgRoute: number[]
msgBranchState?: Record<string, number>
inputVars: Record<string, PromptVarValue>
modelOverride?: Model
}
Expand Down
78 changes: 52 additions & 26 deletions src/views/DialogView.vue
Original file line number Diff line number Diff line change
Expand Up @@ -496,29 +496,49 @@ const assistant = computed(() => {
})
provide('dialog', dialog)

const chain = computed<string[]>(() => liveData.value.dialog ? getChain('$root', liveData.value.dialog.msgRoute)[0] : [])
const chain = computed<string[]>(() => liveData.value.dialog ? getChain(liveData.value.dialog.msgTree, '$root', getBranchState())[0] : [])
const historyChain = ref<string[]>([])
function clampIndex(value: number, min: number, max: number) {
return Math.min(Math.max(value, min), max)
}
function switchChain(index, value) {
const route = [...dialog.value.msgRoute.slice(0, index), value]
updateChain(route)
const branchState = { ...getBranchState(), [chain.value[index]]: value }
updateChain(branchState)
}
function getBranchState(route = dialog.value?.msgRoute || []) {
if (!liveData.value.dialog) return {}
if (dialog.value?.msgBranchState) return { ...dialog.value.msgBranchState }
return mergeRouteIntoBranchState(liveData.value.dialog.msgTree, route, {})
}
function updateChain(route) {
const res = getChain('$root', route)
function mergeRouteIntoBranchState(tree: Record<string, string[]>, route: number[], baseState: Record<string, number>) {
const branchState = { ...baseState }
let node = '$root'
for (const rawIndex of route) {
const children = tree[node]
if (!children?.length) break
const index = clampIndex(rawIndex ?? 0, 0, children.length - 1)
branchState[node] = index
node = children[index]
}
return branchState
}
function updateChain(routeOrBranchState: number[] | Record<string, number>) {
const branchState = Array.isArray(routeOrBranchState)
? mergeRouteIntoBranchState(liveData.value.dialog.msgTree, routeOrBranchState, getBranchState())
: routeOrBranchState
const res = getChain(liveData.value.dialog.msgTree, '$root', branchState)
historyChain.value = res[0]
db.dialogs.update(dialog.value.id, { msgRoute: res[1] })
db.dialogs.update(dialog.value.id, { msgRoute: res[1], msgBranchState: branchState })
}
watch([() => liveData.value.messages.length, () => liveData.value.dialog?.id], () => {
liveData.value.dialog && updateChain(liveData.value.dialog.msgRoute)
})
function getChain(node, route: number[]) {
const children = liveData.value.dialog.msgTree[node]
const r = route.at(0) || 0
if (children[r]) {
const [restChain, restRoute] = getChain(children[r], route.slice(1))
return [[node, ...restChain], [r, ...restRoute]]
} else {
return [[node], [r]]
}
function getChain(tree: Record<string, string[]>, node: string, branchState: Record<string, number>) {
const children = tree[node]
if (!children?.length) return [[node], []]
const index = clampIndex(branchState[node] ?? 0, 0, children.length - 1)
const [restChain, restRoute] = getChain(tree, children[index], branchState)
return [[node, ...restChain], [index, ...restRoute]]
}

const messageInput = ref()
Expand All @@ -528,15 +548,14 @@ function focusInput() {
async function edit(index) {
const target = chain.value[index - 1]
const { type, contents } = messageMap.value[chain.value[index]]
switchChain(index - 1, dialog.value.msgTree[target].length)
await db.transaction('rw', db.dialogs, db.messages, db.items, () => {
appendMessage(target, {
await db.transaction('rw', db.dialogs, db.messages, db.items, async () => {
await appendMessage(target, {
type,
contents,
status: 'inputing'
})
}, false, true)
const content = contents[0] as UserMessageContent
saveItems(content.items.map(id => itemMap.value[id]))
await saveItems(content.items.map(id => itemMap.value[id]))
})
await nextTick()
focusInput()
Expand All @@ -551,7 +570,6 @@ async function regenerate(index) {
return
}
const target = chain.value[index - 1]
switchChain(index - 1, dialog.value.msgTree[target].length)
await stream(target, false)
}
async function deleteBranch(index) {
Expand Down Expand Up @@ -581,7 +599,7 @@ async function deleteBranch(index) {
})
}

async function appendMessage(target, info: Partial<Message>, insert = false) {
async function appendMessage(target, info: Partial<Message>, insert = false, selectBranch = false) {
const id = genId()
await db.transaction('rw', db.dialogs, db.messages, async () => {
await db.messages.add({
Expand All @@ -599,9 +617,17 @@ async function appendMessage(target, info: Partial<Message>, insert = false) {
[target]: [...children, id],
[id]: []
}
await db.dialogs.update(props.id, {
msgTree: { ...d.msgTree, ...changes }
})
const msgTree = { ...d.msgTree, ...changes }
const dialogChanges: Partial<Dialog> = { msgTree }
if (selectBranch) {
const branchState = d.msgBranchState
? { ...d.msgBranchState }
: mergeRouteIntoBranchState(d.msgTree, d.msgRoute, {})
branchState[target] = insert ? 0 : children.length
dialogChanges.msgBranchState = branchState
dialogChanges.msgRoute = getChain(msgTree, '$root', branchState)[1]
}
await db.dialogs.update(props.id, dialogChanges)
})
return id
}
Expand Down Expand Up @@ -960,7 +986,7 @@ async function stream(target, insert = false) {
status: 'pending',
generatingSession: sessions.id,
modelName: model.value.name
}, insert)
}, insert, true)
!insert && await appendMessage(id, {
type: 'user',
contents: [{
Expand Down