Skip to content

Commit b378d28

Browse files
committed
upgrade to llama.cpp b2797
Signed-off-by: Konstantin Herud <[email protected]>
1 parent c5e1e38 commit b378d28

File tree

5 files changed

+43
-9
lines changed

5 files changed

+43
-9
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ FetchContent_MakeAvailable(json)
2222
FetchContent_Declare(
2323
llama.cpp
2424
GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git
25-
GIT_TAG b2702
25+
GIT_TAG b2797
2626
)
2727
FetchContent_MakeAvailable(llama.cpp)
2828

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
![Java 11+](https://img.shields.io/badge/Java-11%2B-informational)
2-
![llama.cpp b2702](https://img.shields.io/badge/llama.cpp-%23b2702-informational)
2+
![llama.cpp b2797](https://img.shields.io/badge/llama.cpp-%23b2797-informational)
33

44
# Java Bindings for [llama.cpp](https://github.com/ggerganov/llama.cpp)
55

@@ -18,7 +18,7 @@ This repository provides Java bindings for the C++ library.
1818
3. [Android](#importing-in-android)
1919

2020
> [!NOTE]
21-
> Now with Llama 3 support
21+
> Now with support for Llama 3, Phi-3, and flash attention
2222
2323
## Quick Start
2424

@@ -28,7 +28,7 @@ Access this library via Maven:
2828
<dependency>
2929
<groupId>de.kherud</groupId>
3030
<artifactId>llama</artifactId>
31-
<version>3.0.1</version>
31+
<version>3.0.2</version>
3232
</dependency>
3333
```
3434

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
<groupId>de.kherud</groupId>
66
<artifactId>llama</artifactId>
7-
<version>3.0.1</version>
7+
<version>3.0.2</version>
88
<packaging>jar</packaging>
99

1010
<name>${project.groupId}:${project.artifactId}</name>

src/main/cpp/server.hpp

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,7 @@ struct server_context
910910
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
911911
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
912912
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
913-
slot.params.seed = json_value(data, "seed", default_params.seed);
913+
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
914914
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
915915
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
916916
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
@@ -1209,7 +1209,7 @@ struct server_context
12091209
bool process_token(completion_token_output &result, server_slot &slot)
12101210
{
12111211
// remember which tokens were sampled - used for repetition penalties during sampling
1212-
const std::string token_str = llama_token_to_piece(ctx, result.tok);
1212+
const std::string token_str = llama_token_to_piece(ctx, result.tok, false);
12131213
slot.sampled = result.tok;
12141214

12151215
// search stop word and delete it
@@ -1314,6 +1314,27 @@ struct server_context
13141314
LOG_VERBOSE("eos token found", {});
13151315
}
13161316

1317+
auto n_ctx_train = llama_n_ctx_train(model);
1318+
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1
1319+
&& slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
1320+
LOG_WARNING("n_predict is not set and self-context extend is disabled."
1321+
" Limiting generated tokens to n_ctx_train to avoid EOS-less generation infinite loop", {
1322+
{ "id_slot", slot.id },
1323+
{ "params.n_predict", slot.params.n_predict },
1324+
{ "slot.n_prompt_tokens", slot.n_prompt_tokens },
1325+
{ "slot.n_decoded", slot.n_decoded },
1326+
{ "slot.n_predict", slot.n_predict },
1327+
{ "n_slots", params.n_parallel },
1328+
{ "slot.n_ctx", slot.n_ctx },
1329+
{ "n_ctx", n_ctx },
1330+
{ "n_ctx_train", n_ctx_train },
1331+
{ "ga_n", slot.ga_n },
1332+
});
1333+
slot.truncated = true;
1334+
slot.stopped_limit = true;
1335+
slot.has_next_token = false; // stop prediction
1336+
}
1337+
13171338
LOG_VERBOSE("next token", {
13181339
{"id_slot", slot.id},
13191340
{"id_task", slot.id_task},
@@ -1475,8 +1496,9 @@ struct server_context
14751496
{
14761497
const std::vector<llama_token> stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false);
14771498

1499+
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
14781500
probs = std::vector<completion_token_output>(slot.generated_token_probs.begin(),
1479-
slot.generated_token_probs.end() - stop_word_toks.size());
1501+
slot.generated_token_probs.end() - safe_offset);
14801502
}
14811503
else
14821504
{
@@ -2313,7 +2335,7 @@ struct server_context
23132335
});
23142336

23152337
// process the created batch of tokens
2316-
for (int32_t i = 0; i < (int32_t)batch.n_tokens; i += n_batch)
2338+
for (int32_t i = 0; i < batch.n_tokens; i += n_batch)
23172339
{
23182340
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
23192341

@@ -2534,6 +2556,7 @@ static void server_params_parse(json jparams, server_params &sparams, gpt_params
25342556
params.embedding = json_value(jparams, "embedding", default_params.embedding);
25352557
params.escape = json_value(jparams, "escape", default_params.escape);
25362558
params.cont_batching = json_value(jparams, "cont_batching", default_params.cont_batching);
2559+
params.flash_attn = json_value(jparams, "flash_attn", default_params.flash_attn);
25372560
params.input_prefix_bos = json_value(jparams, "input_prefix_bos", default_params.input_prefix_bos);
25382561
params.ignore_eos = json_value(jparams, "ignore_eos", default_params.ignore_eos);
25392562
params.use_mmap = json_value(jparams, "use_mmap", default_params.use_mmap);
@@ -2596,4 +2619,6 @@ static void server_params_parse(json jparams, server_params &sparams, gpt_params
25962619
LOG_WARNING("llama.cpp was compiled without CUDA. It is not possible to set a main GPU.", {});
25972620
#endif
25982621
}
2622+
2623+
gpt_params_handle_model_default(params);
25992624
}

src/main/java/de/kherud/llama/ModelParameters.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ public final class ModelParameters extends JsonParameters {
6161
private static final String PARAM_LORA_BASE = "lora_base";
6262
private static final String PARAM_EMBEDDING = "embedding";
6363
private static final String PARAM_CONT_BATCHING = "cont_batching";
64+
private static final String PARAM_FLASH_ATTENTION = "flash_attn";
6465
private static final String PARAM_INPUT_PREFIX_BOS = "input_prefix_bos";
6566
private static final String PARAM_IGNORE_EOS = "ignore_eos";
6667
private static final String PARAM_USE_MMAP = "use_mmap";
@@ -526,6 +527,14 @@ public ModelParameters setContinuousBatching(boolean contBatching) {
526527
return this;
527528
}
528529

530+
/**
531+
* Whether to enable Flash Attention (default: disabled)
532+
*/
533+
public ModelParameters setFlashAttention(boolean flashAttention) {
534+
parameters.put(PARAM_FLASH_ATTENTION, String.valueOf(flashAttention));
535+
return this;
536+
}
537+
529538
/**
530539
* Whether to add prefix BOS to user inputs, preceding the `--in-prefix` string
531540
*/

0 commit comments

Comments
 (0)