Skip to content

Commit cc45878

Browse files
ochafikfuryhawk
authored andcommitted
server: update deepseek reasoning format (pass reasoning_content as diffs) (ggml-org#13933)
* server: update deepseek reasoning format (now in reasoning_content diffs), add legacy option for compat * update unit/test_tool_call.py::test_thoughts
1 parent 169839b commit cc45878

File tree

8 files changed

+30
-19
lines changed

8 files changed

+30
-19
lines changed

common/arg.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2869,6 +2869,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
28692869
"(default: deepseek)",
28702870
[](common_params & params, const std::string & value) {
28712871
/**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; }
2872+
else if (value == "deepseek-legacy") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY; }
28722873
else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; }
28732874
else { throw std::invalid_argument("invalid value"); }
28742875
}

common/chat.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,10 @@ json common_chat_msg::to_json_oaicompat() const
8282

8383
std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg) {
8484
std::vector<common_chat_msg_diff> diffs;
85-
// if (previous_msg.reasoning_content != current.reasoning_content) {
86-
// auto & diff = diffs.emplace_back();
87-
// diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, current.reasoning_content);
88-
// }
85+
if (previous_msg.reasoning_content != new_msg.reasoning_content) {
86+
auto & diff = diffs.emplace_back();
87+
diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, new_msg.reasoning_content);
88+
}
8989
if (previous_msg.content != new_msg.content) {
9090
auto & diff = diffs.emplace_back();
9191
diff.content_delta = string_diff(previous_msg.content, new_msg.content);
@@ -385,9 +385,9 @@ json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & t
385385

386386
template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
387387
json delta = json::object();
388-
// if (!diff.reasoning_content_delta.empty()) {
389-
// delta["reasoning_content"] = msg.reasoning_content;
390-
// }
388+
if (!diff.reasoning_content_delta.empty()) {
389+
delta["reasoning_content"] = diff.reasoning_content_delta;
390+
}
391391
if (!diff.content_delta.empty()) {
392392
delta["content"] = diff.content_delta;
393393
}
@@ -598,6 +598,7 @@ const char * common_reasoning_format_name(common_reasoning_format format) {
598598
switch (format) {
599599
case COMMON_REASONING_FORMAT_NONE: return "none";
600600
case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek";
601+
case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy";
601602
default:
602603
throw std::runtime_error("Unknown reasoning format");
603604
}

common/chat.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ struct common_chat_msg {
7070
};
7171

7272
struct common_chat_msg_diff {
73-
// std::string reasoning_content_delta;
73+
std::string reasoning_content_delta;
7474
std::string content_delta;
7575
size_t tool_call_index = std::string::npos;
7676
common_chat_tool_call tool_call_delta;

common/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,8 @@ struct common_params_vocoder {
215215

216216
enum common_reasoning_format {
217217
COMMON_REASONING_FORMAT_NONE,
218-
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`
218+
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
219+
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
219220
};
220221

221222
struct common_params {

tests/test-chat.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
using json = nlohmann::ordered_json;
2020

2121
static std::ostream & operator<<(std::ostream & os, const common_chat_msg_diff & diff) {
22-
// os << "reasoning_content_delta: " << diff.reasoning_content_delta << '\n';
2322
os << "{ content_delta: " << diff.content_delta << "; ";
23+
os << "reasoning_content_delta: " << diff.reasoning_content_delta << "; ";
2424
if (diff.tool_call_index != std::string::npos) {
2525
os << "tool_call_index: " << diff.tool_call_index << "; ";
2626
os << "tool_call_delta.name: " << diff.tool_call_delta.name << "; ";

tools/server/server.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ struct server_task {
360360
params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format;
361361
}
362362
params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format;
363-
params.oaicompat_chat_syntax.reasoning_in_content = params.stream;
363+
params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (params_base.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
364364
params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
365365
params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false);
366366
}

tools/server/tests/unit/test_tool_call.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -499,13 +499,12 @@ def do_test_calc_result(server: ServerProcess, result_override: str | None, n_pr
499499

500500

501501
@pytest.mark.slow
502-
@pytest.mark.parametrize("n_predict,reasoning_format,stream,expect_reasoning_content,expect_content,hf_repo,template_override", [
503-
(128, 'deepseek', CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
504-
(128, None, CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
505-
(1024, 'deepseek', CompletionMode.NORMAL, "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
506-
(1024, 'deepseek', CompletionMode.STREAMED, None, "^<think>I need to calculate [\\s\\S]*?</think>To find the sum of [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
507-
(1024, 'deepseek', CompletionMode.NORMAL, "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
508-
(1024, 'deepseek', CompletionMode.STREAMED, None, "^<think>First, I [\\s\\S]*?</think>To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
502+
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
503+
@pytest.mark.parametrize("n_predict,reasoning_format,expect_reasoning_content,expect_content,hf_repo,template_override", [
504+
(128, 'deepseek', None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
505+
(128, None, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
506+
(1024, 'deepseek', "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
507+
(1024, 'deepseek', "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
509508
# (1024, 'none', CompletionMode.NORMAL, None, "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
510509
# (128, 'deepseek', None, "^Okay, let me figure out the sum of 102 and 7[\\s\\S]*", "bartowski/Qwen_QwQ-32B-GGUF:Q4_K_M", None),
511510
])

tools/server/tests/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,12 @@ def make_any_request(
308308
stream = data.get('stream', False)
309309
if stream:
310310
content: list[str] = []
311+
reasoning_content: list[str] = []
311312
tool_calls: list[dict] = []
312313
finish_reason: Optional[str] = None
313314

314315
content_parts = 0
316+
reasoning_content_parts = 0
315317
tool_call_parts = 0
316318
arguments_parts = 0
317319

@@ -322,6 +324,10 @@ def make_any_request(
322324
assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!'
323325
content.append(choice['delta']['content'])
324326
content_parts += 1
327+
if choice['delta'].get('reasoning_content') is not None:
328+
assert len(choice['delta']['reasoning_content']) > 0, f'Expected non empty reasoning_content delta!'
329+
reasoning_content.append(choice['delta']['reasoning_content'])
330+
reasoning_content_parts += 1
325331
if choice['delta'].get('finish_reason') is not None:
326332
finish_reason = choice['delta']['finish_reason']
327333
for tc in choice['delta'].get('tool_calls', []):
@@ -349,8 +355,10 @@ def make_any_request(
349355
tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name']
350356
if fct.get('arguments') is not None:
351357
tool_call['function']['arguments'] += fct['arguments']
358+
arguments_parts += 1
359+
tool_call_parts += 1
352360

353-
print(f'Streamed response had {content_parts} content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
361+
print(f'Streamed response had {content_parts} content parts, {reasoning_content_parts} reasoning_content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
354362
result = dict(
355363
choices=[
356364
dict(
@@ -359,6 +367,7 @@ def make_any_request(
359367
message=dict(
360368
role='assistant',
361369
content=''.join(content) if content else None,
370+
reasoning_content=''.join(reasoning_content) if reasoning_content else None,
362371
tool_calls=tool_calls if tool_calls else None,
363372
),
364373
)

0 commit comments

Comments
 (0)