Skip to content

Commit

Permalink
add sqlite db for twitch chat history per user
Browse files Browse the repository at this point in the history
keep users history persistent through restarts.
isolate history per user.

better truncation
  • Loading branch information
Chris Kennedy committed May 9, 2024
1 parent 5fbcc11 commit 7a39308
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 68 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/target
/images/*.png
/db/*.db
libndi.dylib
Cargo.lock
.env
Expand Down
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ license-file = "LICENSE"
homepage = "https://github.com/groovybits/rsllm/wiki"
repository = "https://github.com/groovybits/rsllm"
authors = ["Chris Kennedy"]
version = "0.6.0"
version = "0.6.1"
edition = "2021"

[lib]
Expand Down Expand Up @@ -84,3 +84,4 @@ clap_builder = "4.5.2"
safetensors = "0.4.2"
ctrlc = "3.4.4"
base64 = "0.22.0"
rusqlite = "0.31.0"
3 changes: 2 additions & 1 deletion bin/llama_start.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#!/bin/bash
#
#MODEL=/Volumes/BrahmaSSD/LLM/models/GGUF/dolphin-2.7-mixtral-8x7b.Q5_K_M.gguf
MODEL=/Volumes/BrahmaSSD/LLM/models/GGUF/dolphin-2.7-mixtral-8x7b.Q8_0.gguf
#MODEL=/Volumes/BrahmaSSD/LLM/models/GGUF/dolphin-2.7-mixtral-8x7b.Q8_0.gguf
MODEL=/Volumes/BrahmaSSD/LLM/models/GGUF/dolphin-2.9-mixtral-8x22b.Q5_K_M.gguf

server \
-m $MODEL \
Expand Down
Empty file added db/.keep
Empty file.
8 changes: 5 additions & 3 deletions scripts/twitch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ MODEL=mistral
MODEL_ID=7b-it
# Generic settings
USE_API=1
CHAT_FORMAT=chatml
#CHAT_FORMAT=chatml
#CHAT_FORMAT=llama2
CHAT_FORMAT=vicuna
MAX_TOKENS=800
TEMPERATURE=0.8
CONTEXT_SIZE=8000
Expand All @@ -38,8 +40,8 @@ NDI_TIMEOUT=600
TWITCH_MODEL=mistral
TWITCH_LLM_CONCURRENCY=1
TWITCH_CHAT_HISTORY=16
TWITCH_MAX_TOKENS_CHAT=120
TWITCH_MAX_TOKENS_LLM=500
TWITCH_MAX_TOKENS_CHAT=500
TWITCH_MAX_TOKENS_LLM=$MAX_TOKENS
## Stable Diffusion Settings
SD_TEXT_MIN=70
SD_WIDTH=512
Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,9 @@ pub async fn clean_tts_input(input: String) -> String {
input = input.replace("..", ".");
}

// remove <|im_end|> string from input and replace with ""
let input = input.replace("<|im_end|>", "");

// remove all extra spaces besides 1 space between words, if all spaces left then reduce to '"
let input = input
.split_whitespace()
Expand Down
31 changes: 2 additions & 29 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1113,36 +1113,9 @@ async fn main() {
// Capture the start time for performance metrics
let start = Instant::now();

let chat_format = if args.candle_llm == "mistral" {
// check if model_id includes the string "Instruct" within it
if args.model_id.contains("Instruct") {
"llama2".to_string()
} else {
"".to_string()
}
} else if args.candle_llm == "gemma" {
if args.model_id == "7b-it" {
"google".to_string()
} else if args.model_id == "2b-it" {
"google".to_string()
} else {
"".to_string()
}
} else if args.use_api {
if args.chat_format == "chatml" {
"chatml".to_string()
} else if args.chat_format == "llama2" {
"llama2".to_string()
} else {
"".to_string()
}
} else {
"".to_string()
};

let prompt = format_messages_for_llm(messages.clone(), chat_format);
let prompt = format_messages_for_llm(messages.clone(), args.chat_format.clone());

debug!("\nPrompt: {}", prompt);
info!("\nPrompt: {}", prompt);

// Spawn a thread to run the mistral function, to keep the UI responsive
if args.candle_llm != "mistral" && args.candle_llm != "gemma" {
Expand Down
52 changes: 44 additions & 8 deletions src/openai_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize};
use std::time::Instant;
use tokio::sync::mpsc::{self};

#[derive(Serialize, Deserialize, Clone)]
#[derive(Serialize, Deserialize, Clone, PartialEq)]
pub struct Message {
pub role: String,
pub content: String,
Expand Down Expand Up @@ -66,6 +66,8 @@ pub fn format_messages_for_llm(messages: Vec<Message>, chat_format: String) -> S
"<start_of_turn>"
} else if chat_format == "chatml" {
"<im_start>"
} else if chat_format == "vicuna" {
""
} else {
""
};
Expand All @@ -75,6 +77,8 @@ pub fn format_messages_for_llm(messages: Vec<Message>, chat_format: String) -> S
"<end_of_turn>"
} else if chat_format == "chatml" {
"<im_end>"
} else if chat_format == "vicuna" {
"\n"
} else {
""
};
Expand All @@ -85,6 +89,8 @@ pub fn format_messages_for_llm(messages: Vec<Message>, chat_format: String) -> S
"<start_of_turn>"
} else if chat_format == "chatml" {
"<im_start>"
} else if chat_format == "vicuna" {
""
} else {
""
};
Expand All @@ -94,6 +100,8 @@ pub fn format_messages_for_llm(messages: Vec<Message>, chat_format: String) -> S
"<end_of_turn>"
} else if chat_format == "chatml" {
"<im_end>"
} else if chat_format == "vicuna" {
"\n"
} else {
""
};
Expand All @@ -104,6 +112,8 @@ pub fn format_messages_for_llm(messages: Vec<Message>, chat_format: String) -> S
"<start_of_turn>"
} else if chat_format == "chatml" {
"<im_start>"
} else if chat_format == "vicuna" {
""
} else {
""
};
Expand All @@ -113,6 +123,8 @@ pub fn format_messages_for_llm(messages: Vec<Message>, chat_format: String) -> S
"<end_of_turn>"
} else if chat_format == "chatml" {
"<im_end>"
} else if chat_format == "vicuna" {
"\n"
} else {
""
};
Expand All @@ -123,6 +135,8 @@ pub fn format_messages_for_llm(messages: Vec<Message>, chat_format: String) -> S
"model"
} else if chat_format == "chatml" {
"system"
} else if chat_format == "vicuna" {
"System: "
} else {
""
};
Expand All @@ -132,6 +146,8 @@ pub fn format_messages_for_llm(messages: Vec<Message>, chat_format: String) -> S
"user"
} else if chat_format == "chatml" {
"user"
} else if chat_format == "vicuna" {
"User: "
} else {
""
};
Expand All @@ -141,31 +157,51 @@ pub fn format_messages_for_llm(messages: Vec<Message>, chat_format: String) -> S
"model"
} else if chat_format == "chatml" {
"assistant"
} else if chat_format == "vicuna" {
"Assistant: "
} else {
""
};

for message in messages {
for (index, message) in messages.iter().enumerate() {
// check if last message, safely get if this is the last message
let is_last = index == messages.len() - 1;
match message.role.as_str() {
// remove <|im_end|> from anywhere in message
"system" => {
let message_content = message.content.replace("<|im_end|>", "");
formatted_history += &format!(
"{}{}{} {}{}{}\n",
bos_token, sys_token, sys_name, message.content, sys_end_token, eos_token
bos_token, sys_token, sys_name, message_content, sys_end_token, eos_token
);
}
"user" => {
// Assuming user messages should be formatted as instructions
let message_content = message.content.replace("<|im_end|>", "");
formatted_history += &format!(
"{}{}{} {}{}\n",
bos_token, inst_token, user_name, message.content, inst_end_token
bos_token, inst_token, user_name, message_content, inst_end_token
);
}
"assistant" => {
// Close the instruction tag for user/system messages and add the assistant's response
formatted_history += &format!(
"{}{} {}{}{}\n",
assist_token, assist_name, message.content, assist_end_token, eos_token
);
let message_content = message.content.replace("<|im_end|>", "");
if is_last {
formatted_history += &format!(
"{}{}{} {}\n",
bos_token, assist_token, assist_name, message_content
);
} else {
formatted_history += &format!(
"{}{}{} {}{}{}\n",
bos_token,
assist_token,
assist_name,
message_content,
assist_end_token,
eos_token
);
}
}
_ => {}
}
Expand Down
Loading

0 comments on commit 7a39308

Please sign in to comment.