Skip to content

Commit

Permalink
feat: support conversations
Browse files Browse the repository at this point in the history
  • Loading branch information
pond918 committed May 28, 2023
1 parent 87792f1 commit 7c8832d
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 40 deletions.
34 changes: 31 additions & 3 deletions src/bots/base-bot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ export abstract class LLMBot {
_chatHistory!: ChatHistory
private __storage!: BotStorage

protected static readonly _conversation_key_ = '_Conversation_key_'

constructor(
/** bot unique name */
readonly name: string,
Expand Down Expand Up @@ -53,12 +55,13 @@ export abstract class LLMBot {

async sendPrompt(msg: ChatDto, streamCallback?: (msg: ChatDto) => void): Promise<ChatDto> {
if (!(await this.isAvailable())) {
const msg = new ChatDto('bot.notAvailable', true)
const msg = new ChatDto('bot.notAvailable')
streamCallback && streamCallback(msg)
return msg
}

// always store req into storage history
// `lastMsgId` and `_conversationKey` are updated
const branched = await this._chatHistory.append(msg)

if (this._getServerType() == 'stateless') {
Expand All @@ -69,8 +72,13 @@ export abstract class LLMBot {
if (this._getServerType() == 'threads') {
// if thread cut. a new server thread has to be created
throw new Error('TODO: create a new server thread for new branch.')
msg.options._conversationKey = ''
} // else server side support
}

// create new conversation on llm server
msg.options._conversationKey || (msg.options._conversationKey = await this._getConversation(true))
await this._setConversation(msg.options._conversationKey)
}

return this._sendPrompt(msg, streamCallback).then(async resp => {
Expand All @@ -92,9 +100,29 @@ export abstract class LLMBot {
* @returns the LLM server type:
* - `stateless`: server does not keep chat history
* - `threads`: server keeps multi-threads of chat history
* - `tree`: server keeps tree structured history
* - `trees`: server keeps tree structured history
*/
abstract _getServerType(): LLMServerType

abstract createConversation(): Promise<string>

/**
* stateful llm server has server side conversations
* @param create true: force create; false: never create; otherwise: create if there is no key
* @returns
*/
async _getConversation(create?: boolean): Promise<string> {
let key = create ? '' : await this.__storage.get<string>(LLMBot._conversation_key_)
if (create || (!key && create !== false)) {
key = await this.createConversation()
this._setConversation(key)
}
return key
}
/** stateful llm server has server side conversations */
async _setConversation(key: string) {
this.__storage.set(LLMBot._conversation_key_, key)
}
}

export enum LLMServerType {
Expand All @@ -103,5 +131,5 @@ export enum LLMServerType {
/** server keeps multi-threads of chat history */
threads = 'threads',
/** server keeps tree structured history */
tree = 'tree',
trees = 'trees',
}
7 changes: 4 additions & 3 deletions src/bots/chat.dto.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@ export class ChatDto {

constructor(
public readonly prompt: string | string[],
done = false,
done = true,
public readonly options: {
/** msg type: true: response, false: request */
resp?: boolean
/** local parent msg id */
/** parent msg id. if '', means to start a new conversation; if undefined, append to current conversation. */
lastMsgId?: string
/** conversation key from llm server */
_conversationKey?: string
/** approximate max length of new response */
maxNewWords?: number
stream?: boolean
} & Record<string, unknown> = {},
) {
done && ((this.id = nanoid()), (this.done = true))
Expand Down
25 changes: 8 additions & 17 deletions src/bots/huggingface/GradioBot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ export default abstract class GradioBot extends LLMBot {

////// user state keys
protected static readonly _session_config = '_session_config'
protected static readonly _session_hash = '_session_hash'

async reloadSession() {
let available = false
Expand All @@ -40,10 +39,8 @@ export default abstract class GradioBot extends LLMBot {
config.root = this._loginUrl
await this._userStorage.set(GradioBot._session_config, config)

if (!(await this._userStorage.get(GradioBot._session_hash))) {
const session_hash = await this.createConversation()
await this._userStorage.set(GradioBot._session_hash, session_hash)
}
// init a conversation forehand
await this._getConversation()
available = true
}
} catch (err) {
Expand All @@ -56,7 +53,7 @@ export default abstract class GradioBot extends LLMBot {
}

async _sendPrompt(prompt: ChatDto, streamCallback?: (msg: ChatDto) => void): Promise<ChatDto> {
let result: ChatDto = new ChatDto('')
let result: ChatDto = new ChatDto('', false)
for (const key in this._fnIndexes) {
const fn_index = this._fnIndexes[key]
const resp = await this._sendFnIndex(fn_index, prompt, streamCallback)
Expand All @@ -67,7 +64,7 @@ export default abstract class GradioBot extends LLMBot {

async _sendFnIndex(fn_index: number, prompt: ChatDto, streamCallback?: (msg: ChatDto) => void): Promise<ChatDto> {
const config = await this._userStorage.get<Record<string, string>>(GradioBot._session_config)
return new Promise((resolve, reject) => {
return new Promise(async (resolve, reject) => {
try {
const url = new URL(config.root + config.path + '/queue/join')
url.protocol = url.protocol === 'https:' ? 'wss:' : 'ws:'
Expand All @@ -83,14 +80,12 @@ export default abstract class GradioBot extends LLMBot {
},
})
const data = this.makeData(fn_index, prompt)
let session_hash: string
wsp.onUnpackedMessage.addListener(async event => {
const session_hash = await this._getConversation()
wsp.onUnpackedMessage.addListener(event => {
if (event.msg === 'send_hash') {
session_hash = session_hash || (await this._userStorage.get<string>(GradioBot._session_hash))
wsp.sendPacked({ fn_index, session_hash })
} else if (event.msg === 'send_data') {
// Requested to send data
session_hash = session_hash || (await this._userStorage.get<string>(GradioBot._session_hash))
wsp.sendPacked({
data,
event_data: null,
Expand All @@ -101,12 +96,12 @@ export default abstract class GradioBot extends LLMBot {
if (event.rank > 0) {
// Waiting in queue
event.rank_eta = Math.floor(event.rank_eta)
streamCallback && streamCallback(new ChatDto('gradio.waiting'))
streamCallback && streamCallback(new ChatDto('gradio.waiting', false))
}
} else if (event.msg === 'process_generating') {
// Generating data
if (event.success && event.output.data) {
streamCallback && streamCallback(new ChatDto(this.parseData(fn_index, event.output.data)))
streamCallback && streamCallback(new ChatDto(this.parseData(fn_index, event.output.data), false))
} else {
reject(new Error(event.output.error))
}
Expand Down Expand Up @@ -155,10 +150,6 @@ export default abstract class GradioBot extends LLMBot {
abstract makeData(fn_index: number, prompt: ChatDto): unknown
abstract parseData(fn_index: number, data: unknown): string

/**
* Should implement this method if the bot supports conversation.
* The conversation structure is defined by the subclass.
*/
async createConversation() {
return Math.random().toString(36).substring(2)
}
Expand Down
50 changes: 35 additions & 15 deletions src/storage/chat-history.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,45 @@ import { BotStorage } from './bot-storage.interface'

/** tree structured chat history. each chat message has a property: 'lastMsgId' */
export class ChatHistory {
protected static readonly _storage_key = '__Chat_history_'
protected static readonly _history_key = '__Chat_history_'
constructor(private readonly _storage: BotStorage) {
_storage.set(ChatHistory._storage_key, [])
_storage.set(ChatHistory._history_key, [])
}

/**
* append msg to history. if msg.options.lastMsgId is empty, a new thread is created.
* @param msg
* @returns true if the msg is appended on a new branch. msg with null parent returns false.
* append msg to history. `msg.done` must be true.
*
* @param msg if `lastMsgId` not empty, append to it; if lastMsgId === '', a new thread created; else, allMsgs[-1] as parent.
* @sideefect `lastMsgId` and `_conversationKey` in `msg.options` is updated, if available
* @returns true if the msg is appended on a new branch. msg with no parent returns false.
*/
// async append(msg: ChatDto & { prompt: string }) {
async append(msg: ChatDto) {
if (msg.options.compound) throw new Error('chat.history.append.compound.not.allowed')
if (!msg.done) throw new Error('chat.history.append.undone.not.allowed')

let branched = false

const allMsgs = await this._storage.get<ChatDto[]>(ChatHistory._storage_key)
const allMsgs = await this._storage.get<ChatDto[]>(ChatHistory._history_key)
const pid = msg.options.lastMsgId
if (pid) {
const parent = this._findLast(allMsgs, m => m.id == pid)
if (!parent) throw Error('chat.history.notfound.lastMsgId: ' + pid)

branched = !parent.options.leaf
parent.options.leaf = 0
}
// branched history, may use different conversation key
msg.options._conversationKey = parent.options._conversationKey

delete parent.options.leaf
} else if (pid !== '') {
// append to current conversation
const parent = allMsgs.at(-1)
if (parent) {
msg.options.lastMsgId = parent.id
msg.options._conversationKey = parent.options._conversationKey
}
} // else start a new conversation

msg.options.leaf = 1
allMsgs.push(msg)

Expand All @@ -39,17 +54,22 @@ export class ChatHistory {
async getWholeThread(msg: ChatDto): Promise<ChatDto> {
if (!msg.options?.lastMsgId) return msg

let mid: unknown = msg.id
let mid: unknown = msg.id,
_conversationKey: string | undefined
const ret: string[] = [],
allMsgs = await this._storage.get<ChatDto[]>(ChatHistory._storage_key)
allMsgs = await this._storage.get<ChatDto[]>(ChatHistory._history_key)
for (let index = allMsgs.length; index > 0; index--) {
const a = allMsgs[index - 1]
if (a.id == mid) {
ret.push(...a.prompt)
if (!(mid = a.options.lastMsgId)) break
const m = allMsgs[index - 1]
if (m.id == mid) {
ret.push(...m.prompt)
// uses the newest conversation key
_conversationKey || (_conversationKey = m.options._conversationKey)
if (!(mid = m.options.lastMsgId)) break
}
}
return new ChatDto(ret.reverse(), msg.done, { ...msg.options, compound: 1 })
const retMsg = new ChatDto(ret.reverse(), msg.done, { ...msg.options, compound: 1 })
_conversationKey && (retMsg.options._conversationKey = _conversationKey)
return retMsg
}

protected _findLast<T>(array: T[], fn: (a: T) => unknown) {
Expand Down
8 changes: 6 additions & 2 deletions test/llmbots.e2e.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@ describe('builtin LLMBots: vicuna-13b (e2e)', () => {
const ready = await claudeBot?.reloadSession()
expect(ready).toBeTruthy()

const msg = new ChatDto('Who is Gauss. reply 5 words most')
// eslint-disable-next-line no-console
const resp = await claudeBot?.sendPrompt(new ChatDto('hi there. 3 words most'), msg => console.log(msg))
// console.log(resp)
const resp = await claudeBot?.sendPrompt(msg, msg => console.log(msg))
// eslint-disable-next-line no-console
console.log(resp)
expect(resp?.prompt).not.toBeNull()
expect(resp?.options.lastMsgId).toEqual(msg.id)
expect(resp?.options._conversationKey).not.toBeNull()
})
})

0 comments on commit 7c8832d

Please sign in to comment.