diff --git a/src/bots/base-bot.ts b/src/bots/base-bot.ts index f4265be..6f0ab42 100644 --- a/src/bots/base-bot.ts +++ b/src/bots/base-bot.ts @@ -13,6 +13,8 @@ export abstract class LLMBot { _chatHistory!: ChatHistory private __storage!: BotStorage + protected static readonly _conversation_key_ = '_Conversation_key_' + constructor( /** bot unique name */ readonly name: string, @@ -53,12 +55,13 @@ export abstract class LLMBot { async sendPrompt(msg: ChatDto, streamCallback?: (msg: ChatDto) => void): Promise { if (!(await this.isAvailable())) { - const msg = new ChatDto('bot.notAvailable', true) + const msg = new ChatDto('bot.notAvailable') streamCallback && streamCallback(msg) return msg } // always store req into storage history + // `lastMsgId` and `_conversationKey` are updated const branched = await this._chatHistory.append(msg) if (this._getServerType() == 'stateless') { @@ -69,8 +72,13 @@ export abstract class LLMBot { if (this._getServerType() == 'threads') { // if thread cut. a new server thread has to be created throw new Error('TODO: create a new server thread for new branch.') + msg.options._conversationKey = '' } // else server side support } + + // create new conversation on llm server + msg.options._conversationKey || (msg.options._conversationKey = await this._getConversation(true)) + await this._setConversation(msg.options._conversationKey) } return this._sendPrompt(msg, streamCallback).then(async resp => { @@ -92,9 +100,29 @@ export abstract class LLMBot { * @returns the LLM server type: * - `stateless`: server does not keep chat history * - `threads`: server keeps multi-threads of chat history - * - `tree`: server keeps tree structured history + * - `trees`: server keeps tree structured history */ abstract _getServerType(): LLMServerType + + abstract createConversation(): Promise + + /** + * stateful llm server has server side conversations + * @param create true: force create; false: never create; otherwise: create if there is no key + * @returns + */ + async _getConversation(create?: boolean): Promise { + let key = create ? '' : await this.__storage.get(LLMBot._conversation_key_) + if (create || (!key && create !== false)) { + key = await this.createConversation() + this._setConversation(key) + } + return key + } + /** stateful llm server has server side conversations */ + async _setConversation(key: string) { + this.__storage.set(LLMBot._conversation_key_, key) + } } export enum LLMServerType { @@ -103,5 +131,5 @@ export enum LLMServerType { /** server keeps multi-threads of chat history */ threads = 'threads', /** server keeps tree structured history */ - tree = 'tree', + trees = 'trees', } diff --git a/src/bots/chat.dto.ts b/src/bots/chat.dto.ts index 1466993..4371b03 100644 --- a/src/bots/chat.dto.ts +++ b/src/bots/chat.dto.ts @@ -7,15 +7,16 @@ export class ChatDto { constructor( public readonly prompt: string | string[], - done = false, + done = true, public readonly options: { /** msg type: true: response, false: request */ resp?: boolean - /** local parent msg id */ + /** parent msg id. if '', means to start a new conversation; if undefined, append to current conversation. */ lastMsgId?: string + /** conversation key from llm server */ + _conversationKey?: string /** approximate max length of new response */ maxNewWords?: number - stream?: boolean } & Record = {}, ) { done && ((this.id = nanoid()), (this.done = true)) diff --git a/src/bots/huggingface/GradioBot.ts b/src/bots/huggingface/GradioBot.ts index 00bf987..c22ed61 100644 --- a/src/bots/huggingface/GradioBot.ts +++ b/src/bots/huggingface/GradioBot.ts @@ -27,7 +27,6 @@ export default abstract class GradioBot extends LLMBot { ////// user state keys protected static readonly _session_config = '_session_config' - protected static readonly _session_hash = '_session_hash' async reloadSession() { let available = false @@ -40,10 +39,8 @@ export default abstract class GradioBot extends LLMBot { config.root = this._loginUrl await this._userStorage.set(GradioBot._session_config, config) - if (!(await this._userStorage.get(GradioBot._session_hash))) { - const session_hash = await this.createConversation() - await this._userStorage.set(GradioBot._session_hash, session_hash) - } + // init a conversation forehand + await this._getConversation() available = true } } catch (err) { @@ -56,7 +53,7 @@ export default abstract class GradioBot extends LLMBot { } async _sendPrompt(prompt: ChatDto, streamCallback?: (msg: ChatDto) => void): Promise { - let result: ChatDto = new ChatDto('') + let result: ChatDto = new ChatDto('', false) for (const key in this._fnIndexes) { const fn_index = this._fnIndexes[key] const resp = await this._sendFnIndex(fn_index, prompt, streamCallback) @@ -67,7 +64,7 @@ export default abstract class GradioBot extends LLMBot { async _sendFnIndex(fn_index: number, prompt: ChatDto, streamCallback?: (msg: ChatDto) => void): Promise { const config = await this._userStorage.get>(GradioBot._session_config) - return new Promise((resolve, reject) => { + return new Promise(async (resolve, reject) => { try { const url = new URL(config.root + config.path + '/queue/join') url.protocol = url.protocol === 'https:' ? 'wss:' : 'ws:' @@ -83,14 +80,12 @@ export default abstract class GradioBot extends LLMBot { }, }) const data = this.makeData(fn_index, prompt) - let session_hash: string - wsp.onUnpackedMessage.addListener(async event => { + const session_hash = await this._getConversation() + wsp.onUnpackedMessage.addListener(event => { if (event.msg === 'send_hash') { - session_hash = session_hash || (await this._userStorage.get(GradioBot._session_hash)) wsp.sendPacked({ fn_index, session_hash }) } else if (event.msg === 'send_data') { // Requested to send data - session_hash = session_hash || (await this._userStorage.get(GradioBot._session_hash)) wsp.sendPacked({ data, event_data: null, @@ -101,12 +96,12 @@ export default abstract class GradioBot extends LLMBot { if (event.rank > 0) { // Waiting in queue event.rank_eta = Math.floor(event.rank_eta) - streamCallback && streamCallback(new ChatDto('gradio.waiting')) + streamCallback && streamCallback(new ChatDto('gradio.waiting', false)) } } else if (event.msg === 'process_generating') { // Generating data if (event.success && event.output.data) { - streamCallback && streamCallback(new ChatDto(this.parseData(fn_index, event.output.data))) + streamCallback && streamCallback(new ChatDto(this.parseData(fn_index, event.output.data), false)) } else { reject(new Error(event.output.error)) } @@ -155,10 +150,6 @@ export default abstract class GradioBot extends LLMBot { abstract makeData(fn_index: number, prompt: ChatDto): unknown abstract parseData(fn_index: number, data: unknown): string - /** - * Should implement this method if the bot supports conversation. - * The conversation structure is defined by the subclass. - */ async createConversation() { return Math.random().toString(36).substring(2) } diff --git a/src/storage/chat-history.ts b/src/storage/chat-history.ts index 4c0769e..2b7c173 100644 --- a/src/storage/chat-history.ts +++ b/src/storage/chat-history.ts @@ -3,30 +3,45 @@ import { BotStorage } from './bot-storage.interface' /** tree structured chat history. each chat message has a property: 'lastMsgId' */ export class ChatHistory { - protected static readonly _storage_key = '__Chat_history_' + protected static readonly _history_key = '__Chat_history_' constructor(private readonly _storage: BotStorage) { - _storage.set(ChatHistory._storage_key, []) + _storage.set(ChatHistory._history_key, []) } /** - * append msg to history. if msg.options.lastMsgId is empty, a new thread is created. - * @param msg - * @returns true if the msg is appended on a new branch. msg with null parent returns false. + * append msg to history. `msg.done` must be true. + * + * @param msg if `lastMsgId` not empty, append to it; if lastMsgId === '', a new thread created; else, allMsgs[-1] as parent. + * @sideefect `lastMsgId` and `_conversationKey` in `msg.options` is updated, if available + * @returns true if the msg is appended on a new branch. msg with no parent returns false. */ // async append(msg: ChatDto & { prompt: string }) { async append(msg: ChatDto) { if (msg.options.compound) throw new Error('chat.history.append.compound.not.allowed') + if (!msg.done) throw new Error('chat.history.append.undone.not.allowed') let branched = false - const allMsgs = await this._storage.get(ChatHistory._storage_key) + const allMsgs = await this._storage.get(ChatHistory._history_key) const pid = msg.options.lastMsgId if (pid) { const parent = this._findLast(allMsgs, m => m.id == pid) if (!parent) throw Error('chat.history.notfound.lastMsgId: ' + pid) + branched = !parent.options.leaf - parent.options.leaf = 0 - } + // branched history, may use different conversation key + msg.options._conversationKey = parent.options._conversationKey + + delete parent.options.leaf + } else if (pid !== '') { + // append to current conversation + const parent = allMsgs.at(-1) + if (parent) { + msg.options.lastMsgId = parent.id + msg.options._conversationKey = parent.options._conversationKey + } + } // else start a new conversation + msg.options.leaf = 1 allMsgs.push(msg) @@ -39,17 +54,22 @@ export class ChatHistory { async getWholeThread(msg: ChatDto): Promise { if (!msg.options?.lastMsgId) return msg - let mid: unknown = msg.id + let mid: unknown = msg.id, + _conversationKey: string | undefined const ret: string[] = [], - allMsgs = await this._storage.get(ChatHistory._storage_key) + allMsgs = await this._storage.get(ChatHistory._history_key) for (let index = allMsgs.length; index > 0; index--) { - const a = allMsgs[index - 1] - if (a.id == mid) { - ret.push(...a.prompt) - if (!(mid = a.options.lastMsgId)) break + const m = allMsgs[index - 1] + if (m.id == mid) { + ret.push(...m.prompt) + // uses the newest conversation key + _conversationKey || (_conversationKey = m.options._conversationKey) + if (!(mid = m.options.lastMsgId)) break } } - return new ChatDto(ret.reverse(), msg.done, { ...msg.options, compound: 1 }) + const retMsg = new ChatDto(ret.reverse(), msg.done, { ...msg.options, compound: 1 }) + _conversationKey && (retMsg.options._conversationKey = _conversationKey) + return retMsg } protected _findLast(array: T[], fn: (a: T) => unknown) { diff --git a/test/llmbots.e2e.spec.ts b/test/llmbots.e2e.spec.ts index 1e4de56..281a954 100644 --- a/test/llmbots.e2e.spec.ts +++ b/test/llmbots.e2e.spec.ts @@ -10,9 +10,13 @@ describe('builtin LLMBots: vicuna-13b (e2e)', () => { const ready = await claudeBot?.reloadSession() expect(ready).toBeTruthy() + const msg = new ChatDto('Who is Gauss. reply 5 words most') // eslint-disable-next-line no-console - const resp = await claudeBot?.sendPrompt(new ChatDto('hi there. 3 words most'), msg => console.log(msg)) - // console.log(resp) + const resp = await claudeBot?.sendPrompt(msg, msg => console.log(msg)) + // eslint-disable-next-line no-console + console.log(resp) expect(resp?.prompt).not.toBeNull() + expect(resp?.options.lastMsgId).toEqual(msg.id) + expect(resp?.options._conversationKey).not.toBeNull() }) })