Skip to content

Commit

Permalink
add chatml format for mixtral dolphin
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Kennedy committed Mar 27, 2024
1 parent fa5d739 commit bd6ed59
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 1 deletion.
2 changes: 2 additions & 0 deletions scripts/twitch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ GREETING="Hi I'm Alice, ask me a question by typing '!message Alice <message>' o
USE_API=1
MODEL=mistral
MODEL_ID=7b-it
CHAT_FORMAT=chatml
MAX_TOKENS=1500
TEMPERATURE=0.8
CONTEXT_SIZE=16000
Expand Down Expand Up @@ -114,6 +115,7 @@ DYLD_LIBRARY_PATH=`pwd`:/usr/local/lib:$DYLD_LIBRARY_PATH \
$USE_API_CMD \
--candle-llm $MODEL \
--llm-history-size $CONTEXT_SIZE \
--chat-format $CHAT_FORMAT \
--model-id $MODEL_ID \
--temperature $TEMPERATURE \
--pipeline-concurrency $PIPELINE_CONCURRENCY \
Expand Down
9 changes: 9 additions & 0 deletions src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ pub struct Args {
)]
pub query: String,

/// Chat Format - LLM chat format to use, llama2, chatml, gemma, ""
#[clap(
long,
env = "CHAT_FORMAT",
default_value = "",
help = "Chat Format - LLM chat format to use, llama2, chatml, gemma, \"\""
)]
pub chat_format: String,

/// Temperature
#[clap(
long,
Expand Down
8 changes: 7 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,13 @@ async fn main() {
"".to_string()
}
} else if args.use_api {
"llama2".to_string()
if args.chat_format == "chatml" {
"chatml".to_string()
} else if args.chat_format == "llama2" {
"llama2".to_string()
} else {
"".to_string()
}
} else {
"".to_string()
};
Expand Down
18 changes: 18 additions & 0 deletions src/openai_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,17 @@ pub fn format_messages_for_llm(messages: Vec<Message>, chat_format: String) -> S
"[INST]"
} else if chat_format == "google" {
"<start_of_turn>"
} else if chat_format == "chatml" {
"<im_start>"
} else {
""
};
let inst_end_token = if chat_format == "llama2" {
"[/INST]"
} else if chat_format == "google" {
"<end_of_turn>"
} else if chat_format == "chatml" {
"<im_end>"
} else {
""
};
Expand All @@ -79,13 +83,17 @@ pub fn format_messages_for_llm(messages: Vec<Message>, chat_format: String) -> S
""
} else if chat_format == "google" {
"<start_of_turn>"
} else if chat_format == "chatml" {
"<im_start>"
} else {
""
};
let assist_end_token = if chat_format == "llama2" {
""
} else if chat_format == "google" {
"<end_of_turn>"
} else if chat_format == "chatml" {
"<im_end>"
} else {
""
};
Expand All @@ -94,13 +102,17 @@ pub fn format_messages_for_llm(messages: Vec<Message>, chat_format: String) -> S
"<<SYS>>"
} else if chat_format == "google" {
"<start_of_turn>"
} else if chat_format == "chatml" {
"<im_start>"
} else {
""
};
let sys_end_token = if chat_format == "llama2" {
"<</SYS>>"
} else if chat_format == "google" {
"<end_of_turn>"
} else if chat_format == "chatml" {
"<im_end>"
} else {
""
};
Expand All @@ -109,20 +121,26 @@ pub fn format_messages_for_llm(messages: Vec<Message>, chat_format: String) -> S
""
} else if chat_format == "google" {
"model"
} else if chat_format == "chatml" {
"system"
} else {
""
};
let user_name = if chat_format == "llama2" {
""
} else if chat_format == "google" {
"user"
} else if chat_format == "chatml" {
"user"
} else {
""
};
let assist_name = if chat_format == "llama2" {
""
} else if chat_format == "google" {
"model"
} else if chat_format == "chatml" {
"assistant"
} else {
""
};
Expand Down

0 comments on commit bd6ed59

Please sign in to comment.