-
Notifications
You must be signed in to change notification settings - Fork 12k
Server: use llama_chat_apply_template #5593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,7 +37,7 @@ struct server_params | |
std::string hostname = "127.0.0.1"; | ||
std::vector<std::string> api_keys; | ||
std::string public_path = "examples/server/public"; | ||
std::string chat_template = "chatml"; | ||
std::string chat_template = ""; | ||
int32_t port = 8080; | ||
int32_t read_timeout = 600; | ||
int32_t write_timeout = 600; | ||
|
@@ -1937,8 +1937,9 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, | |
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); | ||
printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`"); | ||
printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`"); | ||
printf(" --chat-template FORMAT_NAME"); | ||
printf(" set chat template, possible value is: llama2, chatml (default %s)", sparams.chat_template.c_str()); | ||
printf(" --chat-template JINJA_TEMPLATE"); | ||
printf(" set custom jinja chat template (default: template taken from model's metadata)"); | ||
printf(" Note: only commonly used templates are accepted, since we don't have jinja parser"); | ||
printf("\n"); | ||
} | ||
|
||
|
@@ -2390,12 +2391,13 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, | |
break; | ||
} | ||
std::string value(argv[i]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah it's unused, I forgot to remove that. It's now removed |
||
if (value != "chatml" && value != "llama2") { | ||
fprintf(stderr, "error: chat template can be \"llama2\" or \"chatml\", but got: %s\n", value.c_str()); | ||
if (!verify_custom_template(argv[i])) { | ||
fprintf(stderr, "error: the supplied chat template is not supported: %s\n", argv[i]); | ||
fprintf(stderr, "note: llama.cpp does not use jinja parser, we only support commonly used templates\n"); | ||
invalid_param = true; | ||
break; | ||
} | ||
sparams.chat_template = value; | ||
sparams.chat_template = argv[i]; | ||
} | ||
else if (arg == "--override-kv") | ||
{ | ||
|
@@ -2913,7 +2915,7 @@ int main(int argc, char **argv) | |
if (!validate_api_key(req, res)) { | ||
return; | ||
} | ||
json data = oaicompat_completion_params_parse(json::parse(req.body), sparams.chat_template); | ||
json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template); | ||
|
||
const int task_id = llama.queue_tasks.get_new_id(); | ||
llama.queue_results.add_waiting_task_id(task_id); | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -167,50 +167,47 @@ static T json_value(const json &body, const std::string &key, const T &default_v | |||||
: default_value; | ||||||
} | ||||||
|
||||||
inline std::string format_llama2(std::vector<json> messages) | ||||||
{ | ||||||
std::ostringstream output; | ||||||
bool is_inside_turn = false; | ||||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid | ||||||
inline bool verify_custom_template(std::string tmpl) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
llama_chat_message chat[] = {{"user", "test"}}; | ||||||
std::vector<char> buf(1); | ||||||
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, buf.data(), buf.size()); | ||||||
return res >= 0; | ||||||
} | ||||||
|
||||||
for (auto it = messages.begin(); it != messages.end(); ++it) { | ||||||
if (!is_inside_turn) { | ||||||
output << "[INST] "; | ||||||
} | ||||||
std::string role = json_value(*it, "role", std::string("user")); | ||||||
std::string content = json_value(*it, "content", std::string("")); | ||||||
if (role == "system") { | ||||||
output << "<<SYS>>\n" << content << "\n<<SYS>>\n\n"; | ||||||
is_inside_turn = true; | ||||||
} else if (role == "user") { | ||||||
output << content << " [/INST]"; | ||||||
is_inside_turn = true; | ||||||
} else { | ||||||
output << " " << content << " </s>"; | ||||||
is_inside_turn = false; | ||||||
} | ||||||
// Format given chat. If tmpl is empty, we take the template from model metadata | ||||||
inline std::string format_chat(const struct llama_model * model, const std::string tmpl, std::vector<json> messages) | ||||||
{ | ||||||
size_t alloc_size = 0; | ||||||
// vector holding all allocated string to be passed to llama_chat_apply_template | ||||||
std::vector<std::string> str(messages.size() * 2); | ||||||
std::vector<llama_chat_message> chat(messages.size()); | ||||||
|
||||||
for (size_t i = 0; i < messages.size(); ++i) { | ||||||
auto &curr_msg = messages[i]; | ||||||
str[i] = json_value(curr_msg, "role", std::string("")); | ||||||
str[i + 1] = json_value(curr_msg, "content", std::string("")); | ||||||
alloc_size += str[i + 1].length(); | ||||||
chat[i].role = str[i].c_str(); | ||||||
chat[i].content = str[i + 1].c_str(); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There seems to be a bug here. Maybe change to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank for notice that. That's why I noticed that the bot's response is quite weird when I test this PR yesterday. Fixed on c53b34d Looking at the debug log, I can confirm that the formatted chat is correct:
|
||||||
} | ||||||
|
||||||
LOG_VERBOSE("format_llama2", {{"text", output.str()}}); | ||||||
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); | ||||||
std::vector<char> buf(alloc_size * 2); | ||||||
|
||||||
return output.str(); | ||||||
} | ||||||
|
||||||
inline std::string format_chatml(std::vector<json> messages) | ||||||
{ | ||||||
std::ostringstream chatml_msgs; | ||||||
// run the first time to get the total output length | ||||||
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size()); | ||||||
|
||||||
for (auto it = messages.begin(); it != messages.end(); ++it) { | ||||||
chatml_msgs << "<|im_start|>" | ||||||
<< json_value(*it, "role", std::string("user")) << '\n'; | ||||||
chatml_msgs << json_value(*it, "content", std::string("")) | ||||||
<< "<|im_end|>\n"; | ||||||
// if it turns out that our buffer is too small, we resize it | ||||||
if ((size_t) res > buf.size()) { | ||||||
buf.resize(res); | ||||||
res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size()); | ||||||
} | ||||||
|
||||||
chatml_msgs << "<|im_start|>assistant" << '\n'; | ||||||
|
||||||
LOG_VERBOSE("format_chatml", {{"text", chatml_msgs.str()}}); | ||||||
std::string formatted_chat(buf.data(), buf.size()); | ||||||
LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}}); | ||||||
|
||||||
return chatml_msgs.str(); | ||||||
return formatted_chat; | ||||||
} | ||||||
|
||||||
// | ||||||
|
Uh oh!
There was an error while loading. Please reload this page.