Spaces:
Running
Running
talk-llama : sync llama.cpp
Browse files- examples/talk-llama/CMakeLists.txt +3 -0
- examples/talk-llama/llama-arch.cpp +3 -0
- examples/talk-llama/llama-arch.h +2 -0
- examples/talk-llama/llama-batch.cpp +19 -12
- examples/talk-llama/llama-batch.h +15 -10
- examples/talk-llama/llama-context.cpp +226 -151
- examples/talk-llama/llama-context.h +25 -8
- examples/talk-llama/llama-graph.cpp +69 -61
- examples/talk-llama/llama-graph.h +25 -24
- examples/talk-llama/llama-hparams.h +3 -0
- examples/talk-llama/llama-kv-cache-recurrent.cpp +1132 -0
- examples/talk-llama/llama-kv-cache-recurrent.h +191 -0
- examples/talk-llama/llama-kv-cache-unified-iswa.cpp +249 -0
- examples/talk-llama/llama-kv-cache-unified-iswa.h +136 -0
- examples/talk-llama/llama-kv-cache-unified.cpp +1717 -0
- examples/talk-llama/llama-kv-cache-unified.h +278 -0
- examples/talk-llama/llama-kv-cache.cpp +0 -2738
- examples/talk-llama/llama-kv-cache.h +14 -472
- examples/talk-llama/llama-kv-cells.h +37 -6
- examples/talk-llama/llama-memory.h +44 -0
- examples/talk-llama/llama-model.cpp +62 -59
- examples/talk-llama/llama.h +12 -8
examples/talk-llama/CMakeLists.txt
CHANGED
|
@@ -17,6 +17,9 @@ if (WHISPER_SDL2)
|
|
| 17 |
llama-impl.cpp
|
| 18 |
llama-io.cpp
|
| 19 |
llama-kv-cache.cpp
|
|
|
|
|
|
|
|
|
|
| 20 |
llama-memory.cpp
|
| 21 |
llama-mmap.cpp
|
| 22 |
llama-model-loader.cpp
|
|
|
|
| 17 |
llama-impl.cpp
|
| 18 |
llama-io.cpp
|
| 19 |
llama-kv-cache.cpp
|
| 20 |
+
llama-kv-cache-unified.cpp
|
| 21 |
+
llama-kv-cache-unified-iswa.cpp
|
| 22 |
+
llama-kv-cache-recurrent.cpp
|
| 23 |
llama-memory.cpp
|
| 24 |
llama-mmap.cpp
|
| 25 |
llama-model-loader.cpp
|
examples/talk-llama/llama-arch.cpp
CHANGED
|
@@ -174,6 +174,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|
| 174 |
{ LLM_KV_CONVNEXT_EMBEDDING_LENGTH, "%s.convnext.embedding_length" },
|
| 175 |
{ LLM_KV_CONVNEXT_BLOCK_COUNT, "%s.convnext.block_count" },
|
| 176 |
|
|
|
|
|
|
|
| 177 |
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
| 178 |
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
|
| 179 |
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
|
|
@@ -448,6 +450,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|
| 448 |
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
|
| 449 |
{ LLM_TENSOR_POS_EMBD, "position_embd" },
|
| 450 |
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
|
|
|
|
| 451 |
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 452 |
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 453 |
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
|
|
|
| 174 |
{ LLM_KV_CONVNEXT_EMBEDDING_LENGTH, "%s.convnext.embedding_length" },
|
| 175 |
{ LLM_KV_CONVNEXT_BLOCK_COUNT, "%s.convnext.block_count" },
|
| 176 |
|
| 177 |
+
{ LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
|
| 178 |
+
|
| 179 |
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
| 180 |
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
|
| 181 |
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
|
|
|
|
| 450 |
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
|
| 451 |
{ LLM_TENSOR_POS_EMBD, "position_embd" },
|
| 452 |
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
|
| 453 |
+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
| 454 |
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 455 |
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 456 |
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
examples/talk-llama/llama-arch.h
CHANGED
|
@@ -213,6 +213,8 @@ enum llm_kv {
|
|
| 213 |
LLM_KV_CONVNEXT_EMBEDDING_LENGTH,
|
| 214 |
LLM_KV_CONVNEXT_BLOCK_COUNT,
|
| 215 |
|
|
|
|
|
|
|
| 216 |
// deprecated:
|
| 217 |
LLM_KV_TOKENIZER_PREFIX_ID,
|
| 218 |
LLM_KV_TOKENIZER_SUFFIX_ID,
|
|
|
|
| 213 |
LLM_KV_CONVNEXT_EMBEDDING_LENGTH,
|
| 214 |
LLM_KV_CONVNEXT_BLOCK_COUNT,
|
| 215 |
|
| 216 |
+
LLM_KV_CLASSIFIER_OUTPUT_LABELS,
|
| 217 |
+
|
| 218 |
// deprecated:
|
| 219 |
LLM_KV_TOKENIZER_PREFIX_ID,
|
| 220 |
LLM_KV_TOKENIZER_SUFFIX_ID,
|
examples/talk-llama/llama-batch.cpp
CHANGED
|
@@ -15,24 +15,31 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
|
|
| 15 |
break;
|
| 16 |
}
|
| 17 |
}
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
llama_ubatch ubatch = {
|
| 25 |
/*equal_seqs =*/ true,
|
| 26 |
/*n_tokens =*/ 0,
|
| 27 |
/*n_seq_tokens =*/ 0,
|
| 28 |
/*n_seqs =*/ 0,
|
| 29 |
-
/*token =*/ !has_embd ?
|
| 30 |
-
/*embd =*/ has_embd ?
|
| 31 |
-
/*pos =*/
|
| 32 |
-
/*n_seq_id =*/
|
| 33 |
-
/*seq_id =*/
|
| 34 |
-
/*output =*/
|
| 35 |
};
|
|
|
|
| 36 |
return ubatch;
|
| 37 |
}
|
| 38 |
|
|
|
|
| 15 |
break;
|
| 16 |
}
|
| 17 |
}
|
| 18 |
+
|
| 19 |
+
udatas.push_back({});
|
| 20 |
+
|
| 21 |
+
auto & udata = udatas.back();
|
| 22 |
+
|
| 23 |
+
udata.token.resize(!has_embd ? n_ubatch : 0);
|
| 24 |
+
udata.embd.resize(has_embd ? n_embd * n_ubatch : 0);
|
| 25 |
+
udata.pos.resize(n_ubatch);
|
| 26 |
+
udata.n_seq_id.resize(n_ubatch);
|
| 27 |
+
udata.seq_id.resize(n_ubatch);
|
| 28 |
+
udata.output.resize(n_ubatch);
|
| 29 |
+
|
| 30 |
llama_ubatch ubatch = {
|
| 31 |
/*equal_seqs =*/ true,
|
| 32 |
/*n_tokens =*/ 0,
|
| 33 |
/*n_seq_tokens =*/ 0,
|
| 34 |
/*n_seqs =*/ 0,
|
| 35 |
+
/*token =*/ !has_embd ? udata.token.data() : nullptr,
|
| 36 |
+
/*embd =*/ has_embd ? udata.embd.data() : nullptr,
|
| 37 |
+
/*pos =*/ udata.pos.data(),
|
| 38 |
+
/*n_seq_id =*/ udata.n_seq_id.data(),
|
| 39 |
+
/*seq_id =*/ udata.seq_id.data(),
|
| 40 |
+
/*output =*/ udata.output.data(),
|
| 41 |
};
|
| 42 |
+
|
| 43 |
return ubatch;
|
| 44 |
}
|
| 45 |
|
examples/talk-llama/llama-batch.h
CHANGED
|
@@ -11,15 +11,15 @@ struct llama_ubatch {
|
|
| 11 |
bool equal_seqs;
|
| 12 |
// TODO: whole_seqs for embeddings?
|
| 13 |
|
| 14 |
-
uint32_t n_tokens;
|
| 15 |
uint32_t n_seq_tokens; // tokens per sequence
|
| 16 |
uint32_t n_seqs;
|
| 17 |
|
| 18 |
llama_token * token; // [n_tokens]
|
| 19 |
float * embd; // [n_embd, n_tokens]
|
| 20 |
llama_pos * pos; // [n_tokens]
|
| 21 |
-
int32_t * n_seq_id; // [n_seqs]
|
| 22 |
-
llama_seq_id ** seq_id; // [n_seqs]
|
| 23 |
int8_t * output; // [n_tokens]
|
| 24 |
};
|
| 25 |
|
|
@@ -49,13 +49,18 @@ struct llama_sbatch {
|
|
| 49 |
|
| 50 |
const llama_batch * batch = nullptr;
|
| 51 |
|
| 52 |
-
// buffers for the
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
|
| 61 |
|
|
|
|
| 11 |
bool equal_seqs;
|
| 12 |
// TODO: whole_seqs for embeddings?
|
| 13 |
|
| 14 |
+
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
| 15 |
uint32_t n_seq_tokens; // tokens per sequence
|
| 16 |
uint32_t n_seqs;
|
| 17 |
|
| 18 |
llama_token * token; // [n_tokens]
|
| 19 |
float * embd; // [n_embd, n_tokens]
|
| 20 |
llama_pos * pos; // [n_tokens]
|
| 21 |
+
int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
|
| 22 |
+
llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id;
|
| 23 |
int8_t * output; // [n_tokens]
|
| 24 |
};
|
| 25 |
|
|
|
|
| 49 |
|
| 50 |
const llama_batch * batch = nullptr;
|
| 51 |
|
| 52 |
+
// buffers for the ubatches
|
| 53 |
+
// TODO: very hacky, this needs a complete rework
|
| 54 |
+
struct ubatch_data {
|
| 55 |
+
std::vector<llama_token> token;
|
| 56 |
+
std::vector<float> embd;
|
| 57 |
+
std::vector<llama_pos> pos;
|
| 58 |
+
std::vector<int32_t> n_seq_id;
|
| 59 |
+
std::vector<llama_seq_id *> seq_id;
|
| 60 |
+
std::vector<int8_t> output;
|
| 61 |
+
};
|
| 62 |
+
|
| 63 |
+
std::vector<ubatch_data> udatas;
|
| 64 |
|
| 65 |
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
|
| 66 |
|
examples/talk-llama/llama-context.cpp
CHANGED
|
@@ -6,9 +6,10 @@
|
|
| 6 |
#include "llama-model.h"
|
| 7 |
#include "llama-kv-cache.h"
|
| 8 |
|
|
|
|
| 9 |
#include <cstring>
|
|
|
|
| 10 |
#include <stdexcept>
|
| 11 |
-
#include <cinttypes>
|
| 12 |
|
| 13 |
//
|
| 14 |
// llama_context
|
|
@@ -122,6 +123,11 @@ llama_context::llama_context(
|
|
| 122 |
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
| 123 |
}
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
if (!hparams.vocab_only) {
|
| 126 |
// GPU backends
|
| 127 |
for (auto * dev : model.devices) {
|
|
@@ -259,15 +265,9 @@ llama_context::llama_context(
|
|
| 259 |
|
| 260 |
// reserve worst-case graph
|
| 261 |
if (!hparams.vocab_only && memory) {
|
| 262 |
-
const uint32_t n_seqs =
|
| 263 |
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
| 264 |
|
| 265 |
-
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
| 266 |
-
|
| 267 |
-
// restore later
|
| 268 |
-
// TODO: something cleaner
|
| 269 |
-
const auto n_outputs_save = n_outputs;
|
| 270 |
-
|
| 271 |
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
| 272 |
|
| 273 |
int n_splits_pp = -1;
|
|
@@ -279,23 +279,17 @@ llama_context::llama_context(
|
|
| 279 |
// simulate full KV cache
|
| 280 |
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
| 281 |
|
| 282 |
-
kv_self->
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
cross.v_embd.clear();
|
| 285 |
|
| 286 |
// reserve pp graph first so that buffers are only allocated once
|
| 287 |
{
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
// max number of outputs
|
| 291 |
-
n_outputs = ubatch_pp.n_tokens;
|
| 292 |
-
|
| 293 |
-
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
| 294 |
-
|
| 295 |
-
auto * gf = graph_init();
|
| 296 |
-
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
| 297 |
-
|
| 298 |
-
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
| 299 |
throw std::runtime_error("failed to allocate compute pp buffers");
|
| 300 |
}
|
| 301 |
|
|
@@ -305,16 +299,8 @@ llama_context::llama_context(
|
|
| 305 |
|
| 306 |
// reserve with tg graph to get the number of splits and nodes
|
| 307 |
{
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
n_outputs = ubatch_tg.n_tokens;
|
| 311 |
-
|
| 312 |
-
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
|
| 313 |
-
|
| 314 |
-
auto * gf = graph_init();
|
| 315 |
-
graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
|
| 316 |
-
|
| 317 |
-
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
| 318 |
throw std::runtime_error("failed to allocate compute tg buffers");
|
| 319 |
}
|
| 320 |
|
|
@@ -324,22 +310,12 @@ llama_context::llama_context(
|
|
| 324 |
|
| 325 |
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
| 326 |
{
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
n_outputs = ubatch_pp.n_tokens;
|
| 330 |
-
|
| 331 |
-
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
| 332 |
-
|
| 333 |
-
auto * gf = graph_init();
|
| 334 |
-
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
| 335 |
-
|
| 336 |
-
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
| 337 |
throw std::runtime_error("failed to allocate compute pp buffers");
|
| 338 |
}
|
| 339 |
}
|
| 340 |
|
| 341 |
-
n_outputs = n_outputs_save;
|
| 342 |
-
|
| 343 |
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
| 344 |
ggml_backend_t backend = backend_ptrs[i];
|
| 345 |
ggml_backend_buffer_type_t buft = backend_buft[i];
|
|
@@ -453,36 +429,33 @@ const llama_kv_cache * llama_context::get_kv_self() const {
|
|
| 453 |
return kv_self;
|
| 454 |
}
|
| 455 |
|
| 456 |
-
|
| 457 |
-
|
|
|
|
|
|
|
| 458 |
|
| 459 |
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
| 460 |
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
|
| 466 |
-
|
| 467 |
-
// build worst-case graph
|
| 468 |
-
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
| 469 |
-
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
| 470 |
-
|
| 471 |
-
// simulate full KV cache
|
| 472 |
-
kv_self->set_full();
|
| 473 |
|
| 474 |
-
|
| 475 |
-
|
|
|
|
|
|
|
|
|
|
| 476 |
|
| 477 |
-
|
| 478 |
-
|
| 479 |
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
| 484 |
-
}
|
| 485 |
}
|
|
|
|
|
|
|
| 486 |
}
|
| 487 |
|
| 488 |
enum llama_pooling_type llama_context::pooling_type() const {
|
|
@@ -676,6 +649,49 @@ bool llama_context::apply_adapter_cvec(
|
|
| 676 |
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
| 677 |
}
|
| 678 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 679 |
int llama_context::encode(llama_batch & inp_batch) {
|
| 680 |
if (inp_batch.n_tokens == 0) {
|
| 681 |
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
|
@@ -737,8 +753,6 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
| 737 |
|
| 738 |
n_outputs = n_tokens;
|
| 739 |
|
| 740 |
-
//batch_manager->prepare(ubatch);
|
| 741 |
-
|
| 742 |
ggml_backend_sched_reset(sched.get());
|
| 743 |
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
| 744 |
|
|
@@ -749,26 +763,18 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
| 749 |
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
|
| 750 |
cparams.causal_attn = false;
|
| 751 |
|
| 752 |
-
|
| 753 |
-
auto res =
|
| 754 |
-
|
| 755 |
-
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
| 756 |
-
|
| 757 |
-
res->set_inputs(&ubatch);
|
| 758 |
|
| 759 |
cparams.causal_attn = causal_attn_org;
|
| 760 |
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
return -2;
|
| 769 |
-
case GGML_STATUS_FAILED:
|
| 770 |
-
default:
|
| 771 |
-
return -3;
|
| 772 |
}
|
| 773 |
|
| 774 |
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
|
@@ -889,8 +895,6 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 889 |
const int64_t n_tokens_all = batch.n_tokens;
|
| 890 |
const int64_t n_embd = hparams.n_embd;
|
| 891 |
|
| 892 |
-
llama_kv_cache_guard kv_guard(kv_self);
|
| 893 |
-
|
| 894 |
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
| 895 |
|
| 896 |
// TODO: move the validation to the llama_batch_allocr
|
|
@@ -936,7 +940,48 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 936 |
n_outputs_all = 1;
|
| 937 |
}
|
| 938 |
|
| 939 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 940 |
|
| 941 |
// reserve output buffer
|
| 942 |
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
|
@@ -944,13 +989,10 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 944 |
return -2;
|
| 945 |
};
|
| 946 |
|
| 947 |
-
// handle any pending defrags/shifts
|
| 948 |
-
kv_self_update();
|
| 949 |
-
|
| 950 |
int64_t n_outputs_prev = 0;
|
| 951 |
|
| 952 |
-
|
| 953 |
-
|
| 954 |
|
| 955 |
// count the outputs in this u_batch
|
| 956 |
{
|
|
@@ -969,33 +1011,37 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 969 |
n_outputs = n_outputs_new;
|
| 970 |
}
|
| 971 |
|
| 972 |
-
// find KV slot
|
| 973 |
-
if (!kv_self->find_slot(ubatch)) {
|
| 974 |
-
return 1;
|
| 975 |
-
}
|
| 976 |
-
|
| 977 |
ggml_backend_sched_reset(sched.get());
|
| 978 |
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
| 979 |
|
| 980 |
-
|
| 981 |
-
auto res =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 982 |
|
| 983 |
-
|
|
|
|
| 984 |
|
| 985 |
-
|
|
|
|
| 986 |
|
| 987 |
-
|
|
|
|
|
|
|
|
|
|
| 988 |
|
| 989 |
-
|
| 990 |
-
|
| 991 |
-
|
| 992 |
-
|
| 993 |
-
|
| 994 |
-
|
| 995 |
-
|
| 996 |
-
case
|
| 997 |
-
|
| 998 |
-
|
| 999 |
}
|
| 1000 |
}
|
| 1001 |
|
|
@@ -1082,10 +1128,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 1082 |
}
|
| 1083 |
|
| 1084 |
n_outputs_prev += n_outputs;
|
| 1085 |
-
}
|
| 1086 |
-
|
| 1087 |
-
// finalize the batch processing
|
| 1088 |
-
kv_guard.commit();
|
| 1089 |
|
| 1090 |
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
| 1091 |
n_outputs = n_outputs_all;
|
|
@@ -1094,7 +1137,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 1094 |
{
|
| 1095 |
bool sorted_output = true;
|
| 1096 |
|
| 1097 |
-
auto & out_ids =
|
| 1098 |
|
| 1099 |
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
|
| 1100 |
|
|
@@ -1254,11 +1297,52 @@ ggml_cgraph * llama_context::graph_init() {
|
|
| 1254 |
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
|
| 1255 |
}
|
| 1256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1257 |
llm_graph_result_ptr llama_context::graph_build(
|
| 1258 |
-
|
| 1259 |
-
|
| 1260 |
-
|
| 1261 |
-
|
|
|
|
| 1262 |
return model.build_graph(
|
| 1263 |
{
|
| 1264 |
/*.ctx =*/ ctx,
|
|
@@ -1270,7 +1354,7 @@ llm_graph_result_ptr llama_context::graph_build(
|
|
| 1270 |
/*.backend_cpu =*/ backend_cpu,
|
| 1271 |
/*.cvec =*/ &cvec,
|
| 1272 |
/*.loras =*/ &loras,
|
| 1273 |
-
/*.
|
| 1274 |
/*.cross =*/ &cross,
|
| 1275 |
/*.n_outputs =*/ n_outputs,
|
| 1276 |
/*.cb =*/ graph_get_cb(),
|
|
@@ -1951,7 +2035,6 @@ void llama_context::opt_epoch_iter(
|
|
| 1951 |
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
| 1952 |
|
| 1953 |
kv_self->clear();
|
| 1954 |
-
llama_kv_cache_guard kv_guard(kv_self);
|
| 1955 |
|
| 1956 |
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
|
| 1957 |
batch.n_tokens = n_batch;
|
|
@@ -1974,7 +2057,11 @@ void llama_context::opt_epoch_iter(
|
|
| 1974 |
|
| 1975 |
int64_t n_outputs_all = n_tokens_all;
|
| 1976 |
|
| 1977 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1978 |
|
| 1979 |
// reserve output buffer
|
| 1980 |
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
|
@@ -1982,20 +2069,19 @@ void llama_context::opt_epoch_iter(
|
|
| 1982 |
GGML_ABORT("TODO: handle this error");
|
| 1983 |
};
|
| 1984 |
|
| 1985 |
-
|
| 1986 |
-
|
|
|
|
| 1987 |
|
| 1988 |
n_outputs = ubatch.n_tokens;
|
| 1989 |
|
| 1990 |
-
|
| 1991 |
-
|
| 1992 |
-
|
| 1993 |
-
|
| 1994 |
-
GGML_ABORT("TODO: handle this error");
|
| 1995 |
}
|
| 1996 |
|
| 1997 |
auto * gf = graph_init();
|
| 1998 |
-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
|
| 1999 |
|
| 2000 |
struct ggml_context * ctx_compute_opt;
|
| 2001 |
{
|
|
@@ -2010,6 +2096,7 @@ void llama_context::opt_epoch_iter(
|
|
| 2010 |
}
|
| 2011 |
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
|
| 2012 |
ggml_opt_alloc(opt_ctx, train);
|
|
|
|
| 2013 |
res->set_inputs(&ubatch);
|
| 2014 |
{
|
| 2015 |
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
|
|
@@ -2027,10 +2114,10 @@ void llama_context::opt_epoch_iter(
|
|
| 2027 |
callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
|
| 2028 |
}
|
| 2029 |
ggml_free(ctx_compute_opt);
|
| 2030 |
-
}
|
| 2031 |
-
}
|
| 2032 |
|
| 2033 |
-
|
|
|
|
|
|
|
| 2034 |
}
|
| 2035 |
|
| 2036 |
void llama_context::opt_epoch(
|
|
@@ -2194,6 +2281,7 @@ llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
|
|
| 2194 |
return ctx->get_kv_self();
|
| 2195 |
}
|
| 2196 |
|
|
|
|
| 2197 |
void llama_kv_self_update(llama_context * ctx) {
|
| 2198 |
ctx->kv_self_update();
|
| 2199 |
}
|
|
@@ -2448,6 +2536,7 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
|
| 2448 |
return kv->seq_pos_max(seq_id);
|
| 2449 |
}
|
| 2450 |
|
|
|
|
| 2451 |
void llama_kv_self_defrag(llama_context * ctx) {
|
| 2452 |
auto * kv = ctx->get_kv_self();
|
| 2453 |
if (!kv) {
|
|
@@ -2589,22 +2678,8 @@ int32_t llama_encode(
|
|
| 2589 |
int32_t llama_decode(
|
| 2590 |
llama_context * ctx,
|
| 2591 |
llama_batch batch) {
|
| 2592 |
-
int ret = ctx->decode(batch);
|
| 2593 |
-
|
| 2594 |
-
// defrag and try again
|
| 2595 |
-
// TODO: distinguish return code when we are sure that even after defrag there is no space available
|
| 2596 |
-
if (ret == 1) {
|
| 2597 |
-
llama_kv_self_defrag(ctx);
|
| 2598 |
-
ret = ctx->decode(batch);
|
| 2599 |
-
|
| 2600 |
-
if (ret == 1) {
|
| 2601 |
-
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
|
| 2602 |
-
|
| 2603 |
-
return ret;
|
| 2604 |
-
}
|
| 2605 |
-
}
|
| 2606 |
-
|
| 2607 |
-
if (ret != 0) {
|
| 2608 |
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
| 2609 |
}
|
| 2610 |
|
|
|
|
| 6 |
#include "llama-model.h"
|
| 7 |
#include "llama-kv-cache.h"
|
| 8 |
|
| 9 |
+
#include <cinttypes>
|
| 10 |
#include <cstring>
|
| 11 |
+
#include <limits>
|
| 12 |
#include <stdexcept>
|
|
|
|
| 13 |
|
| 14 |
//
|
| 15 |
// llama_context
|
|
|
|
| 123 |
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
| 124 |
}
|
| 125 |
|
| 126 |
+
if (!params.swa_full && cparams.n_seq_max > 1) {
|
| 127 |
+
LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
|
| 128 |
+
__func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
if (!hparams.vocab_only) {
|
| 132 |
// GPU backends
|
| 133 |
for (auto * dev : model.devices) {
|
|
|
|
| 265 |
|
| 266 |
// reserve worst-case graph
|
| 267 |
if (!hparams.vocab_only && memory) {
|
| 268 |
+
const uint32_t n_seqs = cparams.n_seq_max;
|
| 269 |
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
| 270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
| 272 |
|
| 273 |
int n_splits_pp = -1;
|
|
|
|
| 279 |
// simulate full KV cache
|
| 280 |
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
| 281 |
|
| 282 |
+
const auto kv_state = kv_self->init_full();
|
| 283 |
+
if (!kv_state) {
|
| 284 |
+
throw std::runtime_error("failed to initialize KV cache");
|
| 285 |
+
}
|
| 286 |
|
| 287 |
cross.v_embd.clear();
|
| 288 |
|
| 289 |
// reserve pp graph first so that buffers are only allocated once
|
| 290 |
{
|
| 291 |
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
| 292 |
+
if (!gf) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
throw std::runtime_error("failed to allocate compute pp buffers");
|
| 294 |
}
|
| 295 |
|
|
|
|
| 299 |
|
| 300 |
// reserve with tg graph to get the number of splits and nodes
|
| 301 |
{
|
| 302 |
+
auto * gf = graph_reserve(1, 1, 1, kv_state.get());
|
| 303 |
+
if (!gf) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
throw std::runtime_error("failed to allocate compute tg buffers");
|
| 305 |
}
|
| 306 |
|
|
|
|
| 310 |
|
| 311 |
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
| 312 |
{
|
| 313 |
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
| 314 |
+
if (!gf) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
throw std::runtime_error("failed to allocate compute pp buffers");
|
| 316 |
}
|
| 317 |
}
|
| 318 |
|
|
|
|
|
|
|
| 319 |
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
| 320 |
ggml_backend_t backend = backend_ptrs[i];
|
| 321 |
ggml_backend_buffer_type_t buft = backend_buft[i];
|
|
|
|
| 429 |
return kv_self;
|
| 430 |
}
|
| 431 |
|
| 432 |
+
bool llama_context::kv_self_update() {
|
| 433 |
+
if (!memory) {
|
| 434 |
+
return false;
|
| 435 |
+
}
|
| 436 |
|
| 437 |
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
| 438 |
|
| 439 |
+
if (!kv_self->update(*this)) {
|
| 440 |
+
// no updates have been performed
|
| 441 |
+
return false;
|
| 442 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
|
| 444 |
+
// if the KV cache did any computation, we have to reserve a new worst-case graph
|
| 445 |
+
const auto kv_state = kv_self->init_full();
|
| 446 |
+
if (!kv_state) {
|
| 447 |
+
throw std::runtime_error("failed to initialize KV cache");
|
| 448 |
+
}
|
| 449 |
|
| 450 |
+
const uint32_t n_seqs = cparams.n_seq_max;
|
| 451 |
+
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
| 452 |
|
| 453 |
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
| 454 |
+
if (!gf) {
|
| 455 |
+
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
|
|
|
|
|
|
|
| 456 |
}
|
| 457 |
+
|
| 458 |
+
return true;
|
| 459 |
}
|
| 460 |
|
| 461 |
enum llama_pooling_type llama_context::pooling_type() const {
|
|
|
|
| 649 |
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
| 650 |
}
|
| 651 |
|
| 652 |
+
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) {
|
| 653 |
+
if (mstate && !mstate->apply()) {
|
| 654 |
+
LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__);
|
| 655 |
+
ret = GGML_STATUS_FAILED;
|
| 656 |
+
return nullptr;
|
| 657 |
+
}
|
| 658 |
+
|
| 659 |
+
auto * gf = graph_init();
|
| 660 |
+
if (!gf) {
|
| 661 |
+
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
|
| 662 |
+
ret = GGML_STATUS_FAILED;
|
| 663 |
+
return nullptr;
|
| 664 |
+
}
|
| 665 |
+
|
| 666 |
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate);
|
| 667 |
+
if (!res) {
|
| 668 |
+
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
|
| 669 |
+
ret = GGML_STATUS_FAILED;
|
| 670 |
+
return nullptr;
|
| 671 |
+
}
|
| 672 |
+
|
| 673 |
+
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
| 674 |
+
|
| 675 |
+
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
|
| 676 |
+
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
|
| 677 |
+
ret = GGML_STATUS_ALLOC_FAILED;
|
| 678 |
+
return nullptr;
|
| 679 |
+
}
|
| 680 |
+
|
| 681 |
+
res->set_inputs(&ubatch);
|
| 682 |
+
|
| 683 |
+
const auto status = graph_compute(gf, ubatch.n_tokens > 1);
|
| 684 |
+
if (status != GGML_STATUS_SUCCESS) {
|
| 685 |
+
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
|
| 686 |
+
ret = status;
|
| 687 |
+
return nullptr;
|
| 688 |
+
}
|
| 689 |
+
|
| 690 |
+
ret = GGML_STATUS_SUCCESS;
|
| 691 |
+
|
| 692 |
+
return res;
|
| 693 |
+
}
|
| 694 |
+
|
| 695 |
int llama_context::encode(llama_batch & inp_batch) {
|
| 696 |
if (inp_batch.n_tokens == 0) {
|
| 697 |
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
|
|
|
| 753 |
|
| 754 |
n_outputs = n_tokens;
|
| 755 |
|
|
|
|
|
|
|
| 756 |
ggml_backend_sched_reset(sched.get());
|
| 757 |
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
| 758 |
|
|
|
|
| 763 |
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
|
| 764 |
cparams.causal_attn = false;
|
| 765 |
|
| 766 |
+
ggml_status status;
|
| 767 |
+
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 768 |
|
| 769 |
cparams.causal_attn = causal_attn_org;
|
| 770 |
|
| 771 |
+
if (!res) {
|
| 772 |
+
switch (status) {
|
| 773 |
+
case GGML_STATUS_ABORTED: return 2;
|
| 774 |
+
case GGML_STATUS_ALLOC_FAILED: return -2;
|
| 775 |
+
case GGML_STATUS_FAILED: return -3;
|
| 776 |
+
case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
|
| 777 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 778 |
}
|
| 779 |
|
| 780 |
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
|
|
|
| 895 |
const int64_t n_tokens_all = batch.n_tokens;
|
| 896 |
const int64_t n_embd = hparams.n_embd;
|
| 897 |
|
|
|
|
|
|
|
| 898 |
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
| 899 |
|
| 900 |
// TODO: move the validation to the llama_batch_allocr
|
|
|
|
| 940 |
n_outputs_all = 1;
|
| 941 |
}
|
| 942 |
|
| 943 |
+
// handle any pending defrags/shifts
|
| 944 |
+
kv_self_update();
|
| 945 |
+
|
| 946 |
+
llama_memory_state_ptr kv_state;
|
| 947 |
+
|
| 948 |
+
bool did_defrag = false;
|
| 949 |
+
|
| 950 |
+
while (true) {
|
| 951 |
+
kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
|
| 952 |
+
if (!kv_state) {
|
| 953 |
+
return -2;
|
| 954 |
+
}
|
| 955 |
+
|
| 956 |
+
switch (kv_state->get_status()) {
|
| 957 |
+
case LLAMA_MEMORY_STATUS_SUCCESS:
|
| 958 |
+
{
|
| 959 |
+
} break;
|
| 960 |
+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
| 961 |
+
{
|
| 962 |
+
if (!did_defrag) {
|
| 963 |
+
did_defrag = true;
|
| 964 |
+
|
| 965 |
+
kv_self->defrag_sched(-1.0f);
|
| 966 |
+
if (kv_self_update()) {
|
| 967 |
+
LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
|
| 968 |
+
|
| 969 |
+
continue;
|
| 970 |
+
}
|
| 971 |
+
}
|
| 972 |
+
|
| 973 |
+
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
|
| 974 |
+
|
| 975 |
+
return 1;
|
| 976 |
+
}
|
| 977 |
+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
| 978 |
+
{
|
| 979 |
+
return -2;
|
| 980 |
+
}
|
| 981 |
+
}
|
| 982 |
+
|
| 983 |
+
break;
|
| 984 |
+
}
|
| 985 |
|
| 986 |
// reserve output buffer
|
| 987 |
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
|
|
|
| 989 |
return -2;
|
| 990 |
};
|
| 991 |
|
|
|
|
|
|
|
|
|
|
| 992 |
int64_t n_outputs_prev = 0;
|
| 993 |
|
| 994 |
+
do {
|
| 995 |
+
const auto & ubatch = kv_state->get_ubatch();
|
| 996 |
|
| 997 |
// count the outputs in this u_batch
|
| 998 |
{
|
|
|
|
| 1011 |
n_outputs = n_outputs_new;
|
| 1012 |
}
|
| 1013 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1014 |
ggml_backend_sched_reset(sched.get());
|
| 1015 |
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
| 1016 |
|
| 1017 |
+
ggml_status status;
|
| 1018 |
+
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status);
|
| 1019 |
+
|
| 1020 |
+
if (!res) {
|
| 1021 |
+
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
| 1022 |
+
llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits<llama_pos>::max() };
|
| 1023 |
|
| 1024 |
+
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
| 1025 |
+
const auto & seq_id = ubatch.seq_id[i][0];
|
| 1026 |
|
| 1027 |
+
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
|
| 1028 |
+
}
|
| 1029 |
|
| 1030 |
+
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
| 1031 |
+
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
|
| 1032 |
+
continue;
|
| 1033 |
+
}
|
| 1034 |
|
| 1035 |
+
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
|
| 1036 |
+
|
| 1037 |
+
llama_kv_self_seq_rm(this, s, pos_min[s], -1);
|
| 1038 |
+
}
|
| 1039 |
+
|
| 1040 |
+
switch (status) {
|
| 1041 |
+
case GGML_STATUS_ABORTED: return 2;
|
| 1042 |
+
case GGML_STATUS_ALLOC_FAILED: return -2;
|
| 1043 |
+
case GGML_STATUS_FAILED: return -3;
|
| 1044 |
+
case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
|
| 1045 |
}
|
| 1046 |
}
|
| 1047 |
|
|
|
|
| 1128 |
}
|
| 1129 |
|
| 1130 |
n_outputs_prev += n_outputs;
|
| 1131 |
+
} while (kv_state->next());
|
|
|
|
|
|
|
|
|
|
| 1132 |
|
| 1133 |
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
| 1134 |
n_outputs = n_outputs_all;
|
|
|
|
| 1137 |
{
|
| 1138 |
bool sorted_output = true;
|
| 1139 |
|
| 1140 |
+
auto & out_ids = kv_state->out_ids();
|
| 1141 |
|
| 1142 |
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
|
| 1143 |
|
|
|
|
| 1297 |
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
|
| 1298 |
}
|
| 1299 |
|
| 1300 |
+
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) {
|
| 1301 |
+
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
| 1302 |
+
|
| 1303 |
+
if (n_tokens % n_seqs != 0) {
|
| 1304 |
+
n_tokens = (n_tokens / n_seqs) * n_seqs;
|
| 1305 |
+
n_outputs = std::min(n_outputs, n_tokens);
|
| 1306 |
+
|
| 1307 |
+
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
|
| 1308 |
+
}
|
| 1309 |
+
|
| 1310 |
+
// store the n_outputs as it is, and restore it afterwards
|
| 1311 |
+
// TODO: not sure if needed, might simplify in the future by removing this
|
| 1312 |
+
const auto save_n_outputs = this->n_outputs;
|
| 1313 |
+
|
| 1314 |
+
this->n_outputs = n_outputs;
|
| 1315 |
+
|
| 1316 |
+
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
| 1317 |
+
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
| 1318 |
+
|
| 1319 |
+
auto * gf = graph_init();
|
| 1320 |
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
|
| 1321 |
+
|
| 1322 |
+
this->n_outputs = save_n_outputs;
|
| 1323 |
+
|
| 1324 |
+
if (!res) {
|
| 1325 |
+
LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
|
| 1326 |
+
return nullptr;
|
| 1327 |
+
}
|
| 1328 |
+
|
| 1329 |
+
ggml_backend_sched_reset(sched.get());
|
| 1330 |
+
|
| 1331 |
+
// initialize scheduler with the specified graph
|
| 1332 |
+
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
| 1333 |
+
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
| 1334 |
+
return nullptr;
|
| 1335 |
+
}
|
| 1336 |
+
|
| 1337 |
+
return gf;
|
| 1338 |
+
}
|
| 1339 |
+
|
| 1340 |
llm_graph_result_ptr llama_context::graph_build(
|
| 1341 |
+
ggml_context * ctx,
|
| 1342 |
+
ggml_cgraph * gf,
|
| 1343 |
+
const llama_ubatch & ubatch,
|
| 1344 |
+
llm_graph_type gtype,
|
| 1345 |
+
const llama_memory_state_i * mstate) {
|
| 1346 |
return model.build_graph(
|
| 1347 |
{
|
| 1348 |
/*.ctx =*/ ctx,
|
|
|
|
| 1354 |
/*.backend_cpu =*/ backend_cpu,
|
| 1355 |
/*.cvec =*/ &cvec,
|
| 1356 |
/*.loras =*/ &loras,
|
| 1357 |
+
/*.mstate =*/ mstate,
|
| 1358 |
/*.cross =*/ &cross,
|
| 1359 |
/*.n_outputs =*/ n_outputs,
|
| 1360 |
/*.cb =*/ graph_get_cb(),
|
|
|
|
| 2035 |
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
| 2036 |
|
| 2037 |
kv_self->clear();
|
|
|
|
| 2038 |
|
| 2039 |
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
|
| 2040 |
batch.n_tokens = n_batch;
|
|
|
|
| 2057 |
|
| 2058 |
int64_t n_outputs_all = n_tokens_all;
|
| 2059 |
|
| 2060 |
+
auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
|
| 2061 |
+
if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
| 2062 |
+
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
| 2063 |
+
break;
|
| 2064 |
+
}
|
| 2065 |
|
| 2066 |
// reserve output buffer
|
| 2067 |
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
|
|
|
| 2069 |
GGML_ABORT("TODO: handle this error");
|
| 2070 |
};
|
| 2071 |
|
| 2072 |
+
uint32_t pos_batch = 0;
|
| 2073 |
+
do {
|
| 2074 |
+
const auto & ubatch = kv_state->get_ubatch();
|
| 2075 |
|
| 2076 |
n_outputs = ubatch.n_tokens;
|
| 2077 |
|
| 2078 |
+
if (!kv_state->apply()) {
|
| 2079 |
+
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
|
| 2080 |
+
break;
|
|
|
|
|
|
|
| 2081 |
}
|
| 2082 |
|
| 2083 |
auto * gf = graph_init();
|
| 2084 |
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get());
|
| 2085 |
|
| 2086 |
struct ggml_context * ctx_compute_opt;
|
| 2087 |
{
|
|
|
|
| 2096 |
}
|
| 2097 |
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
|
| 2098 |
ggml_opt_alloc(opt_ctx, train);
|
| 2099 |
+
|
| 2100 |
res->set_inputs(&ubatch);
|
| 2101 |
{
|
| 2102 |
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
|
|
|
|
| 2114 |
callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
|
| 2115 |
}
|
| 2116 |
ggml_free(ctx_compute_opt);
|
|
|
|
|
|
|
| 2117 |
|
| 2118 |
+
pos_batch += ubatch.n_tokens;
|
| 2119 |
+
} while (kv_state->next());
|
| 2120 |
+
}
|
| 2121 |
}
|
| 2122 |
|
| 2123 |
void llama_context::opt_epoch(
|
|
|
|
| 2281 |
return ctx->get_kv_self();
|
| 2282 |
}
|
| 2283 |
|
| 2284 |
+
// deprecated
|
| 2285 |
void llama_kv_self_update(llama_context * ctx) {
|
| 2286 |
ctx->kv_self_update();
|
| 2287 |
}
|
|
|
|
| 2536 |
return kv->seq_pos_max(seq_id);
|
| 2537 |
}
|
| 2538 |
|
| 2539 |
+
// deprecated
|
| 2540 |
void llama_kv_self_defrag(llama_context * ctx) {
|
| 2541 |
auto * kv = ctx->get_kv_self();
|
| 2542 |
if (!kv) {
|
|
|
|
| 2678 |
int32_t llama_decode(
|
| 2679 |
llama_context * ctx,
|
| 2680 |
llama_batch batch) {
|
| 2681 |
+
const int ret = ctx->decode(batch);
|
| 2682 |
+
if (ret != 0 && ret != 1) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2683 |
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
| 2684 |
}
|
| 2685 |
|
examples/talk-llama/llama-context.h
CHANGED
|
@@ -18,6 +18,9 @@ struct llama_kv_cache;
|
|
| 18 |
class llama_io_read_i;
|
| 19 |
class llama_io_write_i;
|
| 20 |
|
|
|
|
|
|
|
|
|
|
| 21 |
struct llama_context {
|
| 22 |
// init scheduler and compute buffers, reserve worst-case graphs
|
| 23 |
llama_context(
|
|
@@ -47,7 +50,9 @@ struct llama_context {
|
|
| 47 |
llama_kv_cache * get_kv_self();
|
| 48 |
const llama_kv_cache * get_kv_self() const;
|
| 49 |
|
| 50 |
-
|
|
|
|
|
|
|
| 51 |
|
| 52 |
enum llama_pooling_type pooling_type() const;
|
| 53 |
|
|
@@ -88,6 +93,16 @@ struct llama_context {
|
|
| 88 |
int32_t il_start,
|
| 89 |
int32_t il_end);
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
int encode(llama_batch & inp_batch);
|
| 92 |
int decode(llama_batch & inp_batch);
|
| 93 |
|
|
@@ -180,16 +195,18 @@ public:
|
|
| 180 |
ggml_cgraph * graph_init();
|
| 181 |
|
| 182 |
// returns the result of ggml_backend_sched_graph_compute_async execution
|
| 183 |
-
ggml_status graph_compute(
|
| 184 |
-
|
| 185 |
-
|
|
|
|
| 186 |
|
| 187 |
private:
|
| 188 |
llm_graph_result_ptr graph_build(
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
|
|
|
| 193 |
|
| 194 |
llm_graph_cb graph_get_cb() const;
|
| 195 |
|
|
|
|
| 18 |
class llama_io_read_i;
|
| 19 |
class llama_io_write_i;
|
| 20 |
|
| 21 |
+
class llama_memory_i;
|
| 22 |
+
class llama_memory_state_i;
|
| 23 |
+
|
| 24 |
struct llama_context {
|
| 25 |
// init scheduler and compute buffers, reserve worst-case graphs
|
| 26 |
llama_context(
|
|
|
|
| 50 |
llama_kv_cache * get_kv_self();
|
| 51 |
const llama_kv_cache * get_kv_self() const;
|
| 52 |
|
| 53 |
+
// return true of the KV cache was updated
|
| 54 |
+
// TODO: remove
|
| 55 |
+
bool kv_self_update();
|
| 56 |
|
| 57 |
enum llama_pooling_type pooling_type() const;
|
| 58 |
|
|
|
|
| 93 |
int32_t il_start,
|
| 94 |
int32_t il_end);
|
| 95 |
|
| 96 |
+
// process a single ubatch with a specific graph type
|
| 97 |
+
// if memory_state is provided, it will be applied first to the context's memory
|
| 98 |
+
// ret contains the status of the graph computation
|
| 99 |
+
// returns nullptr only if ret != GGML_STATUS_SUCCESS
|
| 100 |
+
llm_graph_result_ptr process_ubatch(
|
| 101 |
+
const llama_ubatch & ubatch,
|
| 102 |
+
llm_graph_type gtype,
|
| 103 |
+
llama_memory_state_i * mstate,
|
| 104 |
+
ggml_status & ret);
|
| 105 |
+
|
| 106 |
int encode(llama_batch & inp_batch);
|
| 107 |
int decode(llama_batch & inp_batch);
|
| 108 |
|
|
|
|
| 195 |
ggml_cgraph * graph_init();
|
| 196 |
|
| 197 |
// returns the result of ggml_backend_sched_graph_compute_async execution
|
| 198 |
+
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
|
| 199 |
+
|
| 200 |
+
// reserve a graph with a dummy ubatch of the specified size
|
| 201 |
+
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate);
|
| 202 |
|
| 203 |
private:
|
| 204 |
llm_graph_result_ptr graph_build(
|
| 205 |
+
ggml_context * ctx,
|
| 206 |
+
ggml_cgraph * gf,
|
| 207 |
+
const llama_ubatch & ubatch,
|
| 208 |
+
llm_graph_type gtype,
|
| 209 |
+
const llama_memory_state_i * mstate);
|
| 210 |
|
| 211 |
llm_graph_cb graph_get_cb() const;
|
| 212 |
|
examples/talk-llama/llama-graph.cpp
CHANGED
|
@@ -3,7 +3,10 @@
|
|
| 3 |
#include "llama-impl.h"
|
| 4 |
#include "llama-batch.h"
|
| 5 |
#include "llama-cparams.h"
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
#include <cassert>
|
| 9 |
#include <cmath>
|
|
@@ -83,7 +86,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
|
| 83 |
|
| 84 |
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
| 85 |
if (pos_bucket) {
|
| 86 |
-
|
| 87 |
}
|
| 88 |
}
|
| 89 |
|
|
@@ -234,7 +237,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|
| 234 |
void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
| 235 |
GGML_UNUSED(ubatch);
|
| 236 |
|
| 237 |
-
const int64_t n_kv =
|
| 238 |
|
| 239 |
if (s_copy) {
|
| 240 |
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
|
@@ -242,7 +245,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
|
| 242 |
|
| 243 |
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
| 244 |
for (uint32_t i = 0; i < n_kv; ++i) {
|
| 245 |
-
data[i] =
|
| 246 |
}
|
| 247 |
}
|
| 248 |
}
|
|
@@ -250,7 +253,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
|
| 250 |
void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
| 251 |
GGML_UNUSED(ubatch);
|
| 252 |
|
| 253 |
-
const int64_t n_kv =
|
| 254 |
|
| 255 |
if (s_mask) {
|
| 256 |
GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
|
|
@@ -258,7 +261,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
|
| 258 |
|
| 259 |
// clear unused states
|
| 260 |
for (int i = 0; i < n_kv; ++i) {
|
| 261 |
-
data[i] =
|
| 262 |
}
|
| 263 |
}
|
| 264 |
}
|
|
@@ -362,17 +365,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
| 362 |
|
| 363 |
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
| 364 |
if (self_kq_mask) {
|
| 365 |
-
|
| 366 |
}
|
| 367 |
}
|
| 368 |
|
| 369 |
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
| 370 |
if (self_kq_mask) {
|
| 371 |
-
|
| 372 |
}
|
| 373 |
|
| 374 |
if (self_kq_mask_swa) {
|
| 375 |
-
|
| 376 |
}
|
| 377 |
}
|
| 378 |
|
|
@@ -448,14 +451,14 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|
| 448 |
backend_cpu (params.backend_cpu),
|
| 449 |
cvec (params.cvec),
|
| 450 |
loras (params.loras),
|
| 451 |
-
|
| 452 |
cross (params.cross),
|
| 453 |
cb_func (params.cb),
|
| 454 |
res (std::make_unique<llm_graph_result>()) {
|
| 455 |
}
|
| 456 |
|
| 457 |
int64_t llm_graph_context::n_pos_per_embd() const {
|
| 458 |
-
return
|
| 459 |
}
|
| 460 |
|
| 461 |
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
|
|
@@ -954,11 +957,11 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
|
|
| 954 |
}
|
| 955 |
|
| 956 |
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
| 957 |
-
const
|
| 958 |
|
| 959 |
-
auto inp = std::make_unique<llm_graph_input_s_copy>(
|
| 960 |
|
| 961 |
-
const auto n_kv =
|
| 962 |
|
| 963 |
auto & cur = inp->s_copy;
|
| 964 |
|
|
@@ -971,11 +974,11 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
|
| 971 |
}
|
| 972 |
|
| 973 |
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
|
| 974 |
-
const
|
| 975 |
|
| 976 |
-
auto inp = std::make_unique<llm_graph_input_s_mask>(
|
| 977 |
|
| 978 |
-
const auto n_kv =
|
| 979 |
|
| 980 |
auto & cur = inp->s_mask;
|
| 981 |
|
|
@@ -1025,11 +1028,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
|
|
| 1025 |
}
|
| 1026 |
|
| 1027 |
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
| 1028 |
-
const
|
| 1029 |
|
| 1030 |
-
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams,
|
| 1031 |
|
| 1032 |
-
const auto n_kv =
|
| 1033 |
|
| 1034 |
auto & cur = inp->pos_bucket;
|
| 1035 |
|
|
@@ -1231,14 +1234,14 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1231 |
}
|
| 1232 |
|
| 1233 |
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
|
| 1234 |
-
const
|
| 1235 |
|
| 1236 |
-
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams,
|
| 1237 |
|
| 1238 |
{
|
| 1239 |
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
| 1240 |
|
| 1241 |
-
const auto n_kv =
|
| 1242 |
|
| 1243 |
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1244 |
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
@@ -1268,19 +1271,19 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1268 |
ggml_build_forward_expand(gf, k_cur);
|
| 1269 |
ggml_build_forward_expand(gf, v_cur);
|
| 1270 |
|
| 1271 |
-
const
|
| 1272 |
|
| 1273 |
// store to KV cache
|
| 1274 |
{
|
| 1275 |
-
ggml_build_forward_expand(gf,
|
| 1276 |
-
ggml_build_forward_expand(gf,
|
| 1277 |
}
|
| 1278 |
|
| 1279 |
const auto & kq_mask = inp->get_kq_mask();
|
| 1280 |
|
| 1281 |
ggml_tensor * q = q_cur;
|
| 1282 |
-
ggml_tensor * k =
|
| 1283 |
-
ggml_tensor * v =
|
| 1284 |
|
| 1285 |
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
| 1286 |
cb(cur, "kqv_out", il);
|
|
@@ -1301,12 +1304,12 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1301 |
}
|
| 1302 |
|
| 1303 |
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
| 1304 |
-
const
|
| 1305 |
|
| 1306 |
-
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams,
|
| 1307 |
|
| 1308 |
{
|
| 1309 |
-
const auto n_kv =
|
| 1310 |
|
| 1311 |
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1312 |
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
@@ -1318,7 +1321,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|
| 1318 |
{
|
| 1319 |
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
| 1320 |
|
| 1321 |
-
const auto n_kv =
|
| 1322 |
|
| 1323 |
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1324 |
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
|
@@ -1348,23 +1351,23 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1348 |
ggml_build_forward_expand(gf, k_cur);
|
| 1349 |
ggml_build_forward_expand(gf, v_cur);
|
| 1350 |
|
| 1351 |
-
const
|
| 1352 |
|
| 1353 |
-
const
|
| 1354 |
|
| 1355 |
-
const auto *
|
| 1356 |
|
| 1357 |
// store to KV cache
|
| 1358 |
{
|
| 1359 |
-
ggml_build_forward_expand(gf,
|
| 1360 |
-
ggml_build_forward_expand(gf,
|
| 1361 |
}
|
| 1362 |
|
| 1363 |
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
| 1364 |
|
| 1365 |
ggml_tensor * q = q_cur;
|
| 1366 |
-
ggml_tensor * k =
|
| 1367 |
-
ggml_tensor * v =
|
| 1368 |
|
| 1369 |
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
| 1370 |
cb(cur, "kqv_out", il);
|
|
@@ -1446,12 +1449,12 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
|
|
| 1446 |
ggml_tensor * state_mask,
|
| 1447 |
int32_t n_state,
|
| 1448 |
int32_t n_seqs) const {
|
| 1449 |
-
const
|
| 1450 |
|
| 1451 |
-
const auto n_kv =
|
| 1452 |
-
const auto kv_head =
|
| 1453 |
|
| 1454 |
-
ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state,
|
| 1455 |
|
| 1456 |
// copy states
|
| 1457 |
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
|
@@ -1478,13 +1481,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
| 1478 |
ggml_tensor * state_mask,
|
| 1479 |
const llama_ubatch & ubatch,
|
| 1480 |
int il) const {
|
| 1481 |
-
const
|
| 1482 |
|
| 1483 |
const auto token_shift_count = hparams.token_shift_count;
|
| 1484 |
|
| 1485 |
const int64_t n_seqs = ubatch.n_seqs;
|
| 1486 |
|
| 1487 |
-
ggml_tensor * token_shift_all =
|
| 1488 |
|
| 1489 |
ggml_tensor * token_shift = build_copy_mask_state(
|
| 1490 |
gf, token_shift_all, state_copy, state_mask,
|
|
@@ -1499,19 +1502,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|
| 1499 |
ggml_tensor * token_shift,
|
| 1500 |
const llama_ubatch & ubatch,
|
| 1501 |
int il) const {
|
| 1502 |
-
const
|
| 1503 |
|
| 1504 |
const auto token_shift_count = hparams.token_shift_count;
|
| 1505 |
const auto n_embd = hparams.n_embd;
|
| 1506 |
|
| 1507 |
const int64_t n_seqs = ubatch.n_seqs;
|
| 1508 |
|
| 1509 |
-
const auto kv_head =
|
| 1510 |
|
| 1511 |
return ggml_cpy(
|
| 1512 |
ctx0,
|
| 1513 |
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
| 1514 |
-
ggml_view_1d(ctx0,
|
| 1515 |
);
|
| 1516 |
}
|
| 1517 |
|
|
@@ -1562,20 +1565,25 @@ void llm_graph_context::build_pooling(
|
|
| 1562 |
ggml_tensor * inp_cls = build_inp_cls();
|
| 1563 |
inp = ggml_get_rows(ctx0, inp, inp_cls);
|
| 1564 |
|
| 1565 |
-
|
| 1566 |
-
|
| 1567 |
-
|
| 1568 |
-
|
| 1569 |
-
|
| 1570 |
-
|
| 1571 |
-
|
| 1572 |
-
|
| 1573 |
-
|
| 1574 |
-
|
| 1575 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1576 |
GGML_ASSERT(cls_out_b != nullptr);
|
| 1577 |
-
|
| 1578 |
-
|
|
|
|
| 1579 |
}
|
| 1580 |
} break;
|
| 1581 |
default:
|
|
|
|
| 3 |
#include "llama-impl.h"
|
| 4 |
#include "llama-batch.h"
|
| 5 |
#include "llama-cparams.h"
|
| 6 |
+
|
| 7 |
+
#include "llama-kv-cache-unified.h"
|
| 8 |
+
#include "llama-kv-cache-unified-iswa.h"
|
| 9 |
+
#include "llama-kv-cache-recurrent.h"
|
| 10 |
|
| 11 |
#include <cassert>
|
| 12 |
#include <cmath>
|
|
|
|
| 86 |
|
| 87 |
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
| 88 |
if (pos_bucket) {
|
| 89 |
+
kv_state->set_input_pos_bucket(pos_bucket, ubatch);
|
| 90 |
}
|
| 91 |
}
|
| 92 |
|
|
|
|
| 237 |
void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
| 238 |
GGML_UNUSED(ubatch);
|
| 239 |
|
| 240 |
+
const int64_t n_kv = kv_state->get_n_kv();
|
| 241 |
|
| 242 |
if (s_copy) {
|
| 243 |
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
|
|
|
| 245 |
|
| 246 |
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
| 247 |
for (uint32_t i = 0; i < n_kv; ++i) {
|
| 248 |
+
data[i] = kv_state->s_copy(i);
|
| 249 |
}
|
| 250 |
}
|
| 251 |
}
|
|
|
|
| 253 |
void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
| 254 |
GGML_UNUSED(ubatch);
|
| 255 |
|
| 256 |
+
const int64_t n_kv = kv_state->get_n_kv();
|
| 257 |
|
| 258 |
if (s_mask) {
|
| 259 |
GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
|
|
|
|
| 261 |
|
| 262 |
// clear unused states
|
| 263 |
for (int i = 0; i < n_kv; ++i) {
|
| 264 |
+
data[i] = kv_state->s_mask(i);
|
| 265 |
}
|
| 266 |
}
|
| 267 |
}
|
|
|
|
| 365 |
|
| 366 |
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
| 367 |
if (self_kq_mask) {
|
| 368 |
+
kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
| 369 |
}
|
| 370 |
}
|
| 371 |
|
| 372 |
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
| 373 |
if (self_kq_mask) {
|
| 374 |
+
kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
| 375 |
}
|
| 376 |
|
| 377 |
if (self_kq_mask_swa) {
|
| 378 |
+
kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
| 379 |
}
|
| 380 |
}
|
| 381 |
|
|
|
|
| 451 |
backend_cpu (params.backend_cpu),
|
| 452 |
cvec (params.cvec),
|
| 453 |
loras (params.loras),
|
| 454 |
+
mstate (params.mstate),
|
| 455 |
cross (params.cross),
|
| 456 |
cb_func (params.cb),
|
| 457 |
res (std::make_unique<llm_graph_result>()) {
|
| 458 |
}
|
| 459 |
|
| 460 |
int64_t llm_graph_context::n_pos_per_embd() const {
|
| 461 |
+
return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
|
| 462 |
}
|
| 463 |
|
| 464 |
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
|
|
|
|
| 957 |
}
|
| 958 |
|
| 959 |
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
| 960 |
+
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
| 961 |
|
| 962 |
+
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
|
| 963 |
|
| 964 |
+
const auto n_kv = kv_state->get_n_kv();
|
| 965 |
|
| 966 |
auto & cur = inp->s_copy;
|
| 967 |
|
|
|
|
| 974 |
}
|
| 975 |
|
| 976 |
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
|
| 977 |
+
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
| 978 |
|
| 979 |
+
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
|
| 980 |
|
| 981 |
+
const auto n_kv = kv_state->get_n_kv();
|
| 982 |
|
| 983 |
auto & cur = inp->s_mask;
|
| 984 |
|
|
|
|
| 1028 |
}
|
| 1029 |
|
| 1030 |
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
| 1031 |
+
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
| 1032 |
|
| 1033 |
+
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
|
| 1034 |
|
| 1035 |
+
const auto n_kv = kv_state->get_n_kv();
|
| 1036 |
|
| 1037 |
auto & cur = inp->pos_bucket;
|
| 1038 |
|
|
|
|
| 1234 |
}
|
| 1235 |
|
| 1236 |
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
|
| 1237 |
+
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
| 1238 |
|
| 1239 |
+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
|
| 1240 |
|
| 1241 |
{
|
| 1242 |
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
| 1243 |
|
| 1244 |
+
const auto n_kv = kv_state->get_n_kv();
|
| 1245 |
|
| 1246 |
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1247 |
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
|
|
| 1271 |
ggml_build_forward_expand(gf, k_cur);
|
| 1272 |
ggml_build_forward_expand(gf, v_cur);
|
| 1273 |
|
| 1274 |
+
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
| 1275 |
|
| 1276 |
// store to KV cache
|
| 1277 |
{
|
| 1278 |
+
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
| 1279 |
+
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
| 1280 |
}
|
| 1281 |
|
| 1282 |
const auto & kq_mask = inp->get_kq_mask();
|
| 1283 |
|
| 1284 |
ggml_tensor * q = q_cur;
|
| 1285 |
+
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
| 1286 |
+
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
| 1287 |
|
| 1288 |
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
| 1289 |
cb(cur, "kqv_out", il);
|
|
|
|
| 1304 |
}
|
| 1305 |
|
| 1306 |
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
| 1307 |
+
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
| 1308 |
|
| 1309 |
+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
|
| 1310 |
|
| 1311 |
{
|
| 1312 |
+
const auto n_kv = kv_state->get_base()->get_n_kv();
|
| 1313 |
|
| 1314 |
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1315 |
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
|
|
| 1321 |
{
|
| 1322 |
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
| 1323 |
|
| 1324 |
+
const auto n_kv = kv_state->get_swa()->get_n_kv();
|
| 1325 |
|
| 1326 |
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1327 |
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
|
|
|
| 1351 |
ggml_build_forward_expand(gf, k_cur);
|
| 1352 |
ggml_build_forward_expand(gf, v_cur);
|
| 1353 |
|
| 1354 |
+
const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
| 1355 |
|
| 1356 |
+
const bool is_swa = hparams.is_swa(il);
|
| 1357 |
|
| 1358 |
+
const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base();
|
| 1359 |
|
| 1360 |
// store to KV cache
|
| 1361 |
{
|
| 1362 |
+
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
| 1363 |
+
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
| 1364 |
}
|
| 1365 |
|
| 1366 |
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
| 1367 |
|
| 1368 |
ggml_tensor * q = q_cur;
|
| 1369 |
+
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
| 1370 |
+
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
| 1371 |
|
| 1372 |
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
| 1373 |
cb(cur, "kqv_out", il);
|
|
|
|
| 1449 |
ggml_tensor * state_mask,
|
| 1450 |
int32_t n_state,
|
| 1451 |
int32_t n_seqs) const {
|
| 1452 |
+
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
| 1453 |
|
| 1454 |
+
const auto n_kv = kv_state->get_n_kv();
|
| 1455 |
+
const auto kv_head = kv_state->get_head();
|
| 1456 |
|
| 1457 |
+
ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size());
|
| 1458 |
|
| 1459 |
// copy states
|
| 1460 |
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
|
|
|
| 1481 |
ggml_tensor * state_mask,
|
| 1482 |
const llama_ubatch & ubatch,
|
| 1483 |
int il) const {
|
| 1484 |
+
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
| 1485 |
|
| 1486 |
const auto token_shift_count = hparams.token_shift_count;
|
| 1487 |
|
| 1488 |
const int64_t n_seqs = ubatch.n_seqs;
|
| 1489 |
|
| 1490 |
+
ggml_tensor * token_shift_all = kv_state->get_k_l(il);
|
| 1491 |
|
| 1492 |
ggml_tensor * token_shift = build_copy_mask_state(
|
| 1493 |
gf, token_shift_all, state_copy, state_mask,
|
|
|
|
| 1502 |
ggml_tensor * token_shift,
|
| 1503 |
const llama_ubatch & ubatch,
|
| 1504 |
int il) const {
|
| 1505 |
+
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
| 1506 |
|
| 1507 |
const auto token_shift_count = hparams.token_shift_count;
|
| 1508 |
const auto n_embd = hparams.n_embd;
|
| 1509 |
|
| 1510 |
const int64_t n_seqs = ubatch.n_seqs;
|
| 1511 |
|
| 1512 |
+
const auto kv_head = kv_state->get_head();
|
| 1513 |
|
| 1514 |
return ggml_cpy(
|
| 1515 |
ctx0,
|
| 1516 |
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
| 1517 |
+
ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il)))
|
| 1518 |
);
|
| 1519 |
}
|
| 1520 |
|
|
|
|
| 1565 |
ggml_tensor * inp_cls = build_inp_cls();
|
| 1566 |
inp = ggml_get_rows(ctx0, inp, inp_cls);
|
| 1567 |
|
| 1568 |
+
if (cls != nullptr && cls_b != nullptr) {
|
| 1569 |
+
// classification head
|
| 1570 |
+
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
| 1571 |
+
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
|
| 1572 |
+
cur = ggml_tanh(ctx0, cur);
|
| 1573 |
+
|
| 1574 |
+
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
| 1575 |
+
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
|
| 1576 |
+
if (cls_out) {
|
| 1577 |
+
GGML_ASSERT(cls_out_b != nullptr);
|
| 1578 |
+
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
|
| 1579 |
+
}
|
| 1580 |
+
} else if (cls_out) {
|
| 1581 |
+
// Single layer classification head (direct projection)
|
| 1582 |
+
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
|
| 1583 |
GGML_ASSERT(cls_out_b != nullptr);
|
| 1584 |
+
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
|
| 1585 |
+
} else {
|
| 1586 |
+
GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
|
| 1587 |
}
|
| 1588 |
} break;
|
| 1589 |
default:
|
examples/talk-llama/llama-graph.h
CHANGED
|
@@ -17,10 +17,11 @@ struct ggml_tensor;
|
|
| 17 |
struct llama_ubatch;
|
| 18 |
struct llama_cparams;
|
| 19 |
|
| 20 |
-
class
|
| 21 |
-
|
| 22 |
-
class
|
| 23 |
-
class
|
|
|
|
| 24 |
|
| 25 |
// certain models (typically multi-modal) can produce different types of graphs
|
| 26 |
enum llm_graph_type {
|
|
@@ -133,7 +134,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
|
|
| 133 |
public:
|
| 134 |
llm_graph_input_pos_bucket_kv(
|
| 135 |
const llama_hparams & hparams,
|
| 136 |
-
const
|
| 137 |
virtual ~llm_graph_input_pos_bucket_kv() = default;
|
| 138 |
|
| 139 |
void set_input(const llama_ubatch * ubatch) override;
|
|
@@ -141,7 +142,7 @@ public:
|
|
| 141 |
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
|
| 142 |
|
| 143 |
const llama_hparams & hparams;
|
| 144 |
-
const
|
| 145 |
};
|
| 146 |
|
| 147 |
class llm_graph_input_out_ids : public llm_graph_input_i {
|
|
@@ -188,26 +189,26 @@ public:
|
|
| 188 |
|
| 189 |
class llm_graph_input_s_copy : public llm_graph_input_i {
|
| 190 |
public:
|
| 191 |
-
llm_graph_input_s_copy(const
|
| 192 |
virtual ~llm_graph_input_s_copy() = default;
|
| 193 |
|
| 194 |
void set_input(const llama_ubatch * ubatch) override;
|
| 195 |
|
| 196 |
ggml_tensor * s_copy; // I32 [kv_size]
|
| 197 |
|
| 198 |
-
const
|
| 199 |
};
|
| 200 |
|
| 201 |
class llm_graph_input_s_mask : public llm_graph_input_i {
|
| 202 |
public:
|
| 203 |
-
llm_graph_input_s_mask(const
|
| 204 |
virtual ~llm_graph_input_s_mask() = default;
|
| 205 |
|
| 206 |
void set_input(const llama_ubatch * ubatch) override;
|
| 207 |
|
| 208 |
ggml_tensor * s_mask; // F32 [1, n_kv]
|
| 209 |
|
| 210 |
-
const
|
| 211 |
};
|
| 212 |
|
| 213 |
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
|
@@ -247,10 +248,10 @@ public:
|
|
| 247 |
llm_graph_input_attn_kv_unified(
|
| 248 |
const llama_hparams & hparams,
|
| 249 |
const llama_cparams & cparams,
|
| 250 |
-
const
|
| 251 |
hparams(hparams),
|
| 252 |
cparams(cparams),
|
| 253 |
-
|
| 254 |
}
|
| 255 |
~llm_graph_input_attn_kv_unified() = default;
|
| 256 |
|
|
@@ -264,7 +265,7 @@ public:
|
|
| 264 |
const llama_hparams & hparams;
|
| 265 |
const llama_cparams & cparams;
|
| 266 |
|
| 267 |
-
const
|
| 268 |
};
|
| 269 |
|
| 270 |
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
|
@@ -272,10 +273,10 @@ public:
|
|
| 272 |
llm_graph_input_attn_kv_unified_iswa(
|
| 273 |
const llama_hparams & hparams,
|
| 274 |
const llama_cparams & cparams,
|
| 275 |
-
const
|
| 276 |
hparams(hparams),
|
| 277 |
cparams(cparams),
|
| 278 |
-
|
| 279 |
}
|
| 280 |
~llm_graph_input_attn_kv_unified_iswa() = default;
|
| 281 |
|
|
@@ -292,7 +293,7 @@ public:
|
|
| 292 |
const llama_hparams & hparams;
|
| 293 |
const llama_cparams & cparams;
|
| 294 |
|
| 295 |
-
const
|
| 296 |
};
|
| 297 |
|
| 298 |
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
|
@@ -383,10 +384,10 @@ struct llm_graph_params {
|
|
| 383 |
ggml_backend_sched_t sched;
|
| 384 |
ggml_backend_t backend_cpu;
|
| 385 |
|
| 386 |
-
const llama_adapter_cvec
|
| 387 |
-
const llama_adapter_loras
|
| 388 |
-
const
|
| 389 |
-
const llama_cross
|
| 390 |
|
| 391 |
int32_t n_outputs;
|
| 392 |
|
|
@@ -435,10 +436,10 @@ struct llm_graph_context {
|
|
| 435 |
|
| 436 |
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
| 437 |
|
| 438 |
-
const llama_adapter_cvec
|
| 439 |
-
const llama_adapter_loras
|
| 440 |
-
const
|
| 441 |
-
const llama_cross
|
| 442 |
|
| 443 |
const llm_graph_cb & cb_func;
|
| 444 |
|
|
|
|
| 17 |
struct llama_ubatch;
|
| 18 |
struct llama_cparams;
|
| 19 |
|
| 20 |
+
class llama_memory_state_i;
|
| 21 |
+
|
| 22 |
+
class llama_kv_cache_unified_state;
|
| 23 |
+
class llama_kv_cache_unified_iswa_state;
|
| 24 |
+
class llama_kv_cache_recurrent_state;
|
| 25 |
|
| 26 |
// certain models (typically multi-modal) can produce different types of graphs
|
| 27 |
enum llm_graph_type {
|
|
|
|
| 134 |
public:
|
| 135 |
llm_graph_input_pos_bucket_kv(
|
| 136 |
const llama_hparams & hparams,
|
| 137 |
+
const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
|
| 138 |
virtual ~llm_graph_input_pos_bucket_kv() = default;
|
| 139 |
|
| 140 |
void set_input(const llama_ubatch * ubatch) override;
|
|
|
|
| 142 |
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
|
| 143 |
|
| 144 |
const llama_hparams & hparams;
|
| 145 |
+
const llama_kv_cache_unified_state * kv_state;
|
| 146 |
};
|
| 147 |
|
| 148 |
class llm_graph_input_out_ids : public llm_graph_input_i {
|
|
|
|
| 189 |
|
| 190 |
class llm_graph_input_s_copy : public llm_graph_input_i {
|
| 191 |
public:
|
| 192 |
+
llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
|
| 193 |
virtual ~llm_graph_input_s_copy() = default;
|
| 194 |
|
| 195 |
void set_input(const llama_ubatch * ubatch) override;
|
| 196 |
|
| 197 |
ggml_tensor * s_copy; // I32 [kv_size]
|
| 198 |
|
| 199 |
+
const llama_kv_cache_recurrent_state * kv_state;
|
| 200 |
};
|
| 201 |
|
| 202 |
class llm_graph_input_s_mask : public llm_graph_input_i {
|
| 203 |
public:
|
| 204 |
+
llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
|
| 205 |
virtual ~llm_graph_input_s_mask() = default;
|
| 206 |
|
| 207 |
void set_input(const llama_ubatch * ubatch) override;
|
| 208 |
|
| 209 |
ggml_tensor * s_mask; // F32 [1, n_kv]
|
| 210 |
|
| 211 |
+
const llama_kv_cache_recurrent_state * kv_state;
|
| 212 |
};
|
| 213 |
|
| 214 |
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
|
|
|
| 248 |
llm_graph_input_attn_kv_unified(
|
| 249 |
const llama_hparams & hparams,
|
| 250 |
const llama_cparams & cparams,
|
| 251 |
+
const llama_kv_cache_unified_state * kv_state) :
|
| 252 |
hparams(hparams),
|
| 253 |
cparams(cparams),
|
| 254 |
+
kv_state(kv_state) {
|
| 255 |
}
|
| 256 |
~llm_graph_input_attn_kv_unified() = default;
|
| 257 |
|
|
|
|
| 265 |
const llama_hparams & hparams;
|
| 266 |
const llama_cparams & cparams;
|
| 267 |
|
| 268 |
+
const llama_kv_cache_unified_state * kv_state;
|
| 269 |
};
|
| 270 |
|
| 271 |
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
|
|
|
| 273 |
llm_graph_input_attn_kv_unified_iswa(
|
| 274 |
const llama_hparams & hparams,
|
| 275 |
const llama_cparams & cparams,
|
| 276 |
+
const llama_kv_cache_unified_iswa_state * kv_state) :
|
| 277 |
hparams(hparams),
|
| 278 |
cparams(cparams),
|
| 279 |
+
kv_state(kv_state) {
|
| 280 |
}
|
| 281 |
~llm_graph_input_attn_kv_unified_iswa() = default;
|
| 282 |
|
|
|
|
| 293 |
const llama_hparams & hparams;
|
| 294 |
const llama_cparams & cparams;
|
| 295 |
|
| 296 |
+
const llama_kv_cache_unified_iswa_state * kv_state;
|
| 297 |
};
|
| 298 |
|
| 299 |
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
|
|
|
| 384 |
ggml_backend_sched_t sched;
|
| 385 |
ggml_backend_t backend_cpu;
|
| 386 |
|
| 387 |
+
const llama_adapter_cvec * cvec;
|
| 388 |
+
const llama_adapter_loras * loras;
|
| 389 |
+
const llama_memory_state_i * mstate;
|
| 390 |
+
const llama_cross * cross;
|
| 391 |
|
| 392 |
int32_t n_outputs;
|
| 393 |
|
|
|
|
| 436 |
|
| 437 |
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
| 438 |
|
| 439 |
+
const llama_adapter_cvec * cvec;
|
| 440 |
+
const llama_adapter_loras * loras;
|
| 441 |
+
const llama_memory_state_i * mstate;
|
| 442 |
+
const llama_cross * cross;
|
| 443 |
|
| 444 |
const llm_graph_cb & cb_func;
|
| 445 |
|
examples/talk-llama/llama-hparams.h
CHANGED
|
@@ -131,6 +131,9 @@ struct llama_hparams {
|
|
| 131 |
bool attn_soft_cap = false;
|
| 132 |
bool use_kq_norm = true;
|
| 133 |
|
|
|
|
|
|
|
|
|
|
| 134 |
// llama4
|
| 135 |
uint32_t n_moe_layer_step = 0;
|
| 136 |
uint32_t n_no_rope_layer_step = 4;
|
|
|
|
| 131 |
bool attn_soft_cap = false;
|
| 132 |
bool use_kq_norm = true;
|
| 133 |
|
| 134 |
+
// for Classifiers
|
| 135 |
+
uint32_t n_cls_out = 1;
|
| 136 |
+
|
| 137 |
// llama4
|
| 138 |
uint32_t n_moe_layer_step = 0;
|
| 139 |
uint32_t n_no_rope_layer_step = 4;
|
examples/talk-llama/llama-kv-cache-recurrent.cpp
ADDED
|
@@ -0,0 +1,1132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "llama-kv-cache-recurrent.h"
|
| 2 |
+
|
| 3 |
+
#include "llama-impl.h"
|
| 4 |
+
#include "llama-batch.h"
|
| 5 |
+
#include "llama-model.h"
|
| 6 |
+
|
| 7 |
+
#include <algorithm>
|
| 8 |
+
#include <cassert>
|
| 9 |
+
#include <limits>
|
| 10 |
+
#include <map>
|
| 11 |
+
#include <stdexcept>
|
| 12 |
+
|
| 13 |
+
//
|
| 14 |
+
// llama_kv_cache_recurrent
|
| 15 |
+
//
|
| 16 |
+
|
| 17 |
+
llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
| 18 |
+
const llama_model & model,
|
| 19 |
+
ggml_type type_k,
|
| 20 |
+
ggml_type type_v,
|
| 21 |
+
bool offload,
|
| 22 |
+
uint32_t kv_size,
|
| 23 |
+
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
|
| 24 |
+
const int32_t n_layer = hparams.n_layer;
|
| 25 |
+
|
| 26 |
+
LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
|
| 27 |
+
__func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
|
| 28 |
+
|
| 29 |
+
head = 0;
|
| 30 |
+
size = kv_size;
|
| 31 |
+
used = 0;
|
| 32 |
+
|
| 33 |
+
cells.clear();
|
| 34 |
+
cells.resize(kv_size);
|
| 35 |
+
|
| 36 |
+
// create a context for each buffer type
|
| 37 |
+
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
| 38 |
+
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
| 39 |
+
auto it = ctx_map.find(buft);
|
| 40 |
+
if (it == ctx_map.end()) {
|
| 41 |
+
ggml_init_params params = {
|
| 42 |
+
/*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
|
| 43 |
+
/*.mem_buffer =*/ NULL,
|
| 44 |
+
/*.no_alloc =*/ true,
|
| 45 |
+
};
|
| 46 |
+
|
| 47 |
+
ggml_context * ctx = ggml_init(params);
|
| 48 |
+
if (!ctx) {
|
| 49 |
+
return nullptr;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
ctx_map[buft] = ctx;
|
| 53 |
+
ctxs.emplace_back(ctx);
|
| 54 |
+
|
| 55 |
+
return ctx;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
return it->second;
|
| 59 |
+
};
|
| 60 |
+
|
| 61 |
+
k_l.reserve(n_layer);
|
| 62 |
+
v_l.reserve(n_layer);
|
| 63 |
+
|
| 64 |
+
for (int i = 0; i < n_layer; i++) {
|
| 65 |
+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
|
| 66 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
|
| 67 |
+
|
| 68 |
+
const char * dev_name = "CPU";
|
| 69 |
+
|
| 70 |
+
ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
|
| 71 |
+
|
| 72 |
+
if (offload) {
|
| 73 |
+
auto * dev = model.dev_layer(i);
|
| 74 |
+
buft = ggml_backend_dev_buffer_type(dev);
|
| 75 |
+
|
| 76 |
+
dev_name = ggml_backend_dev_name(dev);
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name);
|
| 80 |
+
|
| 81 |
+
ggml_context * ctx = ctx_for_buft(buft);
|
| 82 |
+
if (!ctx) {
|
| 83 |
+
throw std::runtime_error("failed to create ggml context for kv cache");
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
|
| 87 |
+
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
|
| 88 |
+
ggml_format_name(k, "cache_k_l%d", i);
|
| 89 |
+
ggml_format_name(v, "cache_v_l%d", i);
|
| 90 |
+
k_l.push_back(k);
|
| 91 |
+
v_l.push_back(v);
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
| 95 |
+
for (auto it : ctx_map) {
|
| 96 |
+
auto * buft = it.first;
|
| 97 |
+
auto * ctx = it.second;
|
| 98 |
+
|
| 99 |
+
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
| 100 |
+
if (!buf) {
|
| 101 |
+
throw std::runtime_error("failed to allocate buffer for kv cache");
|
| 102 |
+
}
|
| 103 |
+
ggml_backend_buffer_clear(buf, 0);
|
| 104 |
+
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
| 105 |
+
bufs.emplace_back(buf);
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
{
|
| 109 |
+
const size_t memory_size_k = size_k_bytes();
|
| 110 |
+
const size_t memory_size_v = size_v_bytes();
|
| 111 |
+
|
| 112 |
+
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
| 113 |
+
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
| 114 |
+
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
| 115 |
+
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
| 116 |
+
}
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
void llama_kv_cache_recurrent::clear() {
|
| 120 |
+
for (int32_t i = 0; i < (int32_t) size; ++i) {
|
| 121 |
+
cells[i].pos = -1;
|
| 122 |
+
cells[i].seq_id.clear();
|
| 123 |
+
cells[i].src = -1;
|
| 124 |
+
cells[i].tail = -1;
|
| 125 |
+
}
|
| 126 |
+
head = 0;
|
| 127 |
+
used = 0;
|
| 128 |
+
|
| 129 |
+
for (auto & buf : bufs) {
|
| 130 |
+
ggml_backend_buffer_clear(buf.get(), 0);
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
| 135 |
+
uint32_t new_head = size;
|
| 136 |
+
|
| 137 |
+
if (p0 < 0) {
|
| 138 |
+
p0 = 0;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
if (p1 < 0) {
|
| 142 |
+
p1 = std::numeric_limits<llama_pos>::max();
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
// models like Mamba or RWKV can't have a state partially erased
|
| 146 |
+
if (seq_id >= (int64_t) size) {
|
| 147 |
+
// could be fatal
|
| 148 |
+
return false;
|
| 149 |
+
}
|
| 150 |
+
if (0 <= seq_id) {
|
| 151 |
+
int32_t & tail_id = cells[seq_id].tail;
|
| 152 |
+
if (tail_id >= 0) {
|
| 153 |
+
const kv_cell & cell = cells[tail_id];
|
| 154 |
+
// partial intersection is invalid
|
| 155 |
+
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
|
| 156 |
+
return false;
|
| 157 |
+
}
|
| 158 |
+
// invalidate tails which will be cleared
|
| 159 |
+
if (p0 <= cell.pos && cell.pos < p1) {
|
| 160 |
+
tail_id = -1;
|
| 161 |
+
}
|
| 162 |
+
}
|
| 163 |
+
} else {
|
| 164 |
+
// seq_id is negative, then the range should include everything or nothing
|
| 165 |
+
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
|
| 166 |
+
return false;
|
| 167 |
+
}
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
for (uint32_t i = 0; i < size; ++i) {
|
| 171 |
+
if (cells[i].pos >= p0 && cells[i].pos < p1) {
|
| 172 |
+
if (seq_id < 0) {
|
| 173 |
+
cells[i].seq_id.clear();
|
| 174 |
+
} else if (cells[i].has_seq_id(seq_id)) {
|
| 175 |
+
cells[i].seq_id.erase(seq_id);
|
| 176 |
+
} else {
|
| 177 |
+
continue;
|
| 178 |
+
}
|
| 179 |
+
if (cells[i].is_empty()) {
|
| 180 |
+
// keep count of the number of used cells
|
| 181 |
+
if (cells[i].pos >= 0) {
|
| 182 |
+
used--;
|
| 183 |
+
}
|
| 184 |
+
cells[i].pos = -1;
|
| 185 |
+
cells[i].src = -1;
|
| 186 |
+
if (new_head == size) {
|
| 187 |
+
new_head = i;
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
// If we freed up a slot, set head to it so searching can start there.
|
| 194 |
+
if (new_head != size && new_head < head) {
|
| 195 |
+
head = new_head;
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
return true;
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
| 202 |
+
if (seq_id_src == seq_id_dst) {
|
| 203 |
+
return;
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
if (p0 < 0) {
|
| 207 |
+
p0 = 0;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
if (p1 < 0) {
|
| 211 |
+
p1 = std::numeric_limits<llama_pos>::max();
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
|
| 215 |
+
kv_cell & tail_src = cells[seq_id_src];
|
| 216 |
+
kv_cell & tail_dst = cells[seq_id_dst];
|
| 217 |
+
if (tail_dst.tail >= 0) {
|
| 218 |
+
// clear destination seq_id if it wasn't empty
|
| 219 |
+
kv_cell & cell_dst = cells[tail_dst.tail];
|
| 220 |
+
|
| 221 |
+
cell_dst.seq_id.erase(seq_id_dst);
|
| 222 |
+
tail_dst.tail = -1;
|
| 223 |
+
if (cell_dst.seq_id.empty()) {
|
| 224 |
+
cell_dst.pos = -1;
|
| 225 |
+
cell_dst.src = -1;
|
| 226 |
+
used -= 1;
|
| 227 |
+
}
|
| 228 |
+
}
|
| 229 |
+
if (tail_src.tail >= 0) {
|
| 230 |
+
kv_cell & cell_src = cells[tail_src.tail];
|
| 231 |
+
|
| 232 |
+
cell_src.seq_id.insert(seq_id_dst);
|
| 233 |
+
tail_dst.tail = tail_src.tail;
|
| 234 |
+
}
|
| 235 |
+
}
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
|
| 239 |
+
uint32_t new_head = size;
|
| 240 |
+
|
| 241 |
+
for (uint32_t i = 0; i < size; ++i) {
|
| 242 |
+
if ((llama_seq_id) i != seq_id) {
|
| 243 |
+
cells[i].tail = -1;
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
if (!cells[i].has_seq_id(seq_id)) {
|
| 247 |
+
if (cells[i].pos >= 0) {
|
| 248 |
+
used--;
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
cells[i].pos = -1;
|
| 252 |
+
cells[i].src = -1;
|
| 253 |
+
cells[i].seq_id.clear();
|
| 254 |
+
|
| 255 |
+
if (new_head == size){
|
| 256 |
+
new_head = i;
|
| 257 |
+
}
|
| 258 |
+
} else {
|
| 259 |
+
cells[i].seq_id.clear();
|
| 260 |
+
cells[i].seq_id.insert(seq_id);
|
| 261 |
+
}
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
// If we freed up a slot, set head to it so searching can start there.
|
| 265 |
+
if (new_head != size && new_head < head) {
|
| 266 |
+
head = new_head;
|
| 267 |
+
}
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
| 271 |
+
if (shift == 0) {
|
| 272 |
+
return;
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
if (p0 < 0) {
|
| 276 |
+
p0 = 0;
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
if (p1 < 0) {
|
| 280 |
+
p1 = std::numeric_limits<llama_pos>::max();
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
// If there is no range then return early to avoid looping over the
|
| 284 |
+
if (p0 == p1) {
|
| 285 |
+
return;
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
// for Mamba-like or RWKV models, only the pos needs to be shifted
|
| 289 |
+
if (0 <= seq_id && seq_id < (int64_t) size) {
|
| 290 |
+
const int32_t tail_id = cells[seq_id].tail;
|
| 291 |
+
if (tail_id >= 0) {
|
| 292 |
+
kv_cell & cell = cells[tail_id];
|
| 293 |
+
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
| 294 |
+
cell.pos += shift;
|
| 295 |
+
}
|
| 296 |
+
}
|
| 297 |
+
}
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
| 301 |
+
if (d == 1) {
|
| 302 |
+
return;
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
if (p0 < 0) {
|
| 306 |
+
p0 = 0;
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
if (p1 < 0) {
|
| 310 |
+
p1 = std::numeric_limits<llama_pos>::max();
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
// If there is no range then return early to avoid looping over the cache.
|
| 314 |
+
if (p0 == p1) {
|
| 315 |
+
return;
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
// for Mamba-like or RWKV models, only the pos needs to be changed
|
| 319 |
+
if (0 <= seq_id && seq_id < (int64_t) size) {
|
| 320 |
+
const int32_t tail_id = cells[seq_id].tail;
|
| 321 |
+
if (tail_id >= 0) {
|
| 322 |
+
kv_cell & cell = cells[tail_id];
|
| 323 |
+
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
| 324 |
+
cell.pos /= d;
|
| 325 |
+
}
|
| 326 |
+
}
|
| 327 |
+
}
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
|
| 331 |
+
llama_pos result = std::numeric_limits<llama_pos>::max();
|
| 332 |
+
|
| 333 |
+
for (uint32_t i = 0; i < size; ++i) {
|
| 334 |
+
if (cells[i].has_seq_id(seq_id)) {
|
| 335 |
+
result = std::min(result, cells[i].pos);
|
| 336 |
+
}
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
if (result == std::numeric_limits<llama_pos>::max()) {
|
| 340 |
+
result = -1;
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
return result;
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
| 347 |
+
llama_pos result = -1;
|
| 348 |
+
|
| 349 |
+
for (uint32_t i = 0; i < size; ++i) {
|
| 350 |
+
if (cells[i].has_seq_id(seq_id)) {
|
| 351 |
+
result = std::max(result, cells[i].pos);
|
| 352 |
+
}
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
return result;
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
|
| 359 |
+
GGML_UNUSED(embd_pooled);
|
| 360 |
+
|
| 361 |
+
auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
|
| 362 |
+
|
| 363 |
+
std::vector<llama_ubatch> ubatches;
|
| 364 |
+
|
| 365 |
+
while (sbatch.n_tokens > 0) {
|
| 366 |
+
llama_ubatch ubatch;
|
| 367 |
+
|
| 368 |
+
if (embd_pooled) {
|
| 369 |
+
// Pooled embeddings cannot be split across ubatches (yet)
|
| 370 |
+
ubatch = sbatch.split_seq(n_ubatch);
|
| 371 |
+
} else {
|
| 372 |
+
ubatch = sbatch.split_equal(n_ubatch);
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
ubatches.push_back(ubatch);
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
if (!prepare(ubatches)) {
|
| 379 |
+
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches));
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
|
| 386 |
+
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
|
| 390 |
+
// simply remember the full state because it is very small for this type of cache
|
| 391 |
+
// TODO: optimize
|
| 392 |
+
auto org_cells = cells;
|
| 393 |
+
auto org_used = used;
|
| 394 |
+
auto org_head = head;
|
| 395 |
+
|
| 396 |
+
bool success = true;
|
| 397 |
+
|
| 398 |
+
// TODO: here we have to verify that all ubatches can fit in the cells
|
| 399 |
+
// however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells
|
| 400 |
+
// during the compute of each ubatch. to reproduce, uncomment the following loop and run:
|
| 401 |
+
//
|
| 402 |
+
// $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8
|
| 403 |
+
//
|
| 404 |
+
// recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed
|
| 405 |
+
//
|
| 406 |
+
GGML_UNUSED(ubatches);
|
| 407 |
+
//for (const auto & ubatch : ubatches) {
|
| 408 |
+
// if (!find_slot(ubatch)) {
|
| 409 |
+
// success = false;
|
| 410 |
+
// break;
|
| 411 |
+
// }
|
| 412 |
+
//}
|
| 413 |
+
|
| 414 |
+
// restore the original state
|
| 415 |
+
cells = std::move(org_cells);
|
| 416 |
+
used = org_used;
|
| 417 |
+
head = org_head;
|
| 418 |
+
|
| 419 |
+
return success;
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
bool llama_kv_cache_recurrent::update(llama_context & lctx) {
|
| 423 |
+
GGML_UNUSED(lctx);
|
| 424 |
+
// noop
|
| 425 |
+
return false;
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
void llama_kv_cache_recurrent::defrag_sched(float thold) {
|
| 429 |
+
GGML_UNUSED(thold);
|
| 430 |
+
// noop
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
| 434 |
+
const uint32_t n_tokens = ubatch.n_tokens;
|
| 435 |
+
const uint32_t n_seqs = ubatch.n_seqs;
|
| 436 |
+
|
| 437 |
+
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
| 438 |
+
|
| 439 |
+
// if we have enough unused cells before the current head ->
|
| 440 |
+
// better to start searching from the beginning of the cache, hoping to fill it
|
| 441 |
+
if (head > used + 2*n_tokens) {
|
| 442 |
+
head = 0;
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
// For recurrent state architectures (like Mamba or RWKV),
|
| 446 |
+
// each cache cell can store the state for a whole sequence.
|
| 447 |
+
// A slot should be always be contiguous.
|
| 448 |
+
|
| 449 |
+
// can only process batches with an equal number of new tokens in each sequence
|
| 450 |
+
GGML_ASSERT(ubatch.equal_seqs);
|
| 451 |
+
|
| 452 |
+
int32_t min = size - 1;
|
| 453 |
+
int32_t max = 0;
|
| 454 |
+
|
| 455 |
+
// everything should fit if all seq_ids are smaller than the max
|
| 456 |
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 457 |
+
const uint32_t n_seq_id = ubatch.n_seq_id[s];
|
| 458 |
+
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
| 459 |
+
const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
| 460 |
+
|
| 461 |
+
if (seq_id < 0 || (uint32_t) seq_id >= size) {
|
| 462 |
+
// too big seq_id
|
| 463 |
+
// TODO: would it be possible to resize the cache instead?
|
| 464 |
+
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max);
|
| 465 |
+
return false;
|
| 466 |
+
}
|
| 467 |
+
if (j > 0) {
|
| 468 |
+
kv_cell & seq = cells[seq_id];
|
| 469 |
+
if (seq.tail >= 0) {
|
| 470 |
+
kv_cell & cell = cells[seq.tail];
|
| 471 |
+
// clear cells from seq_ids that become shared
|
| 472 |
+
// (should not normally happen, but let's handle it anyway)
|
| 473 |
+
cell.seq_id.erase(seq_id);
|
| 474 |
+
seq.tail = -1;
|
| 475 |
+
if (cell.seq_id.empty()) {
|
| 476 |
+
cell.pos = -1;
|
| 477 |
+
cell.src = -1;
|
| 478 |
+
used -= 1;
|
| 479 |
+
}
|
| 480 |
+
}
|
| 481 |
+
}
|
| 482 |
+
}
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
#ifndef NDEBUG
|
| 486 |
+
{
|
| 487 |
+
std::vector<int32_t> tails_verif;
|
| 488 |
+
tails_verif.assign(size, -1);
|
| 489 |
+
for (uint32_t i = 0; i < size; ++i) {
|
| 490 |
+
kv_cell & cell = cells[i];
|
| 491 |
+
for (llama_seq_id seq_id : cell.seq_id) {
|
| 492 |
+
if (tails_verif[seq_id] != -1) {
|
| 493 |
+
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
|
| 494 |
+
}
|
| 495 |
+
tails_verif[seq_id] = i;
|
| 496 |
+
}
|
| 497 |
+
}
|
| 498 |
+
for (uint32_t i = 0; i < size; ++i) {
|
| 499 |
+
if (tails_verif[i] != cells[i].tail) {
|
| 500 |
+
LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]);
|
| 501 |
+
}
|
| 502 |
+
}
|
| 503 |
+
}
|
| 504 |
+
#endif
|
| 505 |
+
|
| 506 |
+
// find next empty cell
|
| 507 |
+
uint32_t next_empty_cell = head;
|
| 508 |
+
|
| 509 |
+
for (uint32_t i = 0; i < size; ++i) {
|
| 510 |
+
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
| 511 |
+
kv_cell & cell = cells[next_empty_cell];
|
| 512 |
+
if (cell.is_empty()) { break; }
|
| 513 |
+
next_empty_cell += 1;
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
// find usable cell range
|
| 517 |
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 518 |
+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
| 519 |
+
kv_cell & seq_meta = cells[seq_id];
|
| 520 |
+
bool has_cell = false;
|
| 521 |
+
if (seq_meta.tail >= 0) {
|
| 522 |
+
kv_cell & cell = cells[seq_meta.tail];
|
| 523 |
+
GGML_ASSERT(cell.has_seq_id(seq_id));
|
| 524 |
+
// does this seq_id "own" the cell?
|
| 525 |
+
if (cell.seq_id.size() == 1) { has_cell = true; }
|
| 526 |
+
}
|
| 527 |
+
if (!has_cell) {
|
| 528 |
+
kv_cell & empty_cell = cells[next_empty_cell];
|
| 529 |
+
GGML_ASSERT(empty_cell.is_empty());
|
| 530 |
+
// copy old tail into the empty cell
|
| 531 |
+
if (seq_meta.tail >= 0) {
|
| 532 |
+
kv_cell & orig_cell = cells[seq_meta.tail];
|
| 533 |
+
empty_cell.pos = orig_cell.pos;
|
| 534 |
+
empty_cell.src = orig_cell.src;
|
| 535 |
+
orig_cell.seq_id.erase(seq_id);
|
| 536 |
+
empty_cell.seq_id.insert(seq_id); // will be overwritten
|
| 537 |
+
}
|
| 538 |
+
seq_meta.tail = next_empty_cell;
|
| 539 |
+
// find next empty cell
|
| 540 |
+
if (s + 1 < n_seqs) {
|
| 541 |
+
next_empty_cell += 1;
|
| 542 |
+
for (uint32_t i = 0; i < size; ++i) {
|
| 543 |
+
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
| 544 |
+
kv_cell & cell = cells[next_empty_cell];
|
| 545 |
+
if (cell.is_empty()) { break; }
|
| 546 |
+
next_empty_cell += 1;
|
| 547 |
+
}
|
| 548 |
+
}
|
| 549 |
+
}
|
| 550 |
+
if (min > seq_meta.tail) { min = seq_meta.tail; }
|
| 551 |
+
if (max < seq_meta.tail) { max = seq_meta.tail; }
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
// gather and re-order
|
| 555 |
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 556 |
+
int32_t dst_id = s + min;
|
| 557 |
+
int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
|
| 558 |
+
if (dst_id != src_id) {
|
| 559 |
+
kv_cell & dst_cell = cells[dst_id];
|
| 560 |
+
kv_cell & src_cell = cells[src_id];
|
| 561 |
+
|
| 562 |
+
std::swap(dst_cell.pos, src_cell.pos);
|
| 563 |
+
std::swap(dst_cell.src, src_cell.src);
|
| 564 |
+
std::swap(dst_cell.seq_id, src_cell.seq_id);
|
| 565 |
+
|
| 566 |
+
// swap tails (assuming they NEVER overlap)
|
| 567 |
+
for (const llama_seq_id seq_id : src_cell.seq_id) {
|
| 568 |
+
cells[seq_id].tail = src_id;
|
| 569 |
+
}
|
| 570 |
+
for (const llama_seq_id seq_id : dst_cell.seq_id) {
|
| 571 |
+
cells[seq_id].tail = dst_id;
|
| 572 |
+
}
|
| 573 |
+
}
|
| 574 |
+
}
|
| 575 |
+
|
| 576 |
+
// update the pos of the used seqs
|
| 577 |
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 578 |
+
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
|
| 579 |
+
int32_t cell_id = s + min;
|
| 580 |
+
kv_cell & cell = cells[cell_id];
|
| 581 |
+
|
| 582 |
+
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
|
| 583 |
+
// What should happen when the pos backtracks or skips a value?
|
| 584 |
+
// Clearing the state mid-batch would require special-casing which isn't done.
|
| 585 |
+
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
|
| 586 |
+
__func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
|
| 587 |
+
}
|
| 588 |
+
cell.pos = last_pos;
|
| 589 |
+
cell.seq_id.clear();
|
| 590 |
+
for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
|
| 591 |
+
const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
| 592 |
+
cell.seq_id.insert(seq_id);
|
| 593 |
+
cells[seq_id].tail = cell_id;
|
| 594 |
+
}
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
// allow getting the range of used cells, from head to head + n
|
| 598 |
+
head = min;
|
| 599 |
+
n = max - min + 1;
|
| 600 |
+
used = std::count_if(cells.begin(), cells.end(),
|
| 601 |
+
[](const kv_cell & cell){ return !cell.is_empty(); });
|
| 602 |
+
|
| 603 |
+
// sanity check
|
| 604 |
+
return n >= n_seqs;
|
| 605 |
+
}
|
| 606 |
+
|
| 607 |
+
bool llama_kv_cache_recurrent::get_can_shift() const {
|
| 608 |
+
return false;
|
| 609 |
+
}
|
| 610 |
+
|
| 611 |
+
int32_t llama_kv_cache_recurrent::s_copy(int i) const {
|
| 612 |
+
const uint32_t cell_id = i + head;
|
| 613 |
+
|
| 614 |
+
//////////////////////////////////////////////
|
| 615 |
+
// TODO: this should not mutate the KV cache !
|
| 616 |
+
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
|
| 617 |
+
|
| 618 |
+
// prevent out-of-bound sources
|
| 619 |
+
if (cell.src < 0 || (uint32_t) cell.src >= size) {
|
| 620 |
+
cell.src = cell_id;
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
int32_t res = cell.src;
|
| 624 |
+
|
| 625 |
+
// TODO: do not mutate the KV cache
|
| 626 |
+
// ensure copy only happens once
|
| 627 |
+
if (cell.src != (int32_t) cell_id) {
|
| 628 |
+
cell.src = cell_id;
|
| 629 |
+
}
|
| 630 |
+
|
| 631 |
+
return res;
|
| 632 |
+
}
|
| 633 |
+
|
| 634 |
+
float llama_kv_cache_recurrent::s_mask(int i) const {
|
| 635 |
+
const uint32_t cell_id = i + head;
|
| 636 |
+
|
| 637 |
+
//////////////////////////////////////////////
|
| 638 |
+
// TODO: this should not mutate the KV cache !
|
| 639 |
+
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
|
| 640 |
+
|
| 641 |
+
float res = (float) (cell.src >= 0);
|
| 642 |
+
|
| 643 |
+
// only clear once
|
| 644 |
+
if (cell.src < 0) {
|
| 645 |
+
cell.src = cell_id;
|
| 646 |
+
}
|
| 647 |
+
|
| 648 |
+
return res;
|
| 649 |
+
}
|
| 650 |
+
|
| 651 |
+
size_t llama_kv_cache_recurrent::total_size() const {
|
| 652 |
+
size_t size = 0;
|
| 653 |
+
for (const auto & buf : bufs) {
|
| 654 |
+
size += ggml_backend_buffer_get_size(buf.get());
|
| 655 |
+
}
|
| 656 |
+
|
| 657 |
+
return size;
|
| 658 |
+
}
|
| 659 |
+
|
| 660 |
+
size_t llama_kv_cache_recurrent::size_k_bytes() const {
|
| 661 |
+
size_t size_k_bytes = 0;
|
| 662 |
+
|
| 663 |
+
for (const auto & k : k_l) {
|
| 664 |
+
size_k_bytes += ggml_nbytes(k);
|
| 665 |
+
}
|
| 666 |
+
|
| 667 |
+
return size_k_bytes;
|
| 668 |
+
}
|
| 669 |
+
|
| 670 |
+
size_t llama_kv_cache_recurrent::size_v_bytes() const {
|
| 671 |
+
size_t size_v_bytes = 0;
|
| 672 |
+
|
| 673 |
+
for (const auto & v : v_l) {
|
| 674 |
+
size_v_bytes += ggml_nbytes(v);
|
| 675 |
+
}
|
| 676 |
+
|
| 677 |
+
return size_v_bytes;
|
| 678 |
+
}
|
| 679 |
+
|
| 680 |
+
void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
| 681 |
+
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
| 682 |
+
uint32_t cell_count = 0;
|
| 683 |
+
|
| 684 |
+
// Count the number of cells with the specified seq_id
|
| 685 |
+
// Find all the ranges of cells with this seq id (or all, when -1)
|
| 686 |
+
uint32_t cell_range_begin = size;
|
| 687 |
+
for (uint32_t i = 0; i < size; ++i) {
|
| 688 |
+
const auto & cell = cells[i];
|
| 689 |
+
if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
|
| 690 |
+
++cell_count;
|
| 691 |
+
if (cell_range_begin == size) {
|
| 692 |
+
cell_range_begin = i;
|
| 693 |
+
}
|
| 694 |
+
} else {
|
| 695 |
+
if (cell_range_begin != size) {
|
| 696 |
+
cell_ranges.emplace_back(cell_range_begin, i);
|
| 697 |
+
cell_range_begin = size;
|
| 698 |
+
}
|
| 699 |
+
}
|
| 700 |
+
}
|
| 701 |
+
if (cell_range_begin != size) {
|
| 702 |
+
cell_ranges.emplace_back(cell_range_begin, size);
|
| 703 |
+
}
|
| 704 |
+
|
| 705 |
+
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
|
| 706 |
+
uint32_t cell_count_check = 0;
|
| 707 |
+
for (const auto & range : cell_ranges) {
|
| 708 |
+
cell_count_check += range.second - range.first;
|
| 709 |
+
}
|
| 710 |
+
GGML_ASSERT(cell_count == cell_count_check);
|
| 711 |
+
|
| 712 |
+
io.write(&cell_count, sizeof(cell_count));
|
| 713 |
+
|
| 714 |
+
state_write_meta(io, cell_ranges, seq_id);
|
| 715 |
+
state_write_data(io, cell_ranges);
|
| 716 |
+
}
|
| 717 |
+
|
| 718 |
+
void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
| 719 |
+
uint32_t cell_count;
|
| 720 |
+
io.read_to(&cell_count, sizeof(cell_count));
|
| 721 |
+
|
| 722 |
+
bool res = true;
|
| 723 |
+
|
| 724 |
+
res = res && state_read_meta(io, cell_count, seq_id);
|
| 725 |
+
res = res && state_read_data(io, cell_count);
|
| 726 |
+
|
| 727 |
+
if (!res) {
|
| 728 |
+
if (seq_id == -1) {
|
| 729 |
+
clear();
|
| 730 |
+
} else {
|
| 731 |
+
seq_rm(seq_id, -1, -1);
|
| 732 |
+
}
|
| 733 |
+
throw std::runtime_error("failed to restore kv cache");
|
| 734 |
+
}
|
| 735 |
+
}
|
| 736 |
+
|
| 737 |
+
void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
| 738 |
+
for (const auto & range : cell_ranges) {
|
| 739 |
+
for (uint32_t i = range.first; i < range.second; ++i) {
|
| 740 |
+
const auto & cell = cells[i];
|
| 741 |
+
const llama_pos pos = cell.pos;
|
| 742 |
+
const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
|
| 743 |
+
|
| 744 |
+
io.write(&pos, sizeof(pos));
|
| 745 |
+
io.write(&n_seq_id, sizeof(n_seq_id));
|
| 746 |
+
|
| 747 |
+
if (n_seq_id) {
|
| 748 |
+
for (auto seq_id : cell.seq_id) {
|
| 749 |
+
io.write(&seq_id, sizeof(seq_id));
|
| 750 |
+
}
|
| 751 |
+
}
|
| 752 |
+
}
|
| 753 |
+
}
|
| 754 |
+
}
|
| 755 |
+
|
| 756 |
+
void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
| 757 |
+
const uint32_t v_trans = 0;
|
| 758 |
+
const uint32_t n_layer = hparams.n_layer;
|
| 759 |
+
|
| 760 |
+
io.write(&v_trans, sizeof(v_trans));
|
| 761 |
+
io.write(&n_layer, sizeof(n_layer));
|
| 762 |
+
|
| 763 |
+
std::vector<uint8_t> tmp_buf;
|
| 764 |
+
|
| 765 |
+
// Iterate and write all the keys first, each row is a cell
|
| 766 |
+
// Get whole range at a time
|
| 767 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 768 |
+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
| 769 |
+
|
| 770 |
+
// Write key type
|
| 771 |
+
const int32_t k_type_i = (int32_t)k_l[il]->type;
|
| 772 |
+
io.write(&k_type_i, sizeof(k_type_i));
|
| 773 |
+
|
| 774 |
+
// Write row size of key
|
| 775 |
+
const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
|
| 776 |
+
io.write(&k_size_row, sizeof(k_size_row));
|
| 777 |
+
|
| 778 |
+
// Read each range of cells of k_size length each into tmp_buf and write out
|
| 779 |
+
for (const auto & range : cell_ranges) {
|
| 780 |
+
const size_t range_size = range.second - range.first;
|
| 781 |
+
const size_t buf_size = range_size * k_size_row;
|
| 782 |
+
io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
|
| 783 |
+
}
|
| 784 |
+
}
|
| 785 |
+
|
| 786 |
+
if (!v_trans) {
|
| 787 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 788 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 789 |
+
|
| 790 |
+
// Write value type
|
| 791 |
+
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
| 792 |
+
io.write(&v_type_i, sizeof(v_type_i));
|
| 793 |
+
|
| 794 |
+
// Write row size of value
|
| 795 |
+
const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
|
| 796 |
+
io.write(&v_size_row, sizeof(v_size_row));
|
| 797 |
+
|
| 798 |
+
// Read each range of cells of v_size length each into tmp_buf and write out
|
| 799 |
+
for (const auto & range : cell_ranges) {
|
| 800 |
+
const size_t range_size = range.second - range.first;
|
| 801 |
+
const size_t buf_size = range_size * v_size_row;
|
| 802 |
+
io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
|
| 803 |
+
}
|
| 804 |
+
}
|
| 805 |
+
} else {
|
| 806 |
+
// When v is transposed, we also need the element size and get the element ranges from each row
|
| 807 |
+
const uint32_t kv_size = size;
|
| 808 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 809 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 810 |
+
|
| 811 |
+
// Write value type
|
| 812 |
+
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
| 813 |
+
io.write(&v_type_i, sizeof(v_type_i));
|
| 814 |
+
|
| 815 |
+
// Write element size
|
| 816 |
+
const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
|
| 817 |
+
io.write(&v_size_el, sizeof(v_size_el));
|
| 818 |
+
|
| 819 |
+
// Write GQA embedding size
|
| 820 |
+
io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
|
| 821 |
+
|
| 822 |
+
// For each row, we get the element values of each cell
|
| 823 |
+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
| 824 |
+
// Read each range of cells of v_size_el length each into tmp_buf and write out
|
| 825 |
+
for (const auto & range : cell_ranges) {
|
| 826 |
+
const size_t range_size = range.second - range.first;
|
| 827 |
+
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
|
| 828 |
+
const size_t buf_size = range_size * v_size_el;
|
| 829 |
+
io.write_tensor(v_l[il], src_offset, buf_size);
|
| 830 |
+
}
|
| 831 |
+
}
|
| 832 |
+
}
|
| 833 |
+
}
|
| 834 |
+
}
|
| 835 |
+
|
| 836 |
+
bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
| 837 |
+
if (dest_seq_id != -1) {
|
| 838 |
+
// single sequence
|
| 839 |
+
|
| 840 |
+
seq_rm(dest_seq_id, -1, -1);
|
| 841 |
+
|
| 842 |
+
llama_sbatch sbatch;
|
| 843 |
+
llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
| 844 |
+
|
| 845 |
+
batch.n_tokens = cell_count;
|
| 846 |
+
batch.n_seq_tokens = cell_count;
|
| 847 |
+
batch.n_seqs = 1;
|
| 848 |
+
|
| 849 |
+
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 850 |
+
llama_pos pos;
|
| 851 |
+
uint32_t n_seq_id;
|
| 852 |
+
|
| 853 |
+
io.read_to(&pos, sizeof(pos));
|
| 854 |
+
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
| 855 |
+
|
| 856 |
+
if (n_seq_id != 0) {
|
| 857 |
+
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
| 858 |
+
return false;
|
| 859 |
+
}
|
| 860 |
+
|
| 861 |
+
batch.pos[i] = pos;
|
| 862 |
+
}
|
| 863 |
+
batch.n_seq_id[0] = 1;
|
| 864 |
+
batch.seq_id[0] = &dest_seq_id;
|
| 865 |
+
|
| 866 |
+
if (!find_slot(batch)) {
|
| 867 |
+
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
| 868 |
+
return false;
|
| 869 |
+
}
|
| 870 |
+
|
| 871 |
+
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
| 872 |
+
// Assume that this is one contiguous block of cells
|
| 873 |
+
GGML_ASSERT(head + cell_count <= size);
|
| 874 |
+
GGML_ASSERT(cells[head].pos == batch.pos[0]);
|
| 875 |
+
GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
|
| 876 |
+
GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
|
| 877 |
+
GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
|
| 878 |
+
} else {
|
| 879 |
+
// whole KV cache restore
|
| 880 |
+
|
| 881 |
+
if (cell_count > size) {
|
| 882 |
+
LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
|
| 883 |
+
return false;
|
| 884 |
+
}
|
| 885 |
+
|
| 886 |
+
clear();
|
| 887 |
+
|
| 888 |
+
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 889 |
+
kv_cell & cell = cells[i];
|
| 890 |
+
|
| 891 |
+
llama_pos pos;
|
| 892 |
+
uint32_t n_seq_id;
|
| 893 |
+
|
| 894 |
+
io.read_to(&pos, sizeof(pos));
|
| 895 |
+
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
| 896 |
+
|
| 897 |
+
cell.pos = pos;
|
| 898 |
+
|
| 899 |
+
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
| 900 |
+
llama_seq_id seq_id;
|
| 901 |
+
io.read_to(&seq_id, sizeof(seq_id));
|
| 902 |
+
|
| 903 |
+
// TODO: llama_kv_cache_recurrent should have a notion of max sequences
|
| 904 |
+
//if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
|
| 905 |
+
if (seq_id < 0) {
|
| 906 |
+
//LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
|
| 907 |
+
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
|
| 908 |
+
return false;
|
| 909 |
+
}
|
| 910 |
+
|
| 911 |
+
cell.seq_id.insert(seq_id);
|
| 912 |
+
|
| 913 |
+
int32_t & tail = cells[seq_id].tail;
|
| 914 |
+
if (tail != -1) {
|
| 915 |
+
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
|
| 916 |
+
return false;
|
| 917 |
+
}
|
| 918 |
+
tail = i;
|
| 919 |
+
}
|
| 920 |
+
}
|
| 921 |
+
|
| 922 |
+
head = 0;
|
| 923 |
+
used = cell_count;
|
| 924 |
+
}
|
| 925 |
+
|
| 926 |
+
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 927 |
+
uint32_t cell_id = head + i;
|
| 928 |
+
// make sure the recurrent states will keep their restored state
|
| 929 |
+
cells[cell_id].src = cell_id;
|
| 930 |
+
}
|
| 931 |
+
|
| 932 |
+
return true;
|
| 933 |
+
}
|
| 934 |
+
|
| 935 |
+
bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
| 936 |
+
uint32_t v_trans;
|
| 937 |
+
uint32_t n_layer;
|
| 938 |
+
io.read_to(&v_trans, sizeof(v_trans));
|
| 939 |
+
io.read_to(&n_layer, sizeof(n_layer));
|
| 940 |
+
|
| 941 |
+
if (n_layer != hparams.n_layer) {
|
| 942 |
+
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
|
| 943 |
+
return false;
|
| 944 |
+
}
|
| 945 |
+
if (cell_count > size) {
|
| 946 |
+
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
|
| 947 |
+
return false;
|
| 948 |
+
}
|
| 949 |
+
if (false != (bool) v_trans) {
|
| 950 |
+
LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
|
| 951 |
+
return false;
|
| 952 |
+
}
|
| 953 |
+
|
| 954 |
+
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
|
| 955 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 956 |
+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
| 957 |
+
|
| 958 |
+
// Read type of key
|
| 959 |
+
int32_t k_type_i_ref;
|
| 960 |
+
io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
|
| 961 |
+
const int32_t k_type_i = (int32_t) k_l[il]->type;
|
| 962 |
+
if (k_type_i != k_type_i_ref) {
|
| 963 |
+
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
|
| 964 |
+
return false;
|
| 965 |
+
}
|
| 966 |
+
|
| 967 |
+
// Read row size of key
|
| 968 |
+
uint64_t k_size_row_ref;
|
| 969 |
+
io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
|
| 970 |
+
const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
|
| 971 |
+
if (k_size_row != k_size_row_ref) {
|
| 972 |
+
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
|
| 973 |
+
return false;
|
| 974 |
+
}
|
| 975 |
+
|
| 976 |
+
if (cell_count) {
|
| 977 |
+
// Read and set the keys for the whole cell range
|
| 978 |
+
ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
|
| 979 |
+
}
|
| 980 |
+
}
|
| 981 |
+
|
| 982 |
+
if (!v_trans) {
|
| 983 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 984 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 985 |
+
|
| 986 |
+
// Read type of value
|
| 987 |
+
int32_t v_type_i_ref;
|
| 988 |
+
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
| 989 |
+
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
| 990 |
+
if (v_type_i != v_type_i_ref) {
|
| 991 |
+
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
| 992 |
+
return false;
|
| 993 |
+
}
|
| 994 |
+
|
| 995 |
+
// Read row size of value
|
| 996 |
+
uint64_t v_size_row_ref;
|
| 997 |
+
io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
|
| 998 |
+
const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
|
| 999 |
+
if (v_size_row != v_size_row_ref) {
|
| 1000 |
+
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
|
| 1001 |
+
return false;
|
| 1002 |
+
}
|
| 1003 |
+
|
| 1004 |
+
if (cell_count) {
|
| 1005 |
+
// Read and set the values for the whole cell range
|
| 1006 |
+
ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
|
| 1007 |
+
}
|
| 1008 |
+
}
|
| 1009 |
+
} else {
|
| 1010 |
+
// For each layer, read the values for each cell (transposed)
|
| 1011 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 1012 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 1013 |
+
|
| 1014 |
+
// Read type of value
|
| 1015 |
+
int32_t v_type_i_ref;
|
| 1016 |
+
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
| 1017 |
+
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
| 1018 |
+
if (v_type_i != v_type_i_ref) {
|
| 1019 |
+
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
| 1020 |
+
return false;
|
| 1021 |
+
}
|
| 1022 |
+
|
| 1023 |
+
// Read element size of value
|
| 1024 |
+
uint32_t v_size_el_ref;
|
| 1025 |
+
io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
|
| 1026 |
+
const size_t v_size_el = ggml_type_size(v_l[il]->type);
|
| 1027 |
+
if (v_size_el != v_size_el_ref) {
|
| 1028 |
+
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
|
| 1029 |
+
return false;
|
| 1030 |
+
}
|
| 1031 |
+
|
| 1032 |
+
// Read GQA embedding size
|
| 1033 |
+
uint32_t n_embd_v_gqa_ref;
|
| 1034 |
+
io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
|
| 1035 |
+
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
|
| 1036 |
+
LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
|
| 1037 |
+
return false;
|
| 1038 |
+
}
|
| 1039 |
+
|
| 1040 |
+
if (cell_count) {
|
| 1041 |
+
// For each row in the transposed matrix, read the values for the whole cell range
|
| 1042 |
+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
| 1043 |
+
const size_t dst_offset = (head + j * size) * v_size_el;
|
| 1044 |
+
ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
| 1045 |
+
}
|
| 1046 |
+
}
|
| 1047 |
+
}
|
| 1048 |
+
}
|
| 1049 |
+
|
| 1050 |
+
return true;
|
| 1051 |
+
}
|
| 1052 |
+
|
| 1053 |
+
//
|
| 1054 |
+
// llama_kv_cache_recurrent_state
|
| 1055 |
+
//
|
| 1056 |
+
|
| 1057 |
+
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {}
|
| 1058 |
+
|
| 1059 |
+
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
|
| 1060 |
+
llama_memory_status status,
|
| 1061 |
+
llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) {
|
| 1062 |
+
}
|
| 1063 |
+
|
| 1064 |
+
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
|
| 1065 |
+
llama_memory_status status,
|
| 1066 |
+
llama_kv_cache_recurrent * kv,
|
| 1067 |
+
llama_sbatch sbatch,
|
| 1068 |
+
std::vector<llama_ubatch> ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
|
| 1069 |
+
|
| 1070 |
+
llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default;
|
| 1071 |
+
|
| 1072 |
+
bool llama_kv_cache_recurrent_state::next() {
|
| 1073 |
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1074 |
+
|
| 1075 |
+
if (++i_next >= ubatches.size()) {
|
| 1076 |
+
return false;
|
| 1077 |
+
}
|
| 1078 |
+
|
| 1079 |
+
return true;
|
| 1080 |
+
}
|
| 1081 |
+
|
| 1082 |
+
bool llama_kv_cache_recurrent_state::apply() {
|
| 1083 |
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1084 |
+
|
| 1085 |
+
kv->find_slot(ubatches[i_next]);
|
| 1086 |
+
|
| 1087 |
+
return true;
|
| 1088 |
+
}
|
| 1089 |
+
|
| 1090 |
+
std::vector<int64_t> & llama_kv_cache_recurrent_state::out_ids() {
|
| 1091 |
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1092 |
+
|
| 1093 |
+
return sbatch.out_ids;
|
| 1094 |
+
}
|
| 1095 |
+
|
| 1096 |
+
llama_memory_status llama_kv_cache_recurrent_state::get_status() const {
|
| 1097 |
+
return status;
|
| 1098 |
+
}
|
| 1099 |
+
|
| 1100 |
+
const llama_ubatch & llama_kv_cache_recurrent_state::get_ubatch() const {
|
| 1101 |
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1102 |
+
|
| 1103 |
+
return ubatches[i_next];
|
| 1104 |
+
}
|
| 1105 |
+
|
| 1106 |
+
uint32_t llama_kv_cache_recurrent_state::get_n_kv() const {
|
| 1107 |
+
return is_full ? kv->size : kv->n;
|
| 1108 |
+
}
|
| 1109 |
+
|
| 1110 |
+
uint32_t llama_kv_cache_recurrent_state::get_head() const {
|
| 1111 |
+
return is_full ? 0 : kv->head;
|
| 1112 |
+
}
|
| 1113 |
+
|
| 1114 |
+
uint32_t llama_kv_cache_recurrent_state::get_size() const {
|
| 1115 |
+
return kv->size;
|
| 1116 |
+
}
|
| 1117 |
+
|
| 1118 |
+
ggml_tensor * llama_kv_cache_recurrent_state::get_k_l(int32_t il) const {
|
| 1119 |
+
return kv->k_l[il];
|
| 1120 |
+
}
|
| 1121 |
+
|
| 1122 |
+
ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
|
| 1123 |
+
return kv->v_l[il];
|
| 1124 |
+
}
|
| 1125 |
+
|
| 1126 |
+
int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
|
| 1127 |
+
return kv->s_copy(i);
|
| 1128 |
+
}
|
| 1129 |
+
|
| 1130 |
+
float llama_kv_cache_recurrent_state::s_mask(int i) const {
|
| 1131 |
+
return kv->s_mask(i);
|
| 1132 |
+
}
|
examples/talk-llama/llama-kv-cache-recurrent.h
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "llama-batch.h"
|
| 4 |
+
#include "llama-graph.h"
|
| 5 |
+
#include "llama-kv-cache.h"
|
| 6 |
+
|
| 7 |
+
#include <set>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
//
|
| 11 |
+
// llama_kv_cache_recurrent
|
| 12 |
+
//
|
| 13 |
+
|
| 14 |
+
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
|
| 15 |
+
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
|
| 16 |
+
class llama_kv_cache_recurrent : public llama_kv_cache {
|
| 17 |
+
public:
|
| 18 |
+
llama_kv_cache_recurrent(
|
| 19 |
+
const llama_model & model,
|
| 20 |
+
ggml_type type_k,
|
| 21 |
+
ggml_type type_v,
|
| 22 |
+
bool offload,
|
| 23 |
+
uint32_t kv_size,
|
| 24 |
+
uint32_t n_seq_max);
|
| 25 |
+
|
| 26 |
+
~llama_kv_cache_recurrent() = default;
|
| 27 |
+
|
| 28 |
+
//
|
| 29 |
+
// llama_memory_i
|
| 30 |
+
//
|
| 31 |
+
|
| 32 |
+
void clear() override;
|
| 33 |
+
|
| 34 |
+
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
| 35 |
+
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
| 36 |
+
void seq_keep(llama_seq_id seq_id) override;
|
| 37 |
+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
| 38 |
+
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
| 39 |
+
|
| 40 |
+
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
| 41 |
+
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
| 42 |
+
|
| 43 |
+
//
|
| 44 |
+
// llama_kv_cache
|
| 45 |
+
//
|
| 46 |
+
|
| 47 |
+
llama_memory_state_ptr init_batch(
|
| 48 |
+
const llama_batch & batch,
|
| 49 |
+
uint32_t n_ubatch,
|
| 50 |
+
bool embd_pooled,
|
| 51 |
+
bool logits_all) override;
|
| 52 |
+
|
| 53 |
+
llama_memory_state_ptr init_full() override;
|
| 54 |
+
|
| 55 |
+
bool update(llama_context & lctx) override;
|
| 56 |
+
|
| 57 |
+
void defrag_sched(float thold) override;
|
| 58 |
+
|
| 59 |
+
bool prepare(const std::vector<llama_ubatch> & ubatches);
|
| 60 |
+
|
| 61 |
+
// find a contiguous slot of kv cells and emplace the ubatch there
|
| 62 |
+
bool find_slot(const llama_ubatch & ubatch);
|
| 63 |
+
|
| 64 |
+
bool get_can_shift() const override;
|
| 65 |
+
|
| 66 |
+
// TODO: temporary methods - they are not really const as they do const_cast<>, fix this
|
| 67 |
+
int32_t s_copy(int i) const;
|
| 68 |
+
float s_mask(int i) const;
|
| 69 |
+
|
| 70 |
+
// state write/load
|
| 71 |
+
|
| 72 |
+
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
| 73 |
+
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
| 74 |
+
|
| 75 |
+
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
| 76 |
+
uint32_t size = 0; // total number of cells, shared across all sequences
|
| 77 |
+
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
| 78 |
+
|
| 79 |
+
// computed before each graph build
|
| 80 |
+
uint32_t n = 0;
|
| 81 |
+
|
| 82 |
+
// TODO: optimize for recurrent state needs
|
| 83 |
+
struct kv_cell {
|
| 84 |
+
llama_pos pos = -1;
|
| 85 |
+
int32_t src = -1; // used to copy states
|
| 86 |
+
int32_t tail = -1;
|
| 87 |
+
|
| 88 |
+
std::set<llama_seq_id> seq_id;
|
| 89 |
+
|
| 90 |
+
bool has_seq_id(const llama_seq_id & id) const {
|
| 91 |
+
return seq_id.find(id) != seq_id.end();
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
bool is_empty() const {
|
| 95 |
+
return seq_id.empty();
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
bool is_same_seq(const kv_cell & other) const {
|
| 99 |
+
return seq_id == other.seq_id;
|
| 100 |
+
}
|
| 101 |
+
};
|
| 102 |
+
|
| 103 |
+
std::vector<kv_cell> cells;
|
| 104 |
+
|
| 105 |
+
std::vector<ggml_tensor *> k_l; // per layer
|
| 106 |
+
std::vector<ggml_tensor *> v_l;
|
| 107 |
+
|
| 108 |
+
private:
|
| 109 |
+
//const llama_model & model;
|
| 110 |
+
const llama_hparams & hparams;
|
| 111 |
+
|
| 112 |
+
const uint32_t n_seq_max = 1;
|
| 113 |
+
|
| 114 |
+
std::vector<ggml_context_ptr> ctxs;
|
| 115 |
+
std::vector<ggml_backend_buffer_ptr> bufs;
|
| 116 |
+
|
| 117 |
+
size_t total_size() const;
|
| 118 |
+
|
| 119 |
+
size_t size_k_bytes() const;
|
| 120 |
+
size_t size_v_bytes() const;
|
| 121 |
+
|
| 122 |
+
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
| 123 |
+
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
| 124 |
+
|
| 125 |
+
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
| 126 |
+
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
| 127 |
+
};
|
| 128 |
+
|
| 129 |
+
class llama_kv_cache_recurrent_state : public llama_memory_state_i {
|
| 130 |
+
public:
|
| 131 |
+
// used for errors
|
| 132 |
+
llama_kv_cache_recurrent_state(llama_memory_status status);
|
| 133 |
+
|
| 134 |
+
// used to create a full-cache state
|
| 135 |
+
llama_kv_cache_recurrent_state(
|
| 136 |
+
llama_memory_status status,
|
| 137 |
+
llama_kv_cache_recurrent * kv);
|
| 138 |
+
|
| 139 |
+
// used to create a state from a batch
|
| 140 |
+
llama_kv_cache_recurrent_state(
|
| 141 |
+
llama_memory_status status,
|
| 142 |
+
llama_kv_cache_recurrent * kv,
|
| 143 |
+
llama_sbatch sbatch,
|
| 144 |
+
std::vector<llama_ubatch> ubatches);
|
| 145 |
+
|
| 146 |
+
virtual ~llama_kv_cache_recurrent_state();
|
| 147 |
+
|
| 148 |
+
//
|
| 149 |
+
// llama_memory_state_i
|
| 150 |
+
//
|
| 151 |
+
|
| 152 |
+
bool next() override;
|
| 153 |
+
bool apply() override;
|
| 154 |
+
|
| 155 |
+
std::vector<int64_t> & out_ids() override;
|
| 156 |
+
|
| 157 |
+
llama_memory_status get_status() const override;
|
| 158 |
+
const llama_ubatch & get_ubatch() const override;
|
| 159 |
+
|
| 160 |
+
//
|
| 161 |
+
// llama_kv_cache_recurrent_state specific API
|
| 162 |
+
//
|
| 163 |
+
|
| 164 |
+
uint32_t get_n_kv() const;
|
| 165 |
+
uint32_t get_head() const;
|
| 166 |
+
uint32_t get_size() const;
|
| 167 |
+
|
| 168 |
+
ggml_tensor * get_k_l(int32_t il) const;
|
| 169 |
+
ggml_tensor * get_v_l(int32_t il) const;
|
| 170 |
+
|
| 171 |
+
int32_t s_copy(int i) const;
|
| 172 |
+
float s_mask(int i) const;
|
| 173 |
+
|
| 174 |
+
private:
|
| 175 |
+
const llama_memory_status status;
|
| 176 |
+
|
| 177 |
+
llama_kv_cache_recurrent * kv;
|
| 178 |
+
|
| 179 |
+
llama_sbatch sbatch;
|
| 180 |
+
|
| 181 |
+
size_t i_next = 0;
|
| 182 |
+
|
| 183 |
+
std::vector<llama_ubatch> ubatches;
|
| 184 |
+
|
| 185 |
+
//
|
| 186 |
+
// data needed for building the compute graph for the current ubatch:
|
| 187 |
+
// TODO: extract all the state like `head` and `n` here
|
| 188 |
+
//
|
| 189 |
+
|
| 190 |
+
const bool is_full = false;
|
| 191 |
+
};
|
examples/talk-llama/llama-kv-cache-unified-iswa.cpp
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "llama-kv-cache-unified-iswa.h"
|
| 2 |
+
|
| 3 |
+
#include "llama-impl.h"
|
| 4 |
+
#include "llama-batch.h"
|
| 5 |
+
#include "llama-model.h"
|
| 6 |
+
|
| 7 |
+
#include <algorithm>
|
| 8 |
+
#include <cassert>
|
| 9 |
+
|
| 10 |
+
//
|
| 11 |
+
// llama_kv_cache_unified_iswa
|
| 12 |
+
//
|
| 13 |
+
|
| 14 |
+
llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
| 15 |
+
const llama_model & model,
|
| 16 |
+
ggml_type type_k,
|
| 17 |
+
ggml_type type_v,
|
| 18 |
+
bool v_trans,
|
| 19 |
+
bool offload,
|
| 20 |
+
bool swa_full,
|
| 21 |
+
uint32_t kv_size,
|
| 22 |
+
uint32_t n_seq_max,
|
| 23 |
+
uint32_t n_ubatch,
|
| 24 |
+
uint32_t n_pad) : hparams(model.hparams) {
|
| 25 |
+
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
| 26 |
+
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
| 27 |
+
|
| 28 |
+
const uint32_t size_base = kv_size;
|
| 29 |
+
|
| 30 |
+
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
|
| 31 |
+
|
| 32 |
+
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
| 33 |
+
if (swa_full) {
|
| 34 |
+
LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
|
| 35 |
+
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
| 36 |
+
|
| 37 |
+
size_swa = size_base;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
|
| 41 |
+
|
| 42 |
+
kv_base = std::make_unique<llama_kv_cache_unified>(
|
| 43 |
+
model, std::move(filter_base), type_k, type_v,
|
| 44 |
+
v_trans, offload, size_base, n_seq_max, n_pad,
|
| 45 |
+
0, LLAMA_SWA_TYPE_NONE);
|
| 46 |
+
|
| 47 |
+
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
| 48 |
+
|
| 49 |
+
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
| 50 |
+
model, std::move(filter_swa), type_k, type_v,
|
| 51 |
+
v_trans, offload, size_swa, n_seq_max, n_pad,
|
| 52 |
+
hparams.n_swa, hparams.swa_type);
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
void llama_kv_cache_unified_iswa::clear() {
|
| 56 |
+
kv_base->clear();
|
| 57 |
+
kv_swa ->clear();
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
| 61 |
+
bool res = true;
|
| 62 |
+
|
| 63 |
+
res = res & kv_base->seq_rm(seq_id, p0, p1);
|
| 64 |
+
res = res & kv_swa ->seq_rm(seq_id, p0, p1);
|
| 65 |
+
|
| 66 |
+
return res;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
| 70 |
+
kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
| 71 |
+
kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
|
| 75 |
+
kv_base->seq_keep(seq_id);
|
| 76 |
+
kv_swa ->seq_keep(seq_id);
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
| 80 |
+
kv_base->seq_add(seq_id, p0, p1, shift);
|
| 81 |
+
kv_swa ->seq_add(seq_id, p0, p1, shift);
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
| 85 |
+
kv_base->seq_div(seq_id, p0, p1, d);
|
| 86 |
+
kv_swa ->seq_div(seq_id, p0, p1, d);
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
|
| 90 |
+
// the base cache is a superset of the SWA cache, so we can just check the SWA cache
|
| 91 |
+
return kv_swa->seq_pos_min(seq_id);
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
| 95 |
+
return kv_swa->seq_pos_max(seq_id);
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
|
| 99 |
+
GGML_UNUSED(embd_pooled);
|
| 100 |
+
|
| 101 |
+
// TODO: if we fail with split_simple, we should attempt different splitting strategies
|
| 102 |
+
// but to do that properly, we first have to refactor the batches to be more flexible
|
| 103 |
+
|
| 104 |
+
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
|
| 105 |
+
|
| 106 |
+
std::vector<llama_ubatch> ubatches;
|
| 107 |
+
|
| 108 |
+
while (sbatch.n_tokens > 0) {
|
| 109 |
+
auto ubatch = sbatch.split_simple(n_ubatch);
|
| 110 |
+
|
| 111 |
+
ubatches.push_back(ubatch);
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
auto heads_base = kv_base->prepare(ubatches);
|
| 115 |
+
if (heads_base.empty()) {
|
| 116 |
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
auto heads_swa = kv_swa->prepare(ubatches);
|
| 120 |
+
if (heads_swa.empty()) {
|
| 121 |
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
assert(heads_base.size() == heads_swa.size());
|
| 125 |
+
|
| 126 |
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS,
|
| 127 |
+
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
|
| 131 |
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
|
| 135 |
+
bool res = false;
|
| 136 |
+
|
| 137 |
+
res = res | kv_base->update(lctx);
|
| 138 |
+
res = res | kv_swa ->update(lctx);
|
| 139 |
+
|
| 140 |
+
return res;
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
|
| 144 |
+
kv_base->defrag_sched(thold);
|
| 145 |
+
kv_swa ->defrag_sched(thold);
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
| 149 |
+
return kv_base->get_size() == kv_swa->get_size();
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
| 153 |
+
kv_base->state_write(io, seq_id);
|
| 154 |
+
kv_swa ->state_write(io, seq_id);
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
| 158 |
+
kv_base->state_read(io, seq_id);
|
| 159 |
+
kv_swa ->state_read(io, seq_id);
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
|
| 163 |
+
return kv_base.get();
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
|
| 167 |
+
return kv_swa.get();
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
//
|
| 171 |
+
// llama_kv_cache_unified_iswa_state
|
| 172 |
+
//
|
| 173 |
+
|
| 174 |
+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
|
| 175 |
+
|
| 176 |
+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
| 177 |
+
llama_memory_status status,
|
| 178 |
+
llama_kv_cache_unified_iswa * kv) : status(status) {
|
| 179 |
+
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base()));
|
| 180 |
+
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ()));
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
| 184 |
+
llama_memory_status status,
|
| 185 |
+
llama_kv_cache_unified_iswa * kv,
|
| 186 |
+
llama_sbatch sbatch,
|
| 187 |
+
std::vector<uint32_t> heads_base,
|
| 188 |
+
std::vector<uint32_t> heads_swa,
|
| 189 |
+
std::vector<llama_ubatch> ubatches)
|
| 190 |
+
: status(status),
|
| 191 |
+
sbatch(std::move(sbatch)),
|
| 192 |
+
ubatches(std::move(ubatches)) {
|
| 193 |
+
// note: here we copy the ubatches. not sure if this is ideal
|
| 194 |
+
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches));
|
| 195 |
+
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
|
| 199 |
+
|
| 200 |
+
bool llama_kv_cache_unified_iswa_state::next() {
|
| 201 |
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 202 |
+
|
| 203 |
+
state_base->next();
|
| 204 |
+
state_swa ->next();
|
| 205 |
+
|
| 206 |
+
if (++i_next >= ubatches.size()) {
|
| 207 |
+
return false;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
return true;
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
bool llama_kv_cache_unified_iswa_state::apply() {
|
| 214 |
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 215 |
+
|
| 216 |
+
bool res = true;
|
| 217 |
+
|
| 218 |
+
res = res & state_base->apply();
|
| 219 |
+
res = res & state_swa ->apply();
|
| 220 |
+
|
| 221 |
+
return res;
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
std::vector<int64_t> & llama_kv_cache_unified_iswa_state::out_ids() {
|
| 225 |
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 226 |
+
|
| 227 |
+
return sbatch.out_ids;
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
|
| 231 |
+
return status;
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
|
| 235 |
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 236 |
+
return ubatches[i_next];
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
|
| 240 |
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 241 |
+
|
| 242 |
+
return state_base.get();
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
|
| 246 |
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 247 |
+
|
| 248 |
+
return state_swa.get();
|
| 249 |
+
}
|
examples/talk-llama/llama-kv-cache-unified-iswa.h
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "llama-kv-cache-unified.h"
|
| 4 |
+
|
| 5 |
+
#include <vector>
|
| 6 |
+
|
| 7 |
+
//
|
| 8 |
+
// llama_kv_cache_unified_iswa
|
| 9 |
+
//
|
| 10 |
+
|
| 11 |
+
// utilizes two instances of llama_kv_cache_unified
|
| 12 |
+
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
|
| 13 |
+
|
| 14 |
+
class llama_kv_cache_unified_iswa : public llama_kv_cache {
|
| 15 |
+
public:
|
| 16 |
+
llama_kv_cache_unified_iswa(
|
| 17 |
+
const llama_model & model,
|
| 18 |
+
ggml_type type_k,
|
| 19 |
+
ggml_type type_v,
|
| 20 |
+
bool v_trans,
|
| 21 |
+
bool offload,
|
| 22 |
+
bool swa_full,
|
| 23 |
+
uint32_t kv_size,
|
| 24 |
+
uint32_t n_seq_max,
|
| 25 |
+
uint32_t n_ubatch,
|
| 26 |
+
uint32_t n_pad);
|
| 27 |
+
|
| 28 |
+
~llama_kv_cache_unified_iswa() = default;
|
| 29 |
+
|
| 30 |
+
//
|
| 31 |
+
// llama_memory_i
|
| 32 |
+
//
|
| 33 |
+
|
| 34 |
+
void clear() override;
|
| 35 |
+
|
| 36 |
+
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
| 37 |
+
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
| 38 |
+
void seq_keep(llama_seq_id seq_id) override;
|
| 39 |
+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
| 40 |
+
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
| 41 |
+
|
| 42 |
+
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
| 43 |
+
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
| 44 |
+
|
| 45 |
+
//
|
| 46 |
+
// llama_kv_cache
|
| 47 |
+
//
|
| 48 |
+
|
| 49 |
+
llama_memory_state_ptr init_batch(
|
| 50 |
+
const llama_batch & batch,
|
| 51 |
+
uint32_t n_ubatch,
|
| 52 |
+
bool embd_pooled,
|
| 53 |
+
bool logits_all) override;
|
| 54 |
+
|
| 55 |
+
llama_memory_state_ptr init_full() override;
|
| 56 |
+
|
| 57 |
+
bool update(llama_context & lctx) override;
|
| 58 |
+
|
| 59 |
+
void defrag_sched(float thold) override;
|
| 60 |
+
|
| 61 |
+
bool get_can_shift() const override;
|
| 62 |
+
|
| 63 |
+
// state write/load
|
| 64 |
+
|
| 65 |
+
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
| 66 |
+
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
| 67 |
+
|
| 68 |
+
//
|
| 69 |
+
// llama_kv_cache_unified_iswa specific API
|
| 70 |
+
//
|
| 71 |
+
|
| 72 |
+
llama_kv_cache_unified * get_base() const;
|
| 73 |
+
llama_kv_cache_unified * get_swa () const;
|
| 74 |
+
|
| 75 |
+
private:
|
| 76 |
+
const llama_hparams & hparams;
|
| 77 |
+
|
| 78 |
+
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
| 79 |
+
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
| 80 |
+
};
|
| 81 |
+
|
| 82 |
+
class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
|
| 83 |
+
public:
|
| 84 |
+
// used for errors
|
| 85 |
+
llama_kv_cache_unified_iswa_state(llama_memory_status status);
|
| 86 |
+
|
| 87 |
+
// used to create a full-cache state
|
| 88 |
+
llama_kv_cache_unified_iswa_state(
|
| 89 |
+
llama_memory_status status,
|
| 90 |
+
llama_kv_cache_unified_iswa * kv);
|
| 91 |
+
|
| 92 |
+
// used to create a state from a batch
|
| 93 |
+
llama_kv_cache_unified_iswa_state(
|
| 94 |
+
llama_memory_status status,
|
| 95 |
+
llama_kv_cache_unified_iswa * kv,
|
| 96 |
+
llama_sbatch sbatch,
|
| 97 |
+
std::vector<uint32_t> heads_base,
|
| 98 |
+
std::vector<uint32_t> heads_swa,
|
| 99 |
+
std::vector<llama_ubatch> ubatches);
|
| 100 |
+
|
| 101 |
+
virtual ~llama_kv_cache_unified_iswa_state();
|
| 102 |
+
|
| 103 |
+
//
|
| 104 |
+
// llama_memory_state_i
|
| 105 |
+
//
|
| 106 |
+
|
| 107 |
+
bool next() override;
|
| 108 |
+
bool apply() override;
|
| 109 |
+
|
| 110 |
+
std::vector<int64_t> & out_ids() override;
|
| 111 |
+
|
| 112 |
+
llama_memory_status get_status() const override;
|
| 113 |
+
const llama_ubatch & get_ubatch() const override;
|
| 114 |
+
|
| 115 |
+
//
|
| 116 |
+
// llama_kv_cache_unified_iswa_state specific API
|
| 117 |
+
//
|
| 118 |
+
|
| 119 |
+
const llama_kv_cache_unified_state * get_base() const;
|
| 120 |
+
const llama_kv_cache_unified_state * get_swa() const;
|
| 121 |
+
|
| 122 |
+
private:
|
| 123 |
+
const llama_memory_status status;
|
| 124 |
+
|
| 125 |
+
//llama_kv_cache_unified_iswa * kv;
|
| 126 |
+
|
| 127 |
+
llama_sbatch sbatch;
|
| 128 |
+
|
| 129 |
+
// the index of the next ubatch to process
|
| 130 |
+
size_t i_next = 0;
|
| 131 |
+
|
| 132 |
+
std::vector<llama_ubatch> ubatches;
|
| 133 |
+
|
| 134 |
+
std::unique_ptr<llama_kv_cache_unified_state> state_base;
|
| 135 |
+
std::unique_ptr<llama_kv_cache_unified_state> state_swa;
|
| 136 |
+
};
|
examples/talk-llama/llama-kv-cache-unified.cpp
ADDED
|
@@ -0,0 +1,1717 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "llama-kv-cache-unified.h"
|
| 2 |
+
|
| 3 |
+
#include "llama-impl.h"
|
| 4 |
+
#include "llama-model.h"
|
| 5 |
+
#include "llama-context.h"
|
| 6 |
+
|
| 7 |
+
#include <algorithm>
|
| 8 |
+
#include <cassert>
|
| 9 |
+
#include <cmath>
|
| 10 |
+
#include <limits>
|
| 11 |
+
#include <map>
|
| 12 |
+
#include <stdexcept>
|
| 13 |
+
|
| 14 |
+
//
|
| 15 |
+
// llama_kv_cache_unified
|
| 16 |
+
//
|
| 17 |
+
|
| 18 |
+
llama_kv_cache_unified::llama_kv_cache_unified(
|
| 19 |
+
const llama_model & model,
|
| 20 |
+
layer_filter_cb && filter,
|
| 21 |
+
ggml_type type_k,
|
| 22 |
+
ggml_type type_v,
|
| 23 |
+
bool v_trans,
|
| 24 |
+
bool offload,
|
| 25 |
+
uint32_t kv_size,
|
| 26 |
+
uint32_t n_seq_max,
|
| 27 |
+
uint32_t n_pad,
|
| 28 |
+
uint32_t n_swa,
|
| 29 |
+
llama_swa_type swa_type) :
|
| 30 |
+
model(model), hparams(model.hparams), v_trans(v_trans),
|
| 31 |
+
n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
|
| 32 |
+
|
| 33 |
+
GGML_ASSERT(kv_size % n_pad == 0);
|
| 34 |
+
|
| 35 |
+
// create a context for each buffer type
|
| 36 |
+
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
| 37 |
+
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
| 38 |
+
auto it = ctx_map.find(buft);
|
| 39 |
+
if (it == ctx_map.end()) {
|
| 40 |
+
ggml_init_params params = {
|
| 41 |
+
/*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()),
|
| 42 |
+
/*.mem_buffer =*/ NULL,
|
| 43 |
+
/*.no_alloc =*/ true,
|
| 44 |
+
};
|
| 45 |
+
|
| 46 |
+
ggml_context * ctx = ggml_init(params);
|
| 47 |
+
if (!ctx) {
|
| 48 |
+
return nullptr;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
ctx_map[buft] = ctx;
|
| 52 |
+
ctxs.emplace_back(ctx);
|
| 53 |
+
|
| 54 |
+
return ctx;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
return it->second;
|
| 58 |
+
};
|
| 59 |
+
|
| 60 |
+
head = 0;
|
| 61 |
+
|
| 62 |
+
cells.resize(kv_size);
|
| 63 |
+
|
| 64 |
+
for (uint32_t il = 0; il < hparams.n_layer; il++) {
|
| 65 |
+
if (filter && !filter(il)) {
|
| 66 |
+
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
|
| 67 |
+
continue;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
| 71 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 72 |
+
|
| 73 |
+
const char * dev_name = "CPU";
|
| 74 |
+
|
| 75 |
+
ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
|
| 76 |
+
|
| 77 |
+
if (offload) {
|
| 78 |
+
auto * dev = model.dev_layer(il);
|
| 79 |
+
buft = ggml_backend_dev_buffer_type(dev);
|
| 80 |
+
|
| 81 |
+
dev_name = ggml_backend_dev_name(dev);
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name);
|
| 85 |
+
|
| 86 |
+
ggml_context * ctx = ctx_for_buft(buft);
|
| 87 |
+
if (!ctx) {
|
| 88 |
+
throw std::runtime_error("failed to create ggml context for kv cache");
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
ggml_tensor * k;
|
| 92 |
+
ggml_tensor * v;
|
| 93 |
+
|
| 94 |
+
k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
|
| 95 |
+
v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
|
| 96 |
+
|
| 97 |
+
ggml_format_name(k, "cache_k_l%d", il);
|
| 98 |
+
ggml_format_name(v, "cache_v_l%d", il);
|
| 99 |
+
|
| 100 |
+
map_layer_ids[il] = layers.size();
|
| 101 |
+
layers.push_back({ il, k, v });
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
| 105 |
+
for (auto it : ctx_map) {
|
| 106 |
+
auto * buft = it.first;
|
| 107 |
+
auto * ctx = it.second;
|
| 108 |
+
|
| 109 |
+
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
| 110 |
+
if (!buf) {
|
| 111 |
+
throw std::runtime_error("failed to allocate buffer for kv cache");
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
| 115 |
+
|
| 116 |
+
ggml_backend_buffer_clear(buf, 0);
|
| 117 |
+
bufs.emplace_back(buf);
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
{
|
| 121 |
+
const size_t memory_size_k = size_k_bytes();
|
| 122 |
+
const size_t memory_size_v = size_v_bytes();
|
| 123 |
+
|
| 124 |
+
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
| 125 |
+
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
|
| 126 |
+
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
| 127 |
+
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
| 128 |
+
}
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
void llama_kv_cache_unified::clear() {
|
| 132 |
+
cells.reset();
|
| 133 |
+
|
| 134 |
+
head = 0;
|
| 135 |
+
|
| 136 |
+
for (auto & buf : bufs) {
|
| 137 |
+
ggml_backend_buffer_clear(buf.get(), 0);
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
| 142 |
+
uint32_t new_head = cells.size();
|
| 143 |
+
|
| 144 |
+
if (p0 < 0) {
|
| 145 |
+
p0 = 0;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
if (p1 < 0) {
|
| 149 |
+
p1 = std::numeric_limits<llama_pos>::max();
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 153 |
+
if (!cells.pos_in(i, p0, p1)) {
|
| 154 |
+
continue;
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
|
| 158 |
+
if (new_head == cells.size()) {
|
| 159 |
+
new_head = i;
|
| 160 |
+
}
|
| 161 |
+
}
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
// If we freed up a slot, set head to it so searching can start there.
|
| 165 |
+
if (new_head != cells.size() && new_head < head) {
|
| 166 |
+
head = new_head;
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
return true;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
| 173 |
+
if (seq_id_src == seq_id_dst) {
|
| 174 |
+
return;
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
if (p0 < 0) {
|
| 178 |
+
p0 = 0;
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
if (p1 < 0) {
|
| 182 |
+
p1 = std::numeric_limits<llama_pos>::max();
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 186 |
+
if (!cells.pos_in(i, p0, p1)) {
|
| 187 |
+
continue;
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
if (cells.seq_has(i, seq_id_src)) {
|
| 191 |
+
cells.seq_add(i, seq_id_dst);
|
| 192 |
+
}
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
| 197 |
+
uint32_t new_head = cells.size();
|
| 198 |
+
|
| 199 |
+
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 200 |
+
if (cells.seq_keep(i, seq_id)) {
|
| 201 |
+
if (new_head == cells.size()) {
|
| 202 |
+
new_head = i;
|
| 203 |
+
}
|
| 204 |
+
}
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
// If we freed up a slot, set head to it so searching can start there.
|
| 208 |
+
if (new_head != cells.size() && new_head < head) {
|
| 209 |
+
head = new_head;
|
| 210 |
+
}
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
| 214 |
+
if (shift == 0) {
|
| 215 |
+
return;
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
uint32_t new_head = cells.size();
|
| 219 |
+
|
| 220 |
+
if (p0 < 0) {
|
| 221 |
+
p0 = 0;
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
if (p1 < 0) {
|
| 225 |
+
p1 = std::numeric_limits<llama_pos>::max();
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
// If there is no range then return early to avoid looping over all cells.
|
| 229 |
+
if (p0 == p1) {
|
| 230 |
+
return;
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 234 |
+
if (!cells.pos_in(i, p0, p1)) {
|
| 235 |
+
continue;
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
if (cells.seq_has(i, seq_id)) {
|
| 239 |
+
if (cells.pos_add(i, shift)) {
|
| 240 |
+
if (new_head == cells.size()) {
|
| 241 |
+
new_head = i;
|
| 242 |
+
}
|
| 243 |
+
}
|
| 244 |
+
}
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
// If we freed up a slot, set head to it so searching can start there.
|
| 248 |
+
// Otherwise we just start the next search from the beginning.
|
| 249 |
+
head = new_head != cells.size() ? new_head : 0;
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
| 253 |
+
if (d == 1) {
|
| 254 |
+
return;
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
if (p0 < 0) {
|
| 258 |
+
p0 = 0;
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
if (p1 < 0) {
|
| 262 |
+
p1 = std::numeric_limits<llama_pos>::max();
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
// If there is no range then return early to avoid looping over the cache.
|
| 266 |
+
if (p0 == p1) {
|
| 267 |
+
return;
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 271 |
+
if (!cells.pos_in(i, p0, p1)) {
|
| 272 |
+
continue;
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
if (cells.seq_has(i, seq_id)) {
|
| 276 |
+
cells.pos_div(i, d);
|
| 277 |
+
}
|
| 278 |
+
}
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
|
| 282 |
+
return cells.seq_pos_min(seq_id);
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
| 286 |
+
return cells.seq_pos_max(seq_id);
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
| 290 |
+
const llama_batch & batch,
|
| 291 |
+
uint32_t n_ubatch,
|
| 292 |
+
bool embd_pooled,
|
| 293 |
+
bool logits_all) {
|
| 294 |
+
GGML_UNUSED(embd_pooled);
|
| 295 |
+
|
| 296 |
+
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
|
| 297 |
+
|
| 298 |
+
std::vector<llama_ubatch> ubatches;
|
| 299 |
+
while (sbatch.n_tokens > 0) {
|
| 300 |
+
ubatches.push_back(sbatch.split_simple(n_ubatch));
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
auto heads = prepare(ubatches);
|
| 304 |
+
if (heads.empty()) {
|
| 305 |
+
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS,
|
| 309 |
+
this, std::move(sbatch), std::move(heads), std::move(ubatches));
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
llama_memory_state_ptr llama_kv_cache_unified::init_full() {
|
| 313 |
+
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
| 317 |
+
std::vector<uint32_t> res;
|
| 318 |
+
|
| 319 |
+
struct state {
|
| 320 |
+
uint32_t head_old; // old position of the head, before placing the ubatch
|
| 321 |
+
uint32_t head_new; // new position of the head, after placing the ubatch
|
| 322 |
+
|
| 323 |
+
llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
|
| 324 |
+
};
|
| 325 |
+
|
| 326 |
+
// remember the old state of the cells so we can restore it in the end
|
| 327 |
+
std::vector<state> states;
|
| 328 |
+
|
| 329 |
+
bool success = true;
|
| 330 |
+
|
| 331 |
+
for (const auto & ubatch : ubatches) {
|
| 332 |
+
// only find a suitable slot for the ubatch. don't modify the cells yet
|
| 333 |
+
const int32_t head_new = find_slot(ubatch);
|
| 334 |
+
if (head_new < 0) {
|
| 335 |
+
success = false;
|
| 336 |
+
break;
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
// remeber the position that we found
|
| 340 |
+
res.push_back(head_new);
|
| 341 |
+
|
| 342 |
+
// store the old state of the cells in the recovery stack
|
| 343 |
+
states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)});
|
| 344 |
+
|
| 345 |
+
// now emplace the ubatch
|
| 346 |
+
apply_ubatch(head_new, ubatch);
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
// iterate backwards and restore the cells to their original state
|
| 350 |
+
for (auto it = states.rbegin(); it != states.rend(); ++it) {
|
| 351 |
+
cells.set(it->head_new, it->cells);
|
| 352 |
+
head = it->head_old;
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
if (!success) {
|
| 356 |
+
return {};
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
return res;
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
bool llama_kv_cache_unified::update(llama_context & lctx) {
|
| 363 |
+
bool updated = false;
|
| 364 |
+
|
| 365 |
+
auto * sched = lctx.get_sched();
|
| 366 |
+
|
| 367 |
+
if (cells.get_has_shift()) {
|
| 368 |
+
if (!get_can_shift()) {
|
| 369 |
+
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
|
| 373 |
+
|
| 374 |
+
// apply K-shift if needed
|
| 375 |
+
if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
| 376 |
+
ggml_backend_sched_reset(sched);
|
| 377 |
+
|
| 378 |
+
auto * gf = lctx.graph_init();
|
| 379 |
+
|
| 380 |
+
auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
|
| 381 |
+
if (!res) {
|
| 382 |
+
LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
|
| 383 |
+
return updated;
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
| 387 |
+
LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
|
| 388 |
+
return updated;
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
res->set_inputs(nullptr);
|
| 392 |
+
|
| 393 |
+
if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
| 394 |
+
LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
|
| 395 |
+
return updated;
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
updated = true;
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
cells.reset_shift();
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
if (do_defrag) {
|
| 405 |
+
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
| 406 |
+
|
| 407 |
+
if (defrag_prepare(lctx.graph_max_nodes())) {
|
| 408 |
+
ggml_backend_sched_reset(sched);
|
| 409 |
+
|
| 410 |
+
auto * gf = lctx.graph_init();
|
| 411 |
+
|
| 412 |
+
auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
|
| 413 |
+
if (!res) {
|
| 414 |
+
LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
|
| 415 |
+
return updated;
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
| 419 |
+
LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
|
| 420 |
+
return updated;
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
res->set_inputs(nullptr);
|
| 424 |
+
|
| 425 |
+
if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
| 426 |
+
LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
|
| 427 |
+
return updated;
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
updated = true;
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
do_defrag = false;
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
return updated;
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
void llama_kv_cache_unified::defrag_sched(float thold) {
|
| 440 |
+
const auto n_kv = cells.used_max_p1();
|
| 441 |
+
|
| 442 |
+
// - do not defrag small contexts (i.e. < 2048 tokens)
|
| 443 |
+
// - count the padding towards the number of used tokens
|
| 444 |
+
const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
|
| 445 |
+
|
| 446 |
+
// queue defragmentation for next llama_kv_cache_update
|
| 447 |
+
if (fragmentation > thold) {
|
| 448 |
+
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
| 449 |
+
|
| 450 |
+
do_defrag = true;
|
| 451 |
+
}
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
| 455 |
+
const uint32_t n_tokens = ubatch.n_tokens;
|
| 456 |
+
|
| 457 |
+
uint32_t head_cur = this->head;
|
| 458 |
+
|
| 459 |
+
// if we have enough unused cells before the current head ->
|
| 460 |
+
// better to start searching from the beginning of the cache, hoping to fill it
|
| 461 |
+
if (head_cur > cells.get_used() + 2*ubatch.n_tokens) {
|
| 462 |
+
head_cur = 0;
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
// otherwise, one cell per token.
|
| 466 |
+
|
| 467 |
+
if (n_tokens > cells.size()) {
|
| 468 |
+
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
| 469 |
+
return -1;
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
//#define FIND_SLOT_DEBUG 1
|
| 473 |
+
#if FIND_SLOT_DEBUG
|
| 474 |
+
LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", cells.used_max_p1(), cells.get_used(), head, n_swa);
|
| 475 |
+
|
| 476 |
+
// for debugging
|
| 477 |
+
{
|
| 478 |
+
std::string ss;
|
| 479 |
+
if (n_swa > 0) {
|
| 480 |
+
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 481 |
+
if (cells.is_empty(i)) {
|
| 482 |
+
ss += '.';
|
| 483 |
+
} else {
|
| 484 |
+
ss += std::to_string(cells.seq_get(i));
|
| 485 |
+
}
|
| 486 |
+
if (i%256 == 255) {
|
| 487 |
+
ss += '\n';
|
| 488 |
+
}
|
| 489 |
+
}
|
| 490 |
+
}
|
| 491 |
+
LLAMA_LOG_WARN("\n%s\n", ss.c_str());
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
| 495 |
+
if (cells.seq_pos_min(s) < 0) {
|
| 496 |
+
continue;
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[%d] = %5d, max[%d] = %5d\n", n_swa, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
|
| 500 |
+
}
|
| 501 |
+
#endif
|
| 502 |
+
|
| 503 |
+
uint32_t n_tested = 0;
|
| 504 |
+
|
| 505 |
+
while (true) {
|
| 506 |
+
if (head_cur + n_tokens > cells.size()) {
|
| 507 |
+
n_tested += cells.size() - head_cur;
|
| 508 |
+
head_cur = 0;
|
| 509 |
+
continue;
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
// keep track of what the minimum sequence positions would be if we accept the ubatch
|
| 513 |
+
llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
|
| 514 |
+
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
| 515 |
+
seq_pos_min[s] = cells.seq_pos_min(s);
|
| 516 |
+
}
|
| 517 |
+
|
| 518 |
+
bool found = true;
|
| 519 |
+
for (uint32_t i = 0; i < n_tokens; i++) {
|
| 520 |
+
const llama_pos pos = ubatch.pos[i];
|
| 521 |
+
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
| 522 |
+
|
| 523 |
+
// can we use this cell? either:
|
| 524 |
+
// - the cell is empty
|
| 525 |
+
// - the cell is occupied only by one sequence:
|
| 526 |
+
// - mask causally, if the sequence is the same as the one we are inserting
|
| 527 |
+
// - mask SWA, using current max pos for that sequence in the cache
|
| 528 |
+
// always insert in the cell with minimum pos
|
| 529 |
+
bool can_use = cells.is_empty(head_cur + i);
|
| 530 |
+
|
| 531 |
+
if (!can_use && cells.seq_count(head_cur + i) == 1) {
|
| 532 |
+
const llama_pos pos_cell = cells.pos_get(head_cur + i);
|
| 533 |
+
|
| 534 |
+
// causal mask
|
| 535 |
+
if (cells.seq_has(head_cur + i, seq_id)) {
|
| 536 |
+
can_use = pos_cell >= pos;
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
if (!can_use) {
|
| 540 |
+
const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
|
| 541 |
+
|
| 542 |
+
// SWA mask
|
| 543 |
+
// note: we insert only in the cell with minimum pos in order to preserve the invariant that
|
| 544 |
+
// all positions between [pos_min, pos_max] for each sequence will be present in the cache
|
| 545 |
+
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
|
| 546 |
+
if (pos_cell == seq_pos_min[seq_id_cell] &&
|
| 547 |
+
is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
| 548 |
+
seq_pos_min[seq_id_cell]++;
|
| 549 |
+
can_use = true;
|
| 550 |
+
}
|
| 551 |
+
}
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
if (!can_use) {
|
| 555 |
+
found = false;
|
| 556 |
+
head_cur += i + 1;
|
| 557 |
+
n_tested += i + 1;
|
| 558 |
+
break;
|
| 559 |
+
}
|
| 560 |
+
}
|
| 561 |
+
|
| 562 |
+
if (found) {
|
| 563 |
+
break;
|
| 564 |
+
}
|
| 565 |
+
|
| 566 |
+
if (n_tested >= cells.size()) {
|
| 567 |
+
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
| 568 |
+
return -1;
|
| 569 |
+
}
|
| 570 |
+
}
|
| 571 |
+
|
| 572 |
+
return head_cur;
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
|
| 576 |
+
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
| 577 |
+
if (!cells.is_empty(head_cur + i)) {
|
| 578 |
+
cells.rm(head_cur + i);
|
| 579 |
+
}
|
| 580 |
+
|
| 581 |
+
cells.pos_set(head_cur + i, ubatch.pos[i]);
|
| 582 |
+
|
| 583 |
+
for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
|
| 584 |
+
cells.seq_add(head_cur + i, ubatch.seq_id[i][j]);
|
| 585 |
+
}
|
| 586 |
+
}
|
| 587 |
+
|
| 588 |
+
// move the head at the end of the slot
|
| 589 |
+
head = head_cur + ubatch.n_tokens;
|
| 590 |
+
}
|
| 591 |
+
|
| 592 |
+
bool llama_kv_cache_unified::get_can_shift() const {
|
| 593 |
+
return true;
|
| 594 |
+
}
|
| 595 |
+
|
| 596 |
+
uint32_t llama_kv_cache_unified::get_size() const {
|
| 597 |
+
return cells.size();
|
| 598 |
+
}
|
| 599 |
+
|
| 600 |
+
uint32_t llama_kv_cache_unified::get_n_kv() const {
|
| 601 |
+
return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
|
| 605 |
+
const int32_t ikv = map_layer_ids.at(il);
|
| 606 |
+
|
| 607 |
+
auto * k = layers[ikv].k;
|
| 608 |
+
|
| 609 |
+
return ggml_view_3d(ctx, k,
|
| 610 |
+
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv,
|
| 611 |
+
ggml_row_size(k->type, hparams.n_embd_head_k),
|
| 612 |
+
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
|
| 613 |
+
0);
|
| 614 |
+
}
|
| 615 |
+
|
| 616 |
+
ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
|
| 617 |
+
const int32_t ikv = map_layer_ids.at(il);
|
| 618 |
+
|
| 619 |
+
auto * v = layers[ikv].v;
|
| 620 |
+
|
| 621 |
+
if (!v_trans) {
|
| 622 |
+
// note: v->nb[1] <= v->nb[2]
|
| 623 |
+
return ggml_view_3d(ctx, v,
|
| 624 |
+
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv,
|
| 625 |
+
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
|
| 626 |
+
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
|
| 627 |
+
0);
|
| 628 |
+
}
|
| 629 |
+
|
| 630 |
+
// note: v->nb[1] > v->nb[2]
|
| 631 |
+
return ggml_view_3d(ctx, v,
|
| 632 |
+
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v,
|
| 633 |
+
ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
|
| 634 |
+
ggml_row_size(v->type, v->ne[1]), // v->nb[2]
|
| 635 |
+
0);
|
| 636 |
+
}
|
| 637 |
+
|
| 638 |
+
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const {
|
| 639 |
+
const int32_t ikv = map_layer_ids.at(il);
|
| 640 |
+
|
| 641 |
+
auto * k = layers[ikv].k;
|
| 642 |
+
|
| 643 |
+
const int64_t n_tokens = k_cur->ne[2];
|
| 644 |
+
|
| 645 |
+
ggml_tensor * k_view = ggml_view_1d(ctx, k,
|
| 646 |
+
n_tokens*hparams.n_embd_k_gqa(il),
|
| 647 |
+
ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur);
|
| 648 |
+
|
| 649 |
+
return ggml_cpy(ctx, k_cur, k_view);
|
| 650 |
+
}
|
| 651 |
+
|
| 652 |
+
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const {
|
| 653 |
+
const int32_t ikv = map_layer_ids.at(il);
|
| 654 |
+
|
| 655 |
+
auto * v = layers[ikv].v;
|
| 656 |
+
|
| 657 |
+
const int64_t n_tokens = v_cur->ne[2];
|
| 658 |
+
|
| 659 |
+
v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens);
|
| 660 |
+
|
| 661 |
+
ggml_tensor * v_view = nullptr;
|
| 662 |
+
|
| 663 |
+
if (!v_trans) {
|
| 664 |
+
v_view = ggml_view_1d(ctx, v,
|
| 665 |
+
n_tokens*hparams.n_embd_v_gqa(il),
|
| 666 |
+
ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur);
|
| 667 |
+
} else {
|
| 668 |
+
// note: the V cache is transposed when not using flash attention
|
| 669 |
+
v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
|
| 670 |
+
(v->ne[1])*ggml_element_size(v),
|
| 671 |
+
(head_cur)*ggml_element_size(v));
|
| 672 |
+
|
| 673 |
+
v_cur = ggml_transpose(ctx, v_cur);
|
| 674 |
+
}
|
| 675 |
+
|
| 676 |
+
return ggml_cpy(ctx, v_cur, v_view);
|
| 677 |
+
}
|
| 678 |
+
|
| 679 |
+
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
| 680 |
+
const int64_t n_tokens = ubatch->n_tokens;
|
| 681 |
+
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
| 682 |
+
const int64_t n_seqs = ubatch->n_seqs;
|
| 683 |
+
|
| 684 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 685 |
+
float * data = (float *) dst->data;
|
| 686 |
+
|
| 687 |
+
const auto n_kv = dst->ne[0];
|
| 688 |
+
|
| 689 |
+
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
| 690 |
+
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
| 691 |
+
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
|
| 692 |
+
// Causal mask:
|
| 693 |
+
// xxx-------
|
| 694 |
+
// xxxx------
|
| 695 |
+
// xxxxx-----
|
| 696 |
+
// Non-causal mask:
|
| 697 |
+
// xxxxx-----
|
| 698 |
+
// xxxxx-----
|
| 699 |
+
// xxxxx-----
|
| 700 |
+
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
| 701 |
+
for (int h = 0; h < 1; ++h) {
|
| 702 |
+
for (int s = 0; s < n_seqs; ++s) {
|
| 703 |
+
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
| 704 |
+
|
| 705 |
+
for (int j = 0; j < n_seq_tokens; ++j) {
|
| 706 |
+
const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
|
| 707 |
+
|
| 708 |
+
for (uint32_t i = 0; i < n_kv; ++i) {
|
| 709 |
+
float f = 0.0f;
|
| 710 |
+
|
| 711 |
+
bool masked = false;
|
| 712 |
+
|
| 713 |
+
if (cells.is_empty(i)) {
|
| 714 |
+
masked = true;
|
| 715 |
+
} else {
|
| 716 |
+
const llama_pos p0 = cells.pos_get(i);
|
| 717 |
+
|
| 718 |
+
// mask the token if not the same sequence
|
| 719 |
+
masked = masked || (!cells.seq_has(i, seq_id));
|
| 720 |
+
|
| 721 |
+
// mask future tokens
|
| 722 |
+
masked = masked || (causal_attn && p0 > p1);
|
| 723 |
+
|
| 724 |
+
// apply SWA if any
|
| 725 |
+
masked = masked || (is_masked_swa(p0, p1));
|
| 726 |
+
|
| 727 |
+
if (!masked && hparams.use_alibi) {
|
| 728 |
+
f = -std::abs(p0 - p1);
|
| 729 |
+
}
|
| 730 |
+
}
|
| 731 |
+
|
| 732 |
+
if (masked) {
|
| 733 |
+
f = -INFINITY;
|
| 734 |
+
}
|
| 735 |
+
|
| 736 |
+
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
| 737 |
+
}
|
| 738 |
+
}
|
| 739 |
+
}
|
| 740 |
+
|
| 741 |
+
// mask padded tokens
|
| 742 |
+
if (data) {
|
| 743 |
+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
| 744 |
+
for (uint32_t j = 0; j < n_kv; ++j) {
|
| 745 |
+
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
| 746 |
+
}
|
| 747 |
+
}
|
| 748 |
+
}
|
| 749 |
+
}
|
| 750 |
+
}
|
| 751 |
+
|
| 752 |
+
void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
|
| 753 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 754 |
+
|
| 755 |
+
int32_t * data = (int32_t *) dst->data;
|
| 756 |
+
|
| 757 |
+
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 758 |
+
data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
|
| 759 |
+
}
|
| 760 |
+
}
|
| 761 |
+
|
| 762 |
+
void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
| 763 |
+
const int64_t n_tokens = ubatch->n_tokens;
|
| 764 |
+
|
| 765 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 766 |
+
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
| 767 |
+
|
| 768 |
+
int32_t * data = (int32_t *) dst->data;
|
| 769 |
+
|
| 770 |
+
const int32_t n_kv = dst->ne[0];
|
| 771 |
+
|
| 772 |
+
for (int h = 0; h < 1; ++h) {
|
| 773 |
+
for (int j = 0; j < n_tokens; ++j) {
|
| 774 |
+
for (int i = 0; i < n_kv; ++i) {
|
| 775 |
+
// the position when the cells is empty is irrelevant - it will be masked out later in the attention
|
| 776 |
+
const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i);
|
| 777 |
+
|
| 778 |
+
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
|
| 779 |
+
}
|
| 780 |
+
}
|
| 781 |
+
}
|
| 782 |
+
}
|
| 783 |
+
|
| 784 |
+
size_t llama_kv_cache_unified::total_size() const {
|
| 785 |
+
size_t size = 0;
|
| 786 |
+
|
| 787 |
+
for (const auto & buf : bufs) {
|
| 788 |
+
size += ggml_backend_buffer_get_size(buf.get());
|
| 789 |
+
}
|
| 790 |
+
|
| 791 |
+
return size;
|
| 792 |
+
}
|
| 793 |
+
|
| 794 |
+
size_t llama_kv_cache_unified::size_k_bytes() const {
|
| 795 |
+
size_t size_k_bytes = 0;
|
| 796 |
+
|
| 797 |
+
for (const auto & layer : layers) {
|
| 798 |
+
size_k_bytes += ggml_nbytes(layer.k);
|
| 799 |
+
}
|
| 800 |
+
|
| 801 |
+
return size_k_bytes;
|
| 802 |
+
}
|
| 803 |
+
|
| 804 |
+
size_t llama_kv_cache_unified::size_v_bytes() const {
|
| 805 |
+
size_t size_v_bytes = 0;
|
| 806 |
+
|
| 807 |
+
for (const auto & layer : layers) {
|
| 808 |
+
size_v_bytes += ggml_nbytes(layer.v);
|
| 809 |
+
}
|
| 810 |
+
|
| 811 |
+
return size_v_bytes;
|
| 812 |
+
}
|
| 813 |
+
|
| 814 |
+
ggml_tensor * llama_kv_cache_unified::build_rope_shift(
|
| 815 |
+
const llama_cparams & cparams,
|
| 816 |
+
ggml_context * ctx,
|
| 817 |
+
ggml_tensor * cur,
|
| 818 |
+
ggml_tensor * shift,
|
| 819 |
+
ggml_tensor * factors,
|
| 820 |
+
float freq_base,
|
| 821 |
+
float freq_scale) const {
|
| 822 |
+
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
|
| 823 |
+
|
| 824 |
+
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
|
| 825 |
+
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
|
| 826 |
+
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
|
| 827 |
+
|
| 828 |
+
const auto & n_rot = hparams.n_rot;
|
| 829 |
+
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
|
| 830 |
+
// @ngxson : this is a workaround
|
| 831 |
+
// for M-RoPE, we want to rotate the whole vector when doing KV shift
|
| 832 |
+
// a normal RoPE should work, we just need to use the correct ordering
|
| 833 |
+
// ref: https://github.com/ggml-org/llama.cpp/pull/13870
|
| 834 |
+
? LLAMA_ROPE_TYPE_NEOX
|
| 835 |
+
: hparams.rope_type;
|
| 836 |
+
|
| 837 |
+
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
|
| 838 |
+
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
|
| 839 |
+
const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
|
| 840 |
+
? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
|
| 841 |
+
: cparams.yarn_attn_factor;
|
| 842 |
+
|
| 843 |
+
ggml_tensor * tmp;
|
| 844 |
+
|
| 845 |
+
if (ggml_is_quantized(cur->type)) {
|
| 846 |
+
// dequantize to f32 -> RoPE -> quantize back
|
| 847 |
+
tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
|
| 848 |
+
|
| 849 |
+
tmp = ggml_rope_ext(ctx, tmp,
|
| 850 |
+
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
| 851 |
+
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
|
| 852 |
+
|
| 853 |
+
tmp = ggml_cpy(ctx, tmp, cur);
|
| 854 |
+
} else {
|
| 855 |
+
// we rotate only the first n_rot dimensions
|
| 856 |
+
tmp = ggml_rope_ext_inplace(ctx, cur,
|
| 857 |
+
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
| 858 |
+
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
|
| 859 |
+
}
|
| 860 |
+
|
| 861 |
+
return tmp;
|
| 862 |
+
}
|
| 863 |
+
|
| 864 |
+
class llm_graph_input_k_shift : public llm_graph_input_i {
|
| 865 |
+
public:
|
| 866 |
+
llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
|
| 867 |
+
virtual ~llm_graph_input_k_shift() = default;
|
| 868 |
+
|
| 869 |
+
void set_input(const llama_ubatch * ubatch) override;
|
| 870 |
+
|
| 871 |
+
ggml_tensor * k_shift; // I32 [kv_size]
|
| 872 |
+
|
| 873 |
+
const llama_kv_cache_unified * kv_self;
|
| 874 |
+
};
|
| 875 |
+
|
| 876 |
+
void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
|
| 877 |
+
GGML_UNUSED(ubatch);
|
| 878 |
+
|
| 879 |
+
if (k_shift) {
|
| 880 |
+
kv_self->set_input_k_shift(k_shift);
|
| 881 |
+
}
|
| 882 |
+
}
|
| 883 |
+
|
| 884 |
+
llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
| 885 |
+
const llama_cparams & cparams,
|
| 886 |
+
ggml_context * ctx,
|
| 887 |
+
ggml_cgraph * gf) const {
|
| 888 |
+
auto res = std::make_unique<llm_graph_result>();
|
| 889 |
+
|
| 890 |
+
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
| 891 |
+
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
| 892 |
+
|
| 893 |
+
//GGML_ASSERT(kv_self->size == n_ctx);
|
| 894 |
+
|
| 895 |
+
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
| 896 |
+
|
| 897 |
+
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
|
| 898 |
+
ggml_set_input(inp->k_shift);
|
| 899 |
+
|
| 900 |
+
for (const auto & layer : layers) {
|
| 901 |
+
const uint32_t il = layer.il;
|
| 902 |
+
|
| 903 |
+
const int64_t n_head_kv = hparams.n_head_kv(il);
|
| 904 |
+
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
| 905 |
+
|
| 906 |
+
const float freq_base_l = model.get_rope_freq_base (cparams, il);
|
| 907 |
+
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
| 908 |
+
|
| 909 |
+
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
|
| 910 |
+
|
| 911 |
+
ggml_tensor * k =
|
| 912 |
+
ggml_view_3d(ctx, layer.k,
|
| 913 |
+
n_embd_head_k, n_head_kv, cells.size(),
|
| 914 |
+
ggml_row_size(layer.k->type, n_embd_head_k),
|
| 915 |
+
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
| 916 |
+
0);
|
| 917 |
+
|
| 918 |
+
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
|
| 919 |
+
|
| 920 |
+
ggml_build_forward_expand(gf, cur);
|
| 921 |
+
}
|
| 922 |
+
|
| 923 |
+
res->add_input(std::move(inp));
|
| 924 |
+
|
| 925 |
+
return res;
|
| 926 |
+
}
|
| 927 |
+
|
| 928 |
+
llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
| 929 |
+
const llama_cparams & cparams,
|
| 930 |
+
ggml_context * ctx,
|
| 931 |
+
ggml_cgraph * gf) const {
|
| 932 |
+
auto res = std::make_unique<llm_graph_result>();
|
| 933 |
+
|
| 934 |
+
const auto & ids = defrag_info.ids;
|
| 935 |
+
|
| 936 |
+
#if 0
|
| 937 |
+
// CPU defrag
|
| 938 |
+
//
|
| 939 |
+
// TODO: optimizations are possible:
|
| 940 |
+
// - multiple threads
|
| 941 |
+
// - avoid copying to the host memory when already there
|
| 942 |
+
//
|
| 943 |
+
// likely not worth the effort, as we have ggml_graph based defrag
|
| 944 |
+
//
|
| 945 |
+
|
| 946 |
+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
| 947 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
| 948 |
+
|
| 949 |
+
const uint32_t kv_size = size;
|
| 950 |
+
|
| 951 |
+
std::vector<uint8_t> buf_k;
|
| 952 |
+
std::vector<uint8_t> buf_v;
|
| 953 |
+
|
| 954 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 955 |
+
const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
|
| 956 |
+
const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
|
| 957 |
+
|
| 958 |
+
const size_t v_size_el = ggml_type_size(v_l[il]->type);
|
| 959 |
+
const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
|
| 960 |
+
|
| 961 |
+
buf_k.resize(k_size);
|
| 962 |
+
buf_v.resize(v_size);
|
| 963 |
+
|
| 964 |
+
ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
|
| 965 |
+
ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
|
| 966 |
+
|
| 967 |
+
// batch move [i, i+nm) to [id, id+nm)
|
| 968 |
+
// note: cells can move only to a lower index
|
| 969 |
+
for (uint32_t i = 0; i < n_kv; ++i) {
|
| 970 |
+
const uint32_t id = ids[i];
|
| 971 |
+
|
| 972 |
+
if (i == id || id == n_kv) {
|
| 973 |
+
continue;
|
| 974 |
+
}
|
| 975 |
+
|
| 976 |
+
uint32_t nm = 1;
|
| 977 |
+
|
| 978 |
+
while (i + nm < n_kv && ids[i + nm] == id + nm) {
|
| 979 |
+
nm++;
|
| 980 |
+
}
|
| 981 |
+
|
| 982 |
+
// move keys
|
| 983 |
+
{
|
| 984 |
+
const int64_t os = i*k_size_row;
|
| 985 |
+
const int64_t od = id*k_size_row;
|
| 986 |
+
|
| 987 |
+
memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
|
| 988 |
+
}
|
| 989 |
+
|
| 990 |
+
// move values (note: they are transposed)
|
| 991 |
+
{
|
| 992 |
+
const int64_t os = i;
|
| 993 |
+
const int64_t od = id;
|
| 994 |
+
|
| 995 |
+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
| 996 |
+
memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
|
| 997 |
+
}
|
| 998 |
+
}
|
| 999 |
+
|
| 1000 |
+
i += nm - 1;
|
| 1001 |
+
}
|
| 1002 |
+
|
| 1003 |
+
ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
|
| 1004 |
+
ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
|
| 1005 |
+
}
|
| 1006 |
+
#else
|
| 1007 |
+
for (uint32_t i = 0; i < ids.size(); ++i) {
|
| 1008 |
+
const uint32_t id = ids[i];
|
| 1009 |
+
|
| 1010 |
+
if (i == id || id == ids.size()) {
|
| 1011 |
+
continue;
|
| 1012 |
+
}
|
| 1013 |
+
|
| 1014 |
+
uint32_t nm = 1;
|
| 1015 |
+
|
| 1016 |
+
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
|
| 1017 |
+
nm++;
|
| 1018 |
+
}
|
| 1019 |
+
|
| 1020 |
+
for (const auto & layer : layers) {
|
| 1021 |
+
const uint32_t il = layer.il;
|
| 1022 |
+
|
| 1023 |
+
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
| 1024 |
+
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
| 1025 |
+
|
| 1026 |
+
ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k,
|
| 1027 |
+
n_embd_k_gqa, nm,
|
| 1028 |
+
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
| 1029 |
+
ggml_row_size(layer.k->type, n_embd_k_gqa*i));
|
| 1030 |
+
|
| 1031 |
+
ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k,
|
| 1032 |
+
n_embd_k_gqa, nm,
|
| 1033 |
+
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
| 1034 |
+
ggml_row_size(layer.k->type, n_embd_k_gqa*id));
|
| 1035 |
+
|
| 1036 |
+
ggml_tensor * view_v_src;
|
| 1037 |
+
ggml_tensor * view_v_dst;
|
| 1038 |
+
|
| 1039 |
+
if (cparams.flash_attn) {
|
| 1040 |
+
// NOTE: the V cache is not transposed when using flash attention
|
| 1041 |
+
view_v_src = ggml_view_2d(ctx, layer.v,
|
| 1042 |
+
n_embd_v_gqa, nm,
|
| 1043 |
+
ggml_row_size(layer.v->type, n_embd_v_gqa),
|
| 1044 |
+
ggml_row_size(layer.v->type, n_embd_v_gqa*i));
|
| 1045 |
+
|
| 1046 |
+
view_v_dst = ggml_view_2d(ctx, layer.v,
|
| 1047 |
+
n_embd_v_gqa, nm,
|
| 1048 |
+
ggml_row_size(layer.v->type, n_embd_v_gqa),
|
| 1049 |
+
ggml_row_size(layer.v->type, n_embd_v_gqa*id));
|
| 1050 |
+
} else {
|
| 1051 |
+
view_v_src = ggml_view_2d(ctx, layer.v,
|
| 1052 |
+
nm, n_embd_v_gqa,
|
| 1053 |
+
ggml_row_size(layer.v->type, cells.size()),
|
| 1054 |
+
ggml_row_size(layer.v->type, i));
|
| 1055 |
+
|
| 1056 |
+
view_v_dst = ggml_view_2d(ctx, layer.v,
|
| 1057 |
+
nm, n_embd_v_gqa,
|
| 1058 |
+
ggml_row_size(layer.v->type, cells.size()),
|
| 1059 |
+
ggml_row_size(layer.v->type, id));
|
| 1060 |
+
}
|
| 1061 |
+
|
| 1062 |
+
ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
|
| 1063 |
+
ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst));
|
| 1064 |
+
}
|
| 1065 |
+
|
| 1066 |
+
i += nm - 1;
|
| 1067 |
+
}
|
| 1068 |
+
|
| 1069 |
+
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
|
| 1070 |
+
#endif
|
| 1071 |
+
|
| 1072 |
+
return res;
|
| 1073 |
+
}
|
| 1074 |
+
|
| 1075 |
+
bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
| 1076 |
+
const uint32_t n_layer = layers.size();
|
| 1077 |
+
|
| 1078 |
+
const uint32_t n_kv = cells.used_max_p1();
|
| 1079 |
+
const uint32_t n_used = cells.get_used();
|
| 1080 |
+
|
| 1081 |
+
assert(n_used <= n_kv);
|
| 1082 |
+
|
| 1083 |
+
//const int64_t t_start = ggml_time_us();
|
| 1084 |
+
|
| 1085 |
+
// number of cells moved
|
| 1086 |
+
uint32_t n_moves = 0;
|
| 1087 |
+
|
| 1088 |
+
// each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
|
| 1089 |
+
// - source view, destination view, copy operation
|
| 1090 |
+
// - x2 for keys and values
|
| 1091 |
+
//const uint32_t max_moves = max_nodes()/(6*n_layer);
|
| 1092 |
+
// TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
|
| 1093 |
+
const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
|
| 1094 |
+
|
| 1095 |
+
// determine which KV cells to move where
|
| 1096 |
+
//
|
| 1097 |
+
// cell i moves to ids[i]
|
| 1098 |
+
//
|
| 1099 |
+
// if ids[i] == i || ids[i] == n_kv, then cell i is not moved
|
| 1100 |
+
//
|
| 1101 |
+
auto & ids = defrag_info.ids;
|
| 1102 |
+
|
| 1103 |
+
ids.clear();
|
| 1104 |
+
ids.resize(n_kv, n_kv);
|
| 1105 |
+
|
| 1106 |
+
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
| 1107 |
+
if (!cells.is_empty(i0)) {
|
| 1108 |
+
ids[i0] = i0;
|
| 1109 |
+
|
| 1110 |
+
continue;
|
| 1111 |
+
}
|
| 1112 |
+
|
| 1113 |
+
// found a hole - fill it with data from the end of the cache
|
| 1114 |
+
|
| 1115 |
+
uint32_t nh = 1;
|
| 1116 |
+
|
| 1117 |
+
// determine the size of the hole
|
| 1118 |
+
while (i0 + nh < n_used && cells.is_empty(i0 + nh)) {
|
| 1119 |
+
nh++;
|
| 1120 |
+
}
|
| 1121 |
+
|
| 1122 |
+
uint32_t nf = 0;
|
| 1123 |
+
uint32_t is = n_kv - 1;
|
| 1124 |
+
|
| 1125 |
+
// starting from the end, find nh non-empty cells
|
| 1126 |
+
for (; is > i0; --is) {
|
| 1127 |
+
if (cells.is_empty(is) || ids[is] != n_kv) {
|
| 1128 |
+
continue;
|
| 1129 |
+
}
|
| 1130 |
+
|
| 1131 |
+
// non-empty cell which is not yet moved
|
| 1132 |
+
nf++;
|
| 1133 |
+
|
| 1134 |
+
if (nf == nh) {
|
| 1135 |
+
break;
|
| 1136 |
+
}
|
| 1137 |
+
}
|
| 1138 |
+
|
| 1139 |
+
// this can only happen if `n_used` is not accurate, which would be a bug
|
| 1140 |
+
GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
|
| 1141 |
+
|
| 1142 |
+
nf = 0;
|
| 1143 |
+
|
| 1144 |
+
uint32_t i1 = is;
|
| 1145 |
+
|
| 1146 |
+
// are we moving a continuous block of memory?
|
| 1147 |
+
bool cont = false;
|
| 1148 |
+
|
| 1149 |
+
// should we stop searching for the next move?
|
| 1150 |
+
bool stop = false;
|
| 1151 |
+
|
| 1152 |
+
// go back and move the nf cells to the hole
|
| 1153 |
+
for (; i1 < n_kv; ++i1) {
|
| 1154 |
+
if (cells.is_empty(i1) || ids[i1] != n_kv) {
|
| 1155 |
+
if (n_moves == max_moves) {
|
| 1156 |
+
stop = true;
|
| 1157 |
+
break;
|
| 1158 |
+
}
|
| 1159 |
+
|
| 1160 |
+
cont = false;
|
| 1161 |
+
continue;
|
| 1162 |
+
}
|
| 1163 |
+
|
| 1164 |
+
// this cell goes to (i0 + nf)
|
| 1165 |
+
ids[i1] = i0 + nf;
|
| 1166 |
+
|
| 1167 |
+
// move the cell meta data
|
| 1168 |
+
cells.mv(i1, i0 + nf);
|
| 1169 |
+
|
| 1170 |
+
head = n_used;
|
| 1171 |
+
|
| 1172 |
+
if (!cont) {
|
| 1173 |
+
n_moves++;
|
| 1174 |
+
cont = true;
|
| 1175 |
+
}
|
| 1176 |
+
|
| 1177 |
+
nf++;
|
| 1178 |
+
|
| 1179 |
+
if (nf == nh) {
|
| 1180 |
+
break;
|
| 1181 |
+
}
|
| 1182 |
+
}
|
| 1183 |
+
|
| 1184 |
+
if (stop || n_moves == max_moves) {
|
| 1185 |
+
break;
|
| 1186 |
+
}
|
| 1187 |
+
|
| 1188 |
+
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
|
| 1189 |
+
|
| 1190 |
+
i0 += nh - 1;
|
| 1191 |
+
}
|
| 1192 |
+
|
| 1193 |
+
if (n_moves == 0) {
|
| 1194 |
+
return false;
|
| 1195 |
+
}
|
| 1196 |
+
|
| 1197 |
+
LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
|
| 1198 |
+
|
| 1199 |
+
LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
|
| 1200 |
+
|
| 1201 |
+
return true;
|
| 1202 |
+
}
|
| 1203 |
+
|
| 1204 |
+
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
| 1205 |
+
assert(p0 >= 0 && p1 >= 0);
|
| 1206 |
+
|
| 1207 |
+
switch (swa_type) {
|
| 1208 |
+
case LLAMA_SWA_TYPE_NONE:
|
| 1209 |
+
{
|
| 1210 |
+
} break;
|
| 1211 |
+
case LLAMA_SWA_TYPE_STANDARD:
|
| 1212 |
+
{
|
| 1213 |
+
if (p1 - p0 >= (int32_t) n_swa) {
|
| 1214 |
+
return true;
|
| 1215 |
+
}
|
| 1216 |
+
} break;
|
| 1217 |
+
case LLAMA_SWA_TYPE_CHUNKED:
|
| 1218 |
+
{
|
| 1219 |
+
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
|
| 1220 |
+
|
| 1221 |
+
if (p0 < pos_chunk_start) {
|
| 1222 |
+
return true;
|
| 1223 |
+
}
|
| 1224 |
+
} break;
|
| 1225 |
+
}
|
| 1226 |
+
|
| 1227 |
+
return false;
|
| 1228 |
+
}
|
| 1229 |
+
|
| 1230 |
+
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
| 1231 |
+
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
| 1232 |
+
uint32_t cell_count = 0;
|
| 1233 |
+
|
| 1234 |
+
// Count the number of cells with the specified seq_id
|
| 1235 |
+
// Find all the ranges of cells with this seq id (or all, when -1)
|
| 1236 |
+
uint32_t cell_range_begin = cells.size();
|
| 1237 |
+
|
| 1238 |
+
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 1239 |
+
if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
|
| 1240 |
+
++cell_count;
|
| 1241 |
+
if (cell_range_begin == cells.size()) {
|
| 1242 |
+
cell_range_begin = i;
|
| 1243 |
+
}
|
| 1244 |
+
} else {
|
| 1245 |
+
if (cell_range_begin != cells.size()) {
|
| 1246 |
+
cell_ranges.emplace_back(cell_range_begin, i);
|
| 1247 |
+
cell_range_begin = cells.size();
|
| 1248 |
+
}
|
| 1249 |
+
}
|
| 1250 |
+
}
|
| 1251 |
+
|
| 1252 |
+
if (cell_range_begin != cells.size()) {
|
| 1253 |
+
cell_ranges.emplace_back(cell_range_begin, cells.size());
|
| 1254 |
+
}
|
| 1255 |
+
|
| 1256 |
+
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
|
| 1257 |
+
uint32_t cell_count_check = 0;
|
| 1258 |
+
for (const auto & range : cell_ranges) {
|
| 1259 |
+
cell_count_check += range.second - range.first;
|
| 1260 |
+
}
|
| 1261 |
+
GGML_ASSERT(cell_count == cell_count_check);
|
| 1262 |
+
|
| 1263 |
+
io.write(&cell_count, sizeof(cell_count));
|
| 1264 |
+
|
| 1265 |
+
state_write_meta(io, cell_ranges, seq_id);
|
| 1266 |
+
state_write_data(io, cell_ranges);
|
| 1267 |
+
}
|
| 1268 |
+
|
| 1269 |
+
void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
| 1270 |
+
uint32_t cell_count;
|
| 1271 |
+
io.read_to(&cell_count, sizeof(cell_count));
|
| 1272 |
+
|
| 1273 |
+
bool res = true;
|
| 1274 |
+
res = res && state_read_meta(io, cell_count, seq_id);
|
| 1275 |
+
res = res && state_read_data(io, cell_count);
|
| 1276 |
+
|
| 1277 |
+
if (!res) {
|
| 1278 |
+
if (seq_id == -1) {
|
| 1279 |
+
clear();
|
| 1280 |
+
} else {
|
| 1281 |
+
seq_rm(seq_id, -1, -1);
|
| 1282 |
+
}
|
| 1283 |
+
throw std::runtime_error("failed to restore kv cache");
|
| 1284 |
+
}
|
| 1285 |
+
}
|
| 1286 |
+
|
| 1287 |
+
void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
| 1288 |
+
for (const auto & range : cell_ranges) {
|
| 1289 |
+
for (uint32_t i = range.first; i < range.second; ++i) {
|
| 1290 |
+
std::vector<llama_seq_id> seq_ids;
|
| 1291 |
+
|
| 1292 |
+
for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
|
| 1293 |
+
if (cur == seq_id || seq_id == -1) {
|
| 1294 |
+
if (cells.seq_has(i, cur)) {
|
| 1295 |
+
seq_ids.push_back(cur);
|
| 1296 |
+
}
|
| 1297 |
+
}
|
| 1298 |
+
}
|
| 1299 |
+
|
| 1300 |
+
const llama_pos pos = cells.pos_get(i);
|
| 1301 |
+
const uint32_t n_seq_id = seq_ids.size();
|
| 1302 |
+
|
| 1303 |
+
io.write(&pos, sizeof(pos));
|
| 1304 |
+
io.write(&n_seq_id, sizeof(n_seq_id));
|
| 1305 |
+
|
| 1306 |
+
for (const auto & seq_id : seq_ids) {
|
| 1307 |
+
io.write(&seq_id, sizeof(seq_id));
|
| 1308 |
+
}
|
| 1309 |
+
}
|
| 1310 |
+
}
|
| 1311 |
+
}
|
| 1312 |
+
|
| 1313 |
+
void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
| 1314 |
+
const uint32_t v_trans = this->v_trans ? 1 : 0;
|
| 1315 |
+
const uint32_t n_layer = layers.size();
|
| 1316 |
+
|
| 1317 |
+
io.write(&v_trans, sizeof(v_trans));
|
| 1318 |
+
io.write(&n_layer, sizeof(n_layer));
|
| 1319 |
+
|
| 1320 |
+
std::vector<uint8_t> tmp_buf;
|
| 1321 |
+
|
| 1322 |
+
// Iterate and write all the keys first, each row is a cell
|
| 1323 |
+
// Get whole range at a time
|
| 1324 |
+
for (const auto & layer : layers) {
|
| 1325 |
+
const uint32_t il = layer.il;
|
| 1326 |
+
|
| 1327 |
+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
| 1328 |
+
|
| 1329 |
+
// Write key type
|
| 1330 |
+
const int32_t k_type_i = (int32_t)layer.k->type;
|
| 1331 |
+
io.write(&k_type_i, sizeof(k_type_i));
|
| 1332 |
+
|
| 1333 |
+
// Write row size of key
|
| 1334 |
+
const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
|
| 1335 |
+
io.write(&k_size_row, sizeof(k_size_row));
|
| 1336 |
+
|
| 1337 |
+
// Read each range of cells of k_size length each into tmp_buf and write out
|
| 1338 |
+
for (const auto & range : cell_ranges) {
|
| 1339 |
+
const size_t range_size = range.second - range.first;
|
| 1340 |
+
const size_t buf_size = range_size * k_size_row;
|
| 1341 |
+
io.write_tensor(layer.k, range.first * k_size_row, buf_size);
|
| 1342 |
+
}
|
| 1343 |
+
}
|
| 1344 |
+
|
| 1345 |
+
if (!v_trans) {
|
| 1346 |
+
for (const auto & layer : layers) {
|
| 1347 |
+
const uint32_t il = layer.il;
|
| 1348 |
+
|
| 1349 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 1350 |
+
|
| 1351 |
+
// Write value type
|
| 1352 |
+
const int32_t v_type_i = (int32_t)layer.v->type;
|
| 1353 |
+
io.write(&v_type_i, sizeof(v_type_i));
|
| 1354 |
+
|
| 1355 |
+
// Write row size of value
|
| 1356 |
+
const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
|
| 1357 |
+
io.write(&v_size_row, sizeof(v_size_row));
|
| 1358 |
+
|
| 1359 |
+
// Read each range of cells of v_size length each into tmp_buf and write out
|
| 1360 |
+
for (const auto & range : cell_ranges) {
|
| 1361 |
+
const size_t range_size = range.second - range.first;
|
| 1362 |
+
const size_t buf_size = range_size * v_size_row;
|
| 1363 |
+
io.write_tensor(layer.v, range.first * v_size_row, buf_size);
|
| 1364 |
+
}
|
| 1365 |
+
}
|
| 1366 |
+
} else {
|
| 1367 |
+
// When v is transposed, we also need the element size and get the element ranges from each row
|
| 1368 |
+
const uint32_t kv_size = cells.size();
|
| 1369 |
+
|
| 1370 |
+
for (const auto & layer : layers) {
|
| 1371 |
+
const uint32_t il = layer.il;
|
| 1372 |
+
|
| 1373 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 1374 |
+
|
| 1375 |
+
// Write value type
|
| 1376 |
+
const int32_t v_type_i = (int32_t)layer.v->type;
|
| 1377 |
+
io.write(&v_type_i, sizeof(v_type_i));
|
| 1378 |
+
|
| 1379 |
+
// Write element size
|
| 1380 |
+
const uint32_t v_size_el = ggml_type_size(layer.v->type);
|
| 1381 |
+
io.write(&v_size_el, sizeof(v_size_el));
|
| 1382 |
+
|
| 1383 |
+
// Write GQA embedding size
|
| 1384 |
+
io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
|
| 1385 |
+
|
| 1386 |
+
// For each row, we get the element values of each cell
|
| 1387 |
+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
| 1388 |
+
// Read each range of cells of v_size_el length each into tmp_buf and write out
|
| 1389 |
+
for (const auto & range : cell_ranges) {
|
| 1390 |
+
const size_t range_size = range.second - range.first;
|
| 1391 |
+
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
|
| 1392 |
+
const size_t buf_size = range_size * v_size_el;
|
| 1393 |
+
io.write_tensor(layer.v, src_offset, buf_size);
|
| 1394 |
+
}
|
| 1395 |
+
}
|
| 1396 |
+
}
|
| 1397 |
+
}
|
| 1398 |
+
}
|
| 1399 |
+
|
| 1400 |
+
bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
| 1401 |
+
if (dest_seq_id != -1) {
|
| 1402 |
+
// single sequence
|
| 1403 |
+
|
| 1404 |
+
seq_rm(dest_seq_id, -1, -1);
|
| 1405 |
+
|
| 1406 |
+
llama_sbatch sbatch;
|
| 1407 |
+
llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
| 1408 |
+
|
| 1409 |
+
batch.n_tokens = cell_count;
|
| 1410 |
+
|
| 1411 |
+
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 1412 |
+
llama_pos pos;
|
| 1413 |
+
uint32_t n_seq_id;
|
| 1414 |
+
|
| 1415 |
+
io.read_to(&pos, sizeof(pos));
|
| 1416 |
+
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
| 1417 |
+
|
| 1418 |
+
if (n_seq_id != 1) {
|
| 1419 |
+
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
| 1420 |
+
return false;
|
| 1421 |
+
}
|
| 1422 |
+
|
| 1423 |
+
// read the sequence id, but directly discard it - we will use dest_seq_id instead
|
| 1424 |
+
{
|
| 1425 |
+
llama_seq_id seq_id;
|
| 1426 |
+
io.read_to(&seq_id, sizeof(seq_id));
|
| 1427 |
+
}
|
| 1428 |
+
|
| 1429 |
+
batch.pos[i] = pos;
|
| 1430 |
+
batch.n_seq_id[i] = n_seq_id;
|
| 1431 |
+
batch.seq_id[i] = &dest_seq_id;
|
| 1432 |
+
}
|
| 1433 |
+
|
| 1434 |
+
const auto head_cur = find_slot(batch);
|
| 1435 |
+
if (head_cur < 0) {
|
| 1436 |
+
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
| 1437 |
+
return false;
|
| 1438 |
+
}
|
| 1439 |
+
|
| 1440 |
+
apply_ubatch(head_cur, batch);
|
| 1441 |
+
|
| 1442 |
+
// keep the head at the old position because we will read the KV data into it in state_read_data()
|
| 1443 |
+
head = head_cur;
|
| 1444 |
+
|
| 1445 |
+
// DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
| 1446 |
+
// Assume that this is one contiguous block of cells
|
| 1447 |
+
GGML_ASSERT(head_cur + cell_count <= cells.size());
|
| 1448 |
+
GGML_ASSERT(cells.pos_get(head_cur) == batch.pos[0]);
|
| 1449 |
+
GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == batch.pos[cell_count - 1]);
|
| 1450 |
+
GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
|
| 1451 |
+
GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
|
| 1452 |
+
} else {
|
| 1453 |
+
// whole KV cache restore
|
| 1454 |
+
|
| 1455 |
+
if (cell_count > cells.size()) {
|
| 1456 |
+
LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
|
| 1457 |
+
return false;
|
| 1458 |
+
}
|
| 1459 |
+
|
| 1460 |
+
clear();
|
| 1461 |
+
|
| 1462 |
+
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 1463 |
+
llama_pos pos;
|
| 1464 |
+
uint32_t n_seq_id;
|
| 1465 |
+
|
| 1466 |
+
io.read_to(&pos, sizeof(pos));
|
| 1467 |
+
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
| 1468 |
+
|
| 1469 |
+
cells.pos_set(i, pos);
|
| 1470 |
+
|
| 1471 |
+
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
| 1472 |
+
llama_seq_id seq_id;
|
| 1473 |
+
io.read_to(&seq_id, sizeof(seq_id));
|
| 1474 |
+
|
| 1475 |
+
if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
|
| 1476 |
+
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
|
| 1477 |
+
return false;
|
| 1478 |
+
}
|
| 1479 |
+
|
| 1480 |
+
cells.seq_add(i, seq_id);
|
| 1481 |
+
}
|
| 1482 |
+
}
|
| 1483 |
+
|
| 1484 |
+
head = 0;
|
| 1485 |
+
}
|
| 1486 |
+
|
| 1487 |
+
return true;
|
| 1488 |
+
}
|
| 1489 |
+
|
| 1490 |
+
bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
| 1491 |
+
uint32_t v_trans;
|
| 1492 |
+
uint32_t n_layer;
|
| 1493 |
+
|
| 1494 |
+
io.read_to(&v_trans, sizeof(v_trans));
|
| 1495 |
+
io.read_to(&n_layer, sizeof(n_layer));
|
| 1496 |
+
|
| 1497 |
+
if (n_layer != layers.size()) {
|
| 1498 |
+
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
|
| 1499 |
+
return false;
|
| 1500 |
+
}
|
| 1501 |
+
|
| 1502 |
+
if (cell_count > cells.size()) {
|
| 1503 |
+
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
|
| 1504 |
+
return false;
|
| 1505 |
+
}
|
| 1506 |
+
|
| 1507 |
+
if (this->v_trans != (bool) v_trans) {
|
| 1508 |
+
LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
|
| 1509 |
+
return false;
|
| 1510 |
+
}
|
| 1511 |
+
|
| 1512 |
+
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
|
| 1513 |
+
for (const auto & layer : layers) {
|
| 1514 |
+
const uint32_t il = layer.il;
|
| 1515 |
+
|
| 1516 |
+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
| 1517 |
+
|
| 1518 |
+
// Read type of key
|
| 1519 |
+
int32_t k_type_i_ref;
|
| 1520 |
+
io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
|
| 1521 |
+
const int32_t k_type_i = (int32_t) layer.k->type;
|
| 1522 |
+
if (k_type_i != k_type_i_ref) {
|
| 1523 |
+
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
|
| 1524 |
+
return false;
|
| 1525 |
+
}
|
| 1526 |
+
|
| 1527 |
+
// Read row size of key
|
| 1528 |
+
uint64_t k_size_row_ref;
|
| 1529 |
+
io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
|
| 1530 |
+
const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
|
| 1531 |
+
if (k_size_row != k_size_row_ref) {
|
| 1532 |
+
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
|
| 1533 |
+
return false;
|
| 1534 |
+
}
|
| 1535 |
+
|
| 1536 |
+
if (cell_count) {
|
| 1537 |
+
// Read and set the keys for the whole cell range
|
| 1538 |
+
ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
|
| 1539 |
+
}
|
| 1540 |
+
}
|
| 1541 |
+
|
| 1542 |
+
if (!this->v_trans) {
|
| 1543 |
+
for (const auto & layer : layers) {
|
| 1544 |
+
const uint32_t il = layer.il;
|
| 1545 |
+
|
| 1546 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 1547 |
+
|
| 1548 |
+
// Read type of value
|
| 1549 |
+
int32_t v_type_i_ref;
|
| 1550 |
+
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
| 1551 |
+
const int32_t v_type_i = (int32_t)layer.v->type;
|
| 1552 |
+
if (v_type_i != v_type_i_ref) {
|
| 1553 |
+
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
| 1554 |
+
return false;
|
| 1555 |
+
}
|
| 1556 |
+
|
| 1557 |
+
// Read row size of value
|
| 1558 |
+
uint64_t v_size_row_ref;
|
| 1559 |
+
io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
|
| 1560 |
+
const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
|
| 1561 |
+
if (v_size_row != v_size_row_ref) {
|
| 1562 |
+
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
|
| 1563 |
+
return false;
|
| 1564 |
+
}
|
| 1565 |
+
|
| 1566 |
+
if (cell_count) {
|
| 1567 |
+
// Read and set the values for the whole cell range
|
| 1568 |
+
ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
|
| 1569 |
+
}
|
| 1570 |
+
}
|
| 1571 |
+
} else {
|
| 1572 |
+
// For each layer, read the values for each cell (transposed)
|
| 1573 |
+
for (const auto & layer : layers) {
|
| 1574 |
+
const uint32_t il = layer.il;
|
| 1575 |
+
|
| 1576 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 1577 |
+
|
| 1578 |
+
// Read type of value
|
| 1579 |
+
int32_t v_type_i_ref;
|
| 1580 |
+
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
| 1581 |
+
const int32_t v_type_i = (int32_t)layer.v->type;
|
| 1582 |
+
if (v_type_i != v_type_i_ref) {
|
| 1583 |
+
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
| 1584 |
+
return false;
|
| 1585 |
+
}
|
| 1586 |
+
|
| 1587 |
+
// Read element size of value
|
| 1588 |
+
uint32_t v_size_el_ref;
|
| 1589 |
+
io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
|
| 1590 |
+
const size_t v_size_el = ggml_type_size(layer.v->type);
|
| 1591 |
+
if (v_size_el != v_size_el_ref) {
|
| 1592 |
+
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
|
| 1593 |
+
return false;
|
| 1594 |
+
}
|
| 1595 |
+
|
| 1596 |
+
// Read GQA embedding size
|
| 1597 |
+
uint32_t n_embd_v_gqa_ref;
|
| 1598 |
+
io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
|
| 1599 |
+
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
|
| 1600 |
+
LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
|
| 1601 |
+
return false;
|
| 1602 |
+
}
|
| 1603 |
+
|
| 1604 |
+
if (cell_count) {
|
| 1605 |
+
// For each row in the transposed matrix, read the values for the whole cell range
|
| 1606 |
+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
| 1607 |
+
const size_t dst_offset = (head + j * cells.size()) * v_size_el;
|
| 1608 |
+
ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
| 1609 |
+
}
|
| 1610 |
+
}
|
| 1611 |
+
}
|
| 1612 |
+
}
|
| 1613 |
+
|
| 1614 |
+
return true;
|
| 1615 |
+
}
|
| 1616 |
+
|
| 1617 |
+
//
|
| 1618 |
+
// llama_kv_cache_unified_state
|
| 1619 |
+
//
|
| 1620 |
+
|
| 1621 |
+
llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
|
| 1622 |
+
|
| 1623 |
+
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
| 1624 |
+
llama_memory_status status,
|
| 1625 |
+
llama_kv_cache_unified * kv) : status(status), kv(kv) {
|
| 1626 |
+
n_kv = kv->get_size();
|
| 1627 |
+
head = 0;
|
| 1628 |
+
}
|
| 1629 |
+
|
| 1630 |
+
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
| 1631 |
+
llama_memory_status status,
|
| 1632 |
+
llama_kv_cache_unified * kv,
|
| 1633 |
+
llama_sbatch sbatch,
|
| 1634 |
+
std::vector<uint32_t> heads,
|
| 1635 |
+
std::vector<llama_ubatch> ubatches)
|
| 1636 |
+
: status(status),
|
| 1637 |
+
kv(kv),
|
| 1638 |
+
sbatch(std::move(sbatch)),
|
| 1639 |
+
heads(std::move(heads)),
|
| 1640 |
+
ubatches(std::move(ubatches)) {
|
| 1641 |
+
}
|
| 1642 |
+
|
| 1643 |
+
llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
|
| 1644 |
+
|
| 1645 |
+
bool llama_kv_cache_unified_state::next() {
|
| 1646 |
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1647 |
+
|
| 1648 |
+
if (++i_next >= ubatches.size()) {
|
| 1649 |
+
return false;
|
| 1650 |
+
}
|
| 1651 |
+
|
| 1652 |
+
return true;
|
| 1653 |
+
}
|
| 1654 |
+
|
| 1655 |
+
bool llama_kv_cache_unified_state::apply() {
|
| 1656 |
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1657 |
+
|
| 1658 |
+
kv->apply_ubatch(heads[i_next], ubatches[i_next]);
|
| 1659 |
+
|
| 1660 |
+
n_kv = kv->get_n_kv();
|
| 1661 |
+
head = heads[i_next];
|
| 1662 |
+
|
| 1663 |
+
return true;
|
| 1664 |
+
}
|
| 1665 |
+
|
| 1666 |
+
std::vector<int64_t> & llama_kv_cache_unified_state::out_ids() {
|
| 1667 |
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1668 |
+
|
| 1669 |
+
return sbatch.out_ids;
|
| 1670 |
+
}
|
| 1671 |
+
|
| 1672 |
+
llama_memory_status llama_kv_cache_unified_state::get_status() const {
|
| 1673 |
+
return status;
|
| 1674 |
+
}
|
| 1675 |
+
|
| 1676 |
+
const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const {
|
| 1677 |
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1678 |
+
|
| 1679 |
+
return ubatches[i_next];
|
| 1680 |
+
}
|
| 1681 |
+
|
| 1682 |
+
uint32_t llama_kv_cache_unified_state::get_n_kv() const {
|
| 1683 |
+
return n_kv;
|
| 1684 |
+
}
|
| 1685 |
+
|
| 1686 |
+
ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const {
|
| 1687 |
+
return kv->get_k(ctx, il, n_kv);
|
| 1688 |
+
}
|
| 1689 |
+
|
| 1690 |
+
ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const {
|
| 1691 |
+
return kv->get_v(ctx, il, n_kv);
|
| 1692 |
+
}
|
| 1693 |
+
|
| 1694 |
+
ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
|
| 1695 |
+
return kv->cpy_k(ctx, k_cur, il, head);
|
| 1696 |
+
}
|
| 1697 |
+
|
| 1698 |
+
ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
|
| 1699 |
+
return kv->cpy_v(ctx, v_cur, il, head);
|
| 1700 |
+
}
|
| 1701 |
+
|
| 1702 |
+
void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const {
|
| 1703 |
+
kv->set_input_k_shift(dst);
|
| 1704 |
+
}
|
| 1705 |
+
|
| 1706 |
+
void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
| 1707 |
+
kv->set_input_kq_mask(dst, ubatch, causal_attn);
|
| 1708 |
+
}
|
| 1709 |
+
|
| 1710 |
+
void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
| 1711 |
+
kv->set_input_pos_bucket(dst, ubatch);
|
| 1712 |
+
}
|
| 1713 |
+
|
| 1714 |
+
uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
|
| 1715 |
+
// the FA kernels require padding to avoid extra runtime boundary checks
|
| 1716 |
+
return cparams.flash_attn ? 256u : 32u;
|
| 1717 |
+
}
|
examples/talk-llama/llama-kv-cache-unified.h
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "llama-batch.h"
|
| 4 |
+
#include "llama-graph.h"
|
| 5 |
+
#include "llama-kv-cache.h"
|
| 6 |
+
#include "llama-kv-cells.h"
|
| 7 |
+
|
| 8 |
+
#include <unordered_map>
|
| 9 |
+
#include <vector>
|
| 10 |
+
|
| 11 |
+
struct llama_cparams;
|
| 12 |
+
struct llama_hparams;
|
| 13 |
+
struct llama_model;
|
| 14 |
+
struct llama_context;
|
| 15 |
+
|
| 16 |
+
//
|
| 17 |
+
// llama_kv_cache_unified
|
| 18 |
+
//
|
| 19 |
+
|
| 20 |
+
class llama_kv_cache_unified : public llama_kv_cache {
|
| 21 |
+
public:
|
| 22 |
+
static uint32_t get_padding(const llama_cparams & cparams);
|
| 23 |
+
|
| 24 |
+
// this callback is used to filter out layers that should not be included in the cache
|
| 25 |
+
using layer_filter_cb = std::function<bool(int32_t il)>;
|
| 26 |
+
|
| 27 |
+
llama_kv_cache_unified(
|
| 28 |
+
const llama_model & model,
|
| 29 |
+
layer_filter_cb && filter,
|
| 30 |
+
ggml_type type_k,
|
| 31 |
+
ggml_type type_v,
|
| 32 |
+
bool v_trans,
|
| 33 |
+
bool offload,
|
| 34 |
+
uint32_t kv_size,
|
| 35 |
+
uint32_t n_seq_max,
|
| 36 |
+
uint32_t n_pad,
|
| 37 |
+
uint32_t n_swa,
|
| 38 |
+
llama_swa_type swa_type);
|
| 39 |
+
|
| 40 |
+
~llama_kv_cache_unified() = default;
|
| 41 |
+
|
| 42 |
+
//
|
| 43 |
+
// llama_memory_i
|
| 44 |
+
//
|
| 45 |
+
|
| 46 |
+
void clear() override;
|
| 47 |
+
|
| 48 |
+
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
| 49 |
+
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
| 50 |
+
void seq_keep(llama_seq_id seq_id) override;
|
| 51 |
+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
| 52 |
+
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
| 53 |
+
|
| 54 |
+
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
| 55 |
+
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
| 56 |
+
|
| 57 |
+
//
|
| 58 |
+
// llama_kv_cache
|
| 59 |
+
//
|
| 60 |
+
|
| 61 |
+
llama_memory_state_ptr init_batch(
|
| 62 |
+
const llama_batch & batch,
|
| 63 |
+
uint32_t n_ubatch,
|
| 64 |
+
bool embd_pooled,
|
| 65 |
+
bool logits_all) override;
|
| 66 |
+
|
| 67 |
+
llama_memory_state_ptr init_full() override;
|
| 68 |
+
|
| 69 |
+
bool update(llama_context & lctx) override;
|
| 70 |
+
|
| 71 |
+
void defrag_sched(float thold) override;
|
| 72 |
+
|
| 73 |
+
bool get_can_shift() const override;
|
| 74 |
+
|
| 75 |
+
// state write/load
|
| 76 |
+
|
| 77 |
+
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
| 78 |
+
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
| 79 |
+
|
| 80 |
+
//
|
| 81 |
+
// llama_kv_cache_unified specific API
|
| 82 |
+
//
|
| 83 |
+
|
| 84 |
+
uint32_t get_size() const;
|
| 85 |
+
|
| 86 |
+
//
|
| 87 |
+
// graph_build API
|
| 88 |
+
//
|
| 89 |
+
|
| 90 |
+
uint32_t get_n_kv() const;
|
| 91 |
+
|
| 92 |
+
// get views of the current state of the cache
|
| 93 |
+
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
| 94 |
+
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
| 95 |
+
|
| 96 |
+
// store k_cur and v_cur in the cache based on the provided head location
|
| 97 |
+
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const;
|
| 98 |
+
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const;
|
| 99 |
+
|
| 100 |
+
//
|
| 101 |
+
// preparation API
|
| 102 |
+
//
|
| 103 |
+
|
| 104 |
+
// find places for the provided ubatches in the cache, returns the head locations
|
| 105 |
+
// return empty vector on failure
|
| 106 |
+
std::vector<uint32_t> prepare(const std::vector<llama_ubatch> & ubatches);
|
| 107 |
+
|
| 108 |
+
// return the cell position where we can insert the ubatch
|
| 109 |
+
// return -1 on failure to find a contiguous slot of kv cells
|
| 110 |
+
int32_t find_slot(const llama_ubatch & ubatch) const;
|
| 111 |
+
|
| 112 |
+
// emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens)
|
| 113 |
+
void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch);
|
| 114 |
+
|
| 115 |
+
//
|
| 116 |
+
// set_input API
|
| 117 |
+
//
|
| 118 |
+
|
| 119 |
+
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
| 120 |
+
void set_input_k_shift (ggml_tensor * dst) const;
|
| 121 |
+
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
| 122 |
+
|
| 123 |
+
private:
|
| 124 |
+
const llama_model & model;
|
| 125 |
+
const llama_hparams & hparams;
|
| 126 |
+
|
| 127 |
+
struct kv_layer {
|
| 128 |
+
// layer index in the model
|
| 129 |
+
// note: can be different from the layer index in the KV cache
|
| 130 |
+
uint32_t il;
|
| 131 |
+
|
| 132 |
+
ggml_tensor * k;
|
| 133 |
+
ggml_tensor * v;
|
| 134 |
+
};
|
| 135 |
+
|
| 136 |
+
bool do_defrag = false;
|
| 137 |
+
bool v_trans = true; // the value tensor is transposed
|
| 138 |
+
|
| 139 |
+
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
| 140 |
+
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
| 141 |
+
uint32_t head = 0;
|
| 142 |
+
|
| 143 |
+
const uint32_t n_seq_max = 1;
|
| 144 |
+
|
| 145 |
+
// required padding
|
| 146 |
+
const uint32_t n_pad = 1;
|
| 147 |
+
|
| 148 |
+
// SWA
|
| 149 |
+
const uint32_t n_swa = 0;
|
| 150 |
+
|
| 151 |
+
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
| 152 |
+
|
| 153 |
+
std::vector<ggml_context_ptr> ctxs;
|
| 154 |
+
std::vector<ggml_backend_buffer_ptr> bufs;
|
| 155 |
+
|
| 156 |
+
llama_kv_cells_unified cells;
|
| 157 |
+
|
| 158 |
+
std::vector<kv_layer> layers;
|
| 159 |
+
|
| 160 |
+
// model layer id -> KV cache layer id
|
| 161 |
+
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
| 162 |
+
|
| 163 |
+
// defrag
|
| 164 |
+
struct {
|
| 165 |
+
std::vector<uint32_t> ids;
|
| 166 |
+
} defrag_info;
|
| 167 |
+
|
| 168 |
+
// return true if cells have been moved
|
| 169 |
+
bool defrag_prepare(int32_t n_max_nodes);
|
| 170 |
+
|
| 171 |
+
size_t total_size() const;
|
| 172 |
+
|
| 173 |
+
size_t size_k_bytes() const;
|
| 174 |
+
size_t size_v_bytes() const;
|
| 175 |
+
|
| 176 |
+
bool is_masked_swa(llama_pos p0, llama_pos p1) const;
|
| 177 |
+
|
| 178 |
+
ggml_tensor * build_rope_shift(
|
| 179 |
+
const llama_cparams & cparams,
|
| 180 |
+
ggml_context * ctx,
|
| 181 |
+
ggml_tensor * cur,
|
| 182 |
+
ggml_tensor * shift,
|
| 183 |
+
ggml_tensor * factors,
|
| 184 |
+
float freq_base,
|
| 185 |
+
float freq_scale) const;
|
| 186 |
+
|
| 187 |
+
llm_graph_result_ptr build_graph_shift(
|
| 188 |
+
const llama_cparams & cparams,
|
| 189 |
+
ggml_context * ctx,
|
| 190 |
+
ggml_cgraph * gf) const;
|
| 191 |
+
|
| 192 |
+
llm_graph_result_ptr build_graph_defrag(
|
| 193 |
+
const llama_cparams & cparams,
|
| 194 |
+
ggml_context * ctx,
|
| 195 |
+
ggml_cgraph * gf) const;
|
| 196 |
+
|
| 197 |
+
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
| 198 |
+
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
| 199 |
+
|
| 200 |
+
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
| 201 |
+
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
| 202 |
+
};
|
| 203 |
+
|
| 204 |
+
class llama_kv_cache_unified_state : public llama_memory_state_i {
|
| 205 |
+
public:
|
| 206 |
+
// used for errors
|
| 207 |
+
llama_kv_cache_unified_state(llama_memory_status status);
|
| 208 |
+
|
| 209 |
+
// used to create a full-cache state
|
| 210 |
+
llama_kv_cache_unified_state(
|
| 211 |
+
llama_memory_status status,
|
| 212 |
+
llama_kv_cache_unified * kv);
|
| 213 |
+
|
| 214 |
+
// used to create a state from a batch
|
| 215 |
+
llama_kv_cache_unified_state(
|
| 216 |
+
llama_memory_status status,
|
| 217 |
+
llama_kv_cache_unified * kv,
|
| 218 |
+
llama_sbatch sbatch,
|
| 219 |
+
std::vector<uint32_t> heads,
|
| 220 |
+
std::vector<llama_ubatch> ubatches);
|
| 221 |
+
|
| 222 |
+
virtual ~llama_kv_cache_unified_state();
|
| 223 |
+
|
| 224 |
+
//
|
| 225 |
+
// llama_memory_state_i
|
| 226 |
+
//
|
| 227 |
+
|
| 228 |
+
bool next() override;
|
| 229 |
+
bool apply() override;
|
| 230 |
+
|
| 231 |
+
std::vector<int64_t> & out_ids() override;
|
| 232 |
+
|
| 233 |
+
llama_memory_status get_status() const override;
|
| 234 |
+
const llama_ubatch & get_ubatch() const override;
|
| 235 |
+
|
| 236 |
+
//
|
| 237 |
+
// llama_kv_cache_unified_state specific API
|
| 238 |
+
//
|
| 239 |
+
|
| 240 |
+
uint32_t get_n_kv() const;
|
| 241 |
+
|
| 242 |
+
// get views of the current state of the cache
|
| 243 |
+
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
|
| 244 |
+
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
| 245 |
+
|
| 246 |
+
// store k_cur and v_cur in the cache based on the provided head location
|
| 247 |
+
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
|
| 248 |
+
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
|
| 249 |
+
|
| 250 |
+
void set_input_k_shift(ggml_tensor * dst) const;
|
| 251 |
+
|
| 252 |
+
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
| 253 |
+
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
| 254 |
+
|
| 255 |
+
private:
|
| 256 |
+
const llama_memory_status status;
|
| 257 |
+
|
| 258 |
+
llama_kv_cache_unified * kv;
|
| 259 |
+
|
| 260 |
+
llama_sbatch sbatch;
|
| 261 |
+
|
| 262 |
+
// the index of the next ubatch to process
|
| 263 |
+
size_t i_next = 0;
|
| 264 |
+
|
| 265 |
+
std::vector<uint32_t> heads;
|
| 266 |
+
std::vector<llama_ubatch> ubatches;
|
| 267 |
+
|
| 268 |
+
//
|
| 269 |
+
// data needed for building the compute graph for the current ubatch:
|
| 270 |
+
//
|
| 271 |
+
|
| 272 |
+
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
| 273 |
+
// as the cache gets filled, the benefit from this heuristic disappears
|
| 274 |
+
int32_t n_kv;
|
| 275 |
+
|
| 276 |
+
// the beginning of the current slot in which the ubatch will be inserted
|
| 277 |
+
int32_t head;
|
| 278 |
+
};
|
examples/talk-llama/llama-kv-cache.cpp
CHANGED
|
@@ -1,2739 +1 @@
|
|
| 1 |
#include "llama-kv-cache.h"
|
| 2 |
-
|
| 3 |
-
#include "llama-impl.h"
|
| 4 |
-
#include "llama-batch.h"
|
| 5 |
-
#include "llama-cparams.h"
|
| 6 |
-
#include "llama-model.h"
|
| 7 |
-
#include "llama-context.h"
|
| 8 |
-
|
| 9 |
-
#include <algorithm>
|
| 10 |
-
#include <cassert>
|
| 11 |
-
#include <cmath>
|
| 12 |
-
#include <limits>
|
| 13 |
-
#include <map>
|
| 14 |
-
#include <stdexcept>
|
| 15 |
-
|
| 16 |
-
//
|
| 17 |
-
// llama_kv_cache_unified
|
| 18 |
-
//
|
| 19 |
-
|
| 20 |
-
uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
|
| 21 |
-
// the FA kernels require padding to avoid extra runtime boundary checks
|
| 22 |
-
return cparams.flash_attn ? 256u : 32u;
|
| 23 |
-
}
|
| 24 |
-
|
| 25 |
-
llama_kv_cache_unified::llama_kv_cache_unified(
|
| 26 |
-
const llama_model & model,
|
| 27 |
-
layer_filter_cb && filter,
|
| 28 |
-
ggml_type type_k,
|
| 29 |
-
ggml_type type_v,
|
| 30 |
-
bool v_trans,
|
| 31 |
-
bool offload,
|
| 32 |
-
uint32_t kv_size,
|
| 33 |
-
uint32_t n_seq_max,
|
| 34 |
-
uint32_t n_pad,
|
| 35 |
-
uint32_t n_swa,
|
| 36 |
-
llama_swa_type swa_type) :
|
| 37 |
-
model(model), hparams(model.hparams), v_trans(v_trans),
|
| 38 |
-
n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
|
| 39 |
-
|
| 40 |
-
GGML_ASSERT(kv_size % n_pad == 0);
|
| 41 |
-
|
| 42 |
-
// create a context for each buffer type
|
| 43 |
-
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
| 44 |
-
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
| 45 |
-
auto it = ctx_map.find(buft);
|
| 46 |
-
if (it == ctx_map.end()) {
|
| 47 |
-
ggml_init_params params = {
|
| 48 |
-
/*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()),
|
| 49 |
-
/*.mem_buffer =*/ NULL,
|
| 50 |
-
/*.no_alloc =*/ true,
|
| 51 |
-
};
|
| 52 |
-
|
| 53 |
-
ggml_context * ctx = ggml_init(params);
|
| 54 |
-
if (!ctx) {
|
| 55 |
-
return nullptr;
|
| 56 |
-
}
|
| 57 |
-
|
| 58 |
-
ctx_map[buft] = ctx;
|
| 59 |
-
ctxs.emplace_back(ctx);
|
| 60 |
-
|
| 61 |
-
return ctx;
|
| 62 |
-
}
|
| 63 |
-
|
| 64 |
-
return it->second;
|
| 65 |
-
};
|
| 66 |
-
|
| 67 |
-
head = 0;
|
| 68 |
-
|
| 69 |
-
cells.resize(kv_size);
|
| 70 |
-
|
| 71 |
-
for (uint32_t il = 0; il < hparams.n_layer; il++) {
|
| 72 |
-
if (filter && !filter(il)) {
|
| 73 |
-
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
|
| 74 |
-
continue;
|
| 75 |
-
}
|
| 76 |
-
|
| 77 |
-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
| 78 |
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 79 |
-
|
| 80 |
-
const char * dev_name = "CPU";
|
| 81 |
-
|
| 82 |
-
ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
|
| 83 |
-
|
| 84 |
-
if (offload) {
|
| 85 |
-
auto * dev = model.dev_layer(il);
|
| 86 |
-
buft = ggml_backend_dev_buffer_type(dev);
|
| 87 |
-
|
| 88 |
-
dev_name = ggml_backend_dev_name(dev);
|
| 89 |
-
}
|
| 90 |
-
|
| 91 |
-
LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name);
|
| 92 |
-
|
| 93 |
-
ggml_context * ctx = ctx_for_buft(buft);
|
| 94 |
-
if (!ctx) {
|
| 95 |
-
throw std::runtime_error("failed to create ggml context for kv cache");
|
| 96 |
-
}
|
| 97 |
-
|
| 98 |
-
ggml_tensor * k;
|
| 99 |
-
ggml_tensor * v;
|
| 100 |
-
|
| 101 |
-
k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
|
| 102 |
-
v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
|
| 103 |
-
|
| 104 |
-
ggml_format_name(k, "cache_k_l%d", il);
|
| 105 |
-
ggml_format_name(v, "cache_v_l%d", il);
|
| 106 |
-
|
| 107 |
-
map_layer_ids[il] = layers.size();
|
| 108 |
-
layers.push_back({ il, k, v });
|
| 109 |
-
}
|
| 110 |
-
|
| 111 |
-
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
| 112 |
-
for (auto it : ctx_map) {
|
| 113 |
-
auto * buft = it.first;
|
| 114 |
-
auto * ctx = it.second;
|
| 115 |
-
|
| 116 |
-
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
| 117 |
-
if (!buf) {
|
| 118 |
-
throw std::runtime_error("failed to allocate buffer for kv cache");
|
| 119 |
-
}
|
| 120 |
-
|
| 121 |
-
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
| 122 |
-
|
| 123 |
-
ggml_backend_buffer_clear(buf, 0);
|
| 124 |
-
bufs.emplace_back(buf);
|
| 125 |
-
}
|
| 126 |
-
|
| 127 |
-
{
|
| 128 |
-
const size_t memory_size_k = size_k_bytes();
|
| 129 |
-
const size_t memory_size_v = size_v_bytes();
|
| 130 |
-
|
| 131 |
-
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
| 132 |
-
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
|
| 133 |
-
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
| 134 |
-
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
| 135 |
-
}
|
| 136 |
-
}
|
| 137 |
-
|
| 138 |
-
void llama_kv_cache_unified::clear() {
|
| 139 |
-
cells.reset();
|
| 140 |
-
|
| 141 |
-
head = 0;
|
| 142 |
-
|
| 143 |
-
for (auto & buf : bufs) {
|
| 144 |
-
ggml_backend_buffer_clear(buf.get(), 0);
|
| 145 |
-
}
|
| 146 |
-
}
|
| 147 |
-
|
| 148 |
-
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
| 149 |
-
uint32_t new_head = cells.size();
|
| 150 |
-
|
| 151 |
-
if (p0 < 0) {
|
| 152 |
-
p0 = 0;
|
| 153 |
-
}
|
| 154 |
-
|
| 155 |
-
if (p1 < 0) {
|
| 156 |
-
p1 = std::numeric_limits<llama_pos>::max();
|
| 157 |
-
}
|
| 158 |
-
|
| 159 |
-
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 160 |
-
if (!cells.pos_in(i, p0, p1)) {
|
| 161 |
-
continue;
|
| 162 |
-
}
|
| 163 |
-
|
| 164 |
-
if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
|
| 165 |
-
if (new_head == cells.size()) {
|
| 166 |
-
new_head = i;
|
| 167 |
-
}
|
| 168 |
-
}
|
| 169 |
-
}
|
| 170 |
-
|
| 171 |
-
// If we freed up a slot, set head to it so searching can start there.
|
| 172 |
-
if (new_head != cells.size() && new_head < head) {
|
| 173 |
-
head = new_head;
|
| 174 |
-
}
|
| 175 |
-
|
| 176 |
-
return true;
|
| 177 |
-
}
|
| 178 |
-
|
| 179 |
-
void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
| 180 |
-
if (seq_id_src == seq_id_dst) {
|
| 181 |
-
return;
|
| 182 |
-
}
|
| 183 |
-
|
| 184 |
-
if (p0 < 0) {
|
| 185 |
-
p0 = 0;
|
| 186 |
-
}
|
| 187 |
-
|
| 188 |
-
if (p1 < 0) {
|
| 189 |
-
p1 = std::numeric_limits<llama_pos>::max();
|
| 190 |
-
}
|
| 191 |
-
|
| 192 |
-
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 193 |
-
if (!cells.pos_in(i, p0, p1)) {
|
| 194 |
-
continue;
|
| 195 |
-
}
|
| 196 |
-
|
| 197 |
-
if (cells.seq_has(i, seq_id_src)) {
|
| 198 |
-
cells.seq_add(i, seq_id_dst);
|
| 199 |
-
}
|
| 200 |
-
}
|
| 201 |
-
}
|
| 202 |
-
|
| 203 |
-
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
| 204 |
-
uint32_t new_head = cells.size();
|
| 205 |
-
|
| 206 |
-
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 207 |
-
if (cells.seq_keep(i, seq_id)) {
|
| 208 |
-
if (new_head == cells.size()) {
|
| 209 |
-
new_head = i;
|
| 210 |
-
}
|
| 211 |
-
}
|
| 212 |
-
}
|
| 213 |
-
|
| 214 |
-
// If we freed up a slot, set head to it so searching can start there.
|
| 215 |
-
if (new_head != cells.size() && new_head < head) {
|
| 216 |
-
head = new_head;
|
| 217 |
-
}
|
| 218 |
-
}
|
| 219 |
-
|
| 220 |
-
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
| 221 |
-
if (shift == 0) {
|
| 222 |
-
return;
|
| 223 |
-
}
|
| 224 |
-
|
| 225 |
-
uint32_t new_head = cells.size();
|
| 226 |
-
|
| 227 |
-
if (p0 < 0) {
|
| 228 |
-
p0 = 0;
|
| 229 |
-
}
|
| 230 |
-
|
| 231 |
-
if (p1 < 0) {
|
| 232 |
-
p1 = std::numeric_limits<llama_pos>::max();
|
| 233 |
-
}
|
| 234 |
-
|
| 235 |
-
// If there is no range then return early to avoid looping over all cells.
|
| 236 |
-
if (p0 == p1) {
|
| 237 |
-
return;
|
| 238 |
-
}
|
| 239 |
-
|
| 240 |
-
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 241 |
-
if (!cells.pos_in(i, p0, p1)) {
|
| 242 |
-
continue;
|
| 243 |
-
}
|
| 244 |
-
|
| 245 |
-
if (cells.seq_has(i, seq_id)) {
|
| 246 |
-
if (cells.pos_add(i, shift)) {
|
| 247 |
-
if (new_head == cells.size()) {
|
| 248 |
-
new_head = i;
|
| 249 |
-
}
|
| 250 |
-
}
|
| 251 |
-
}
|
| 252 |
-
}
|
| 253 |
-
|
| 254 |
-
// If we freed up a slot, set head to it so searching can start there.
|
| 255 |
-
// Otherwise we just start the next search from the beginning.
|
| 256 |
-
head = new_head != cells.size() ? new_head : 0;
|
| 257 |
-
}
|
| 258 |
-
|
| 259 |
-
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
| 260 |
-
if (d == 1) {
|
| 261 |
-
return;
|
| 262 |
-
}
|
| 263 |
-
|
| 264 |
-
if (p0 < 0) {
|
| 265 |
-
p0 = 0;
|
| 266 |
-
}
|
| 267 |
-
|
| 268 |
-
if (p1 < 0) {
|
| 269 |
-
p1 = std::numeric_limits<llama_pos>::max();
|
| 270 |
-
}
|
| 271 |
-
|
| 272 |
-
// If there is no range then return early to avoid looping over the cache.
|
| 273 |
-
if (p0 == p1) {
|
| 274 |
-
return;
|
| 275 |
-
}
|
| 276 |
-
|
| 277 |
-
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 278 |
-
if (!cells.pos_in(i, p0, p1)) {
|
| 279 |
-
continue;
|
| 280 |
-
}
|
| 281 |
-
|
| 282 |
-
if (cells.seq_has(i, seq_id)) {
|
| 283 |
-
cells.pos_div(i, d);
|
| 284 |
-
}
|
| 285 |
-
}
|
| 286 |
-
}
|
| 287 |
-
|
| 288 |
-
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
|
| 289 |
-
return cells.seq_pos_min(seq_id);
|
| 290 |
-
}
|
| 291 |
-
|
| 292 |
-
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
| 293 |
-
return cells.seq_pos_max(seq_id);
|
| 294 |
-
}
|
| 295 |
-
|
| 296 |
-
void llama_kv_cache_unified::restore() {
|
| 297 |
-
for (auto & state : recovery.states) {
|
| 298 |
-
cells.set(state.i, state.cells);
|
| 299 |
-
}
|
| 300 |
-
|
| 301 |
-
recovery.clear();
|
| 302 |
-
}
|
| 303 |
-
|
| 304 |
-
void llama_kv_cache_unified::commit() {
|
| 305 |
-
if (recovery.states.empty()) {
|
| 306 |
-
LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n",
|
| 307 |
-
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194");
|
| 308 |
-
return;
|
| 309 |
-
}
|
| 310 |
-
|
| 311 |
-
recovery.clear();
|
| 312 |
-
}
|
| 313 |
-
|
| 314 |
-
bool llama_kv_cache_unified::update(llama_context & lctx) {
|
| 315 |
-
bool need_reserve = false;
|
| 316 |
-
|
| 317 |
-
auto * sched = lctx.get_sched();
|
| 318 |
-
|
| 319 |
-
if (cells.get_has_shift()) {
|
| 320 |
-
if (!get_can_shift()) {
|
| 321 |
-
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
|
| 322 |
-
}
|
| 323 |
-
|
| 324 |
-
LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
|
| 325 |
-
|
| 326 |
-
// apply K-shift if needed
|
| 327 |
-
if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
| 328 |
-
ggml_backend_sched_reset(sched);
|
| 329 |
-
|
| 330 |
-
auto * gf = lctx.graph_init();
|
| 331 |
-
|
| 332 |
-
auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
|
| 333 |
-
|
| 334 |
-
ggml_backend_sched_alloc_graph(sched, gf);
|
| 335 |
-
|
| 336 |
-
res->set_inputs(nullptr);
|
| 337 |
-
|
| 338 |
-
lctx.graph_compute(gf, false);
|
| 339 |
-
|
| 340 |
-
need_reserve = true;
|
| 341 |
-
}
|
| 342 |
-
|
| 343 |
-
cells.reset_shift();
|
| 344 |
-
}
|
| 345 |
-
|
| 346 |
-
if (do_defrag) {
|
| 347 |
-
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
| 348 |
-
|
| 349 |
-
if (defrag_prepare(lctx.graph_max_nodes())) {
|
| 350 |
-
ggml_backend_sched_reset(sched);
|
| 351 |
-
|
| 352 |
-
auto * gf = lctx.graph_init();
|
| 353 |
-
|
| 354 |
-
auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
|
| 355 |
-
|
| 356 |
-
ggml_backend_sched_alloc_graph(sched, gf);
|
| 357 |
-
|
| 358 |
-
res->set_inputs(nullptr);
|
| 359 |
-
|
| 360 |
-
lctx.graph_compute(gf, false);
|
| 361 |
-
|
| 362 |
-
need_reserve = true;
|
| 363 |
-
}
|
| 364 |
-
|
| 365 |
-
do_defrag = false;
|
| 366 |
-
}
|
| 367 |
-
|
| 368 |
-
return need_reserve;
|
| 369 |
-
}
|
| 370 |
-
|
| 371 |
-
void llama_kv_cache_unified::defrag_sched(float thold) {
|
| 372 |
-
// - do not defrag small contexts (i.e. < 2048 tokens)
|
| 373 |
-
// - count the padding towards the number of used tokens
|
| 374 |
-
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n)) : 0.0f;
|
| 375 |
-
|
| 376 |
-
// queue defragmentation for next llama_kv_cache_update
|
| 377 |
-
if (fragmentation > thold) {
|
| 378 |
-
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
| 379 |
-
|
| 380 |
-
do_defrag = true;
|
| 381 |
-
}
|
| 382 |
-
}
|
| 383 |
-
|
| 384 |
-
void llama_kv_cache_unified::set_full() {
|
| 385 |
-
n = cells.size();
|
| 386 |
-
|
| 387 |
-
// when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
|
| 388 |
-
// affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
|
| 389 |
-
// we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so
|
| 390 |
-
// setting it to 0 is the simplest way to achieve that
|
| 391 |
-
// ref: https://github.com/ggml-org/llama.cpp/issues/13359
|
| 392 |
-
head = 0;
|
| 393 |
-
}
|
| 394 |
-
|
| 395 |
-
llama_sbatch llama_kv_cache_unified::sbatch_init(const llama_batch & batch, bool logits_all) {
|
| 396 |
-
return llama_sbatch(batch, hparams.n_embd, true, logits_all);
|
| 397 |
-
}
|
| 398 |
-
|
| 399 |
-
llama_ubatch llama_kv_cache_unified::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
|
| 400 |
-
GGML_UNUSED(embd_pooled);
|
| 401 |
-
return sbatch.split_simple(n_ubatch);
|
| 402 |
-
}
|
| 403 |
-
|
| 404 |
-
bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
|
| 405 |
-
const uint32_t n_tokens = ubatch.n_tokens;
|
| 406 |
-
|
| 407 |
-
// if we have enough unused cells before the current head ->
|
| 408 |
-
// better to start searching from the beginning of the cache, hoping to fill it
|
| 409 |
-
if (head > cells.get_used() + 2*ubatch.n_tokens) {
|
| 410 |
-
head = 0;
|
| 411 |
-
}
|
| 412 |
-
|
| 413 |
-
// otherwise, one cell per token.
|
| 414 |
-
|
| 415 |
-
if (n_tokens > cells.size()) {
|
| 416 |
-
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
| 417 |
-
return false;
|
| 418 |
-
}
|
| 419 |
-
|
| 420 |
-
//#define FIND_SLOT_DEBUG 1
|
| 421 |
-
#if FIND_SLOT_DEBUG
|
| 422 |
-
LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
|
| 423 |
-
|
| 424 |
-
// for debugging
|
| 425 |
-
{
|
| 426 |
-
std::string ss;
|
| 427 |
-
if (n_swa > 0) {
|
| 428 |
-
for (uint32_t i = 0; i < size; ++i) {
|
| 429 |
-
if (cells.is_empty(i)) {
|
| 430 |
-
ss += '.';
|
| 431 |
-
} else {
|
| 432 |
-
ss += 'x';
|
| 433 |
-
}
|
| 434 |
-
if (i%256 == 255) {
|
| 435 |
-
ss += '\n';
|
| 436 |
-
}
|
| 437 |
-
}
|
| 438 |
-
}
|
| 439 |
-
LLAMA_LOG_WARN("\n%s\n", ss.c_str());
|
| 440 |
-
}
|
| 441 |
-
#endif
|
| 442 |
-
|
| 443 |
-
uint32_t n_tested = 0;
|
| 444 |
-
|
| 445 |
-
while (true) {
|
| 446 |
-
if (head + n_tokens > cells.size()) {
|
| 447 |
-
n_tested += cells.size() - head;
|
| 448 |
-
head = 0;
|
| 449 |
-
continue;
|
| 450 |
-
}
|
| 451 |
-
|
| 452 |
-
bool found = true;
|
| 453 |
-
for (uint32_t i = 0; i < n_tokens; i++) {
|
| 454 |
-
// TODO: improve to accept cells that are masked by the SWA
|
| 455 |
-
if (!cells.is_empty(head + i)) {
|
| 456 |
-
found = false;
|
| 457 |
-
head += i + 1;
|
| 458 |
-
n_tested += i + 1;
|
| 459 |
-
break;
|
| 460 |
-
}
|
| 461 |
-
}
|
| 462 |
-
|
| 463 |
-
if (found) {
|
| 464 |
-
break;
|
| 465 |
-
}
|
| 466 |
-
|
| 467 |
-
if (n_tested >= cells.size()) {
|
| 468 |
-
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
| 469 |
-
return false;
|
| 470 |
-
}
|
| 471 |
-
}
|
| 472 |
-
|
| 473 |
-
// store the old state of the cells in the recovery stack
|
| 474 |
-
recovery.states.push_back({head, cells.cp(head, n_tokens)});
|
| 475 |
-
|
| 476 |
-
for (uint32_t i = 0; i < n_tokens; ++i) {
|
| 477 |
-
cells.pos_set(head + i, ubatch.pos[i]);
|
| 478 |
-
|
| 479 |
-
for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
|
| 480 |
-
cells.seq_add(head + i, ubatch.seq_id[i][j]);
|
| 481 |
-
}
|
| 482 |
-
}
|
| 483 |
-
|
| 484 |
-
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
| 485 |
-
// after enough generations, the benefit from this heuristic disappears
|
| 486 |
-
// if we start defragmenting the cache, the benefit from this will be more important
|
| 487 |
-
n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
|
| 488 |
-
|
| 489 |
-
#ifdef FIND_SLOT_DEBUG
|
| 490 |
-
LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
|
| 491 |
-
#endif
|
| 492 |
-
|
| 493 |
-
return true;
|
| 494 |
-
}
|
| 495 |
-
|
| 496 |
-
bool llama_kv_cache_unified::get_can_shift() const {
|
| 497 |
-
return true;
|
| 498 |
-
}
|
| 499 |
-
|
| 500 |
-
uint32_t llama_kv_cache_unified::get_n() const {
|
| 501 |
-
return n;
|
| 502 |
-
}
|
| 503 |
-
|
| 504 |
-
uint32_t llama_kv_cache_unified::get_size() const {
|
| 505 |
-
return cells.size();
|
| 506 |
-
}
|
| 507 |
-
|
| 508 |
-
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const {
|
| 509 |
-
const int32_t ikv = map_layer_ids.at(il);
|
| 510 |
-
|
| 511 |
-
auto * k = layers[ikv].k;
|
| 512 |
-
|
| 513 |
-
return ggml_view_3d(ctx, k,
|
| 514 |
-
hparams.n_embd_head_k, hparams.n_head_kv(il), n,
|
| 515 |
-
ggml_row_size(k->type, hparams.n_embd_head_k),
|
| 516 |
-
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
|
| 517 |
-
0);
|
| 518 |
-
}
|
| 519 |
-
|
| 520 |
-
ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) const {
|
| 521 |
-
const int32_t ikv = map_layer_ids.at(il);
|
| 522 |
-
|
| 523 |
-
auto * v = layers[ikv].v;
|
| 524 |
-
|
| 525 |
-
if (!v_trans) {
|
| 526 |
-
// note: v->nb[1] <= v->nb[2]
|
| 527 |
-
return ggml_view_3d(ctx, v,
|
| 528 |
-
hparams.n_embd_head_v, hparams.n_head_kv(il), n,
|
| 529 |
-
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
|
| 530 |
-
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
|
| 531 |
-
0);
|
| 532 |
-
}
|
| 533 |
-
|
| 534 |
-
// note: v->nb[1] > v->nb[2]
|
| 535 |
-
return ggml_view_3d(ctx, v,
|
| 536 |
-
n, hparams.n_head_kv(il), hparams.n_embd_head_v,
|
| 537 |
-
ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
|
| 538 |
-
ggml_row_size(v->type, v->ne[1]), // v->nb[2]
|
| 539 |
-
0);
|
| 540 |
-
}
|
| 541 |
-
|
| 542 |
-
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
|
| 543 |
-
const int32_t ikv = map_layer_ids.at(il);
|
| 544 |
-
|
| 545 |
-
auto * k = layers[ikv].k;
|
| 546 |
-
|
| 547 |
-
const int64_t n_tokens = k_cur->ne[2];
|
| 548 |
-
|
| 549 |
-
ggml_tensor * k_view = ggml_view_1d(ctx, k,
|
| 550 |
-
n_tokens*hparams.n_embd_k_gqa(il),
|
| 551 |
-
ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head);
|
| 552 |
-
|
| 553 |
-
return ggml_cpy(ctx, k_cur, k_view);
|
| 554 |
-
}
|
| 555 |
-
|
| 556 |
-
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
|
| 557 |
-
const int32_t ikv = map_layer_ids.at(il);
|
| 558 |
-
|
| 559 |
-
auto * v = layers[ikv].v;
|
| 560 |
-
|
| 561 |
-
const int64_t n_tokens = v_cur->ne[2];
|
| 562 |
-
|
| 563 |
-
v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens);
|
| 564 |
-
|
| 565 |
-
ggml_tensor * v_view = nullptr;
|
| 566 |
-
|
| 567 |
-
if (!v_trans) {
|
| 568 |
-
v_view = ggml_view_1d(ctx, v,
|
| 569 |
-
n_tokens*hparams.n_embd_v_gqa(il),
|
| 570 |
-
ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head);
|
| 571 |
-
} else {
|
| 572 |
-
// note: the V cache is transposed when not using flash attention
|
| 573 |
-
v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
|
| 574 |
-
(v->ne[1])*ggml_element_size(v),
|
| 575 |
-
( head)*ggml_element_size(v));
|
| 576 |
-
|
| 577 |
-
v_cur = ggml_transpose(ctx, v_cur);
|
| 578 |
-
}
|
| 579 |
-
|
| 580 |
-
return ggml_cpy(ctx, v_cur, v_view);
|
| 581 |
-
}
|
| 582 |
-
|
| 583 |
-
void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax) {
|
| 584 |
-
// no pruning is needed when the cache does not use SWA
|
| 585 |
-
GGML_ASSERT(swa_type != LLAMA_SWA_TYPE_NONE && "do not prune non-SWA cache");
|
| 586 |
-
|
| 587 |
-
int n_attended = 0;
|
| 588 |
-
|
| 589 |
-
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 590 |
-
if (!cells.seq_has(i, seq_id)) {
|
| 591 |
-
continue;
|
| 592 |
-
}
|
| 593 |
-
|
| 594 |
-
const llama_pos p0 = cells.pos_get(i);
|
| 595 |
-
|
| 596 |
-
if (p0 <= pmin && !is_masked_swa(p0, pmin)) {
|
| 597 |
-
n_attended++;
|
| 598 |
-
}
|
| 599 |
-
|
| 600 |
-
if (is_masked_swa(p0, pmax)) {
|
| 601 |
-
cells.seq_rm(i, seq_id);
|
| 602 |
-
}
|
| 603 |
-
}
|
| 604 |
-
|
| 605 |
-
if (n_attended < std::min<int>(n_swa, pmin)) {
|
| 606 |
-
LLAMA_LOG_WARN("%s: partial SWA cache detected - possible loss of information, pmin = %d, n_attended = %d, n_swa = %d\n", __func__, pmin, n_attended, n_swa);
|
| 607 |
-
}
|
| 608 |
-
}
|
| 609 |
-
|
| 610 |
-
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
| 611 |
-
const int64_t n_tokens = ubatch->n_tokens;
|
| 612 |
-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
| 613 |
-
const int64_t n_seqs = ubatch->n_seqs;
|
| 614 |
-
|
| 615 |
-
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 616 |
-
float * data = (float *) dst->data;
|
| 617 |
-
|
| 618 |
-
const int64_t n_kv = n;
|
| 619 |
-
|
| 620 |
-
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
| 621 |
-
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
| 622 |
-
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
|
| 623 |
-
// Causal mask:
|
| 624 |
-
// xxx-------
|
| 625 |
-
// xxxx------
|
| 626 |
-
// xxxxx-----
|
| 627 |
-
// Non-causal mask:
|
| 628 |
-
// xxxxx-----
|
| 629 |
-
// xxxxx-----
|
| 630 |
-
// xxxxx-----
|
| 631 |
-
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
| 632 |
-
for (int h = 0; h < 1; ++h) {
|
| 633 |
-
for (int s = 0; s < n_seqs; ++s) {
|
| 634 |
-
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
| 635 |
-
|
| 636 |
-
for (int j = 0; j < n_seq_tokens; ++j) {
|
| 637 |
-
const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
|
| 638 |
-
|
| 639 |
-
for (int i = 0; i < n_kv; ++i) {
|
| 640 |
-
float f = 0.0f;
|
| 641 |
-
|
| 642 |
-
bool masked = false;
|
| 643 |
-
|
| 644 |
-
if (cells.is_empty(i)) {
|
| 645 |
-
masked = true;
|
| 646 |
-
} else {
|
| 647 |
-
const llama_pos p0 = cells.pos_get(i);
|
| 648 |
-
|
| 649 |
-
// mask the token if not the same sequence
|
| 650 |
-
masked = masked || (!cells.seq_has(i, seq_id));
|
| 651 |
-
|
| 652 |
-
// mask future tokens
|
| 653 |
-
masked = masked || (causal_attn && p0 > p1);
|
| 654 |
-
|
| 655 |
-
// apply SWA if any
|
| 656 |
-
masked = masked || (is_masked_swa(p0, p1));
|
| 657 |
-
|
| 658 |
-
if (!masked && hparams.use_alibi) {
|
| 659 |
-
f = -std::abs(p0 - p1);
|
| 660 |
-
}
|
| 661 |
-
}
|
| 662 |
-
|
| 663 |
-
if (masked) {
|
| 664 |
-
f = -INFINITY;
|
| 665 |
-
}
|
| 666 |
-
|
| 667 |
-
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
| 668 |
-
}
|
| 669 |
-
}
|
| 670 |
-
}
|
| 671 |
-
|
| 672 |
-
// mask padded tokens
|
| 673 |
-
if (data) {
|
| 674 |
-
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
| 675 |
-
for (int j = 0; j < n_kv; ++j) {
|
| 676 |
-
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
| 677 |
-
}
|
| 678 |
-
}
|
| 679 |
-
}
|
| 680 |
-
}
|
| 681 |
-
}
|
| 682 |
-
|
| 683 |
-
void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
|
| 684 |
-
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 685 |
-
|
| 686 |
-
int32_t * data = (int32_t *) dst->data;
|
| 687 |
-
|
| 688 |
-
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 689 |
-
data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
|
| 690 |
-
}
|
| 691 |
-
}
|
| 692 |
-
|
| 693 |
-
void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
| 694 |
-
const int64_t n_tokens = ubatch->n_tokens;
|
| 695 |
-
|
| 696 |
-
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 697 |
-
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
| 698 |
-
|
| 699 |
-
int32_t * data = (int32_t *) dst->data;
|
| 700 |
-
|
| 701 |
-
const int64_t n_kv = n;
|
| 702 |
-
|
| 703 |
-
for (int h = 0; h < 1; ++h) {
|
| 704 |
-
for (int j = 0; j < n_tokens; ++j) {
|
| 705 |
-
for (int i = 0; i < n_kv; ++i) {
|
| 706 |
-
// the position when the cells is empty is irrelevant - it will be masked out later in the attention
|
| 707 |
-
const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i);
|
| 708 |
-
|
| 709 |
-
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
|
| 710 |
-
}
|
| 711 |
-
}
|
| 712 |
-
}
|
| 713 |
-
}
|
| 714 |
-
|
| 715 |
-
size_t llama_kv_cache_unified::total_size() const {
|
| 716 |
-
size_t size = 0;
|
| 717 |
-
|
| 718 |
-
for (const auto & buf : bufs) {
|
| 719 |
-
size += ggml_backend_buffer_get_size(buf.get());
|
| 720 |
-
}
|
| 721 |
-
|
| 722 |
-
return size;
|
| 723 |
-
}
|
| 724 |
-
|
| 725 |
-
size_t llama_kv_cache_unified::size_k_bytes() const {
|
| 726 |
-
size_t size_k_bytes = 0;
|
| 727 |
-
|
| 728 |
-
for (const auto & layer : layers) {
|
| 729 |
-
size_k_bytes += ggml_nbytes(layer.k);
|
| 730 |
-
}
|
| 731 |
-
|
| 732 |
-
return size_k_bytes;
|
| 733 |
-
}
|
| 734 |
-
|
| 735 |
-
size_t llama_kv_cache_unified::size_v_bytes() const {
|
| 736 |
-
size_t size_v_bytes = 0;
|
| 737 |
-
|
| 738 |
-
for (const auto & layer : layers) {
|
| 739 |
-
size_v_bytes += ggml_nbytes(layer.v);
|
| 740 |
-
}
|
| 741 |
-
|
| 742 |
-
return size_v_bytes;
|
| 743 |
-
}
|
| 744 |
-
|
| 745 |
-
ggml_tensor * llama_kv_cache_unified::build_rope_shift(
|
| 746 |
-
const llama_cparams & cparams,
|
| 747 |
-
ggml_context * ctx,
|
| 748 |
-
ggml_tensor * cur,
|
| 749 |
-
ggml_tensor * shift,
|
| 750 |
-
ggml_tensor * factors,
|
| 751 |
-
float freq_base,
|
| 752 |
-
float freq_scale) const {
|
| 753 |
-
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
|
| 754 |
-
|
| 755 |
-
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
|
| 756 |
-
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
|
| 757 |
-
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
|
| 758 |
-
|
| 759 |
-
const auto & n_rot = hparams.n_rot;
|
| 760 |
-
const auto & rope_type = hparams.rope_type;
|
| 761 |
-
|
| 762 |
-
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
|
| 763 |
-
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
|
| 764 |
-
const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
|
| 765 |
-
|
| 766 |
-
ggml_tensor * tmp;
|
| 767 |
-
|
| 768 |
-
if (ggml_is_quantized(cur->type)) {
|
| 769 |
-
// dequantize to f32 -> RoPE -> quantize back
|
| 770 |
-
tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
|
| 771 |
-
|
| 772 |
-
tmp = ggml_rope_ext(ctx, tmp,
|
| 773 |
-
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
| 774 |
-
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
|
| 775 |
-
|
| 776 |
-
tmp = ggml_cpy(ctx, tmp, cur);
|
| 777 |
-
} else {
|
| 778 |
-
// we rotate only the first n_rot dimensions
|
| 779 |
-
tmp = ggml_rope_ext_inplace(ctx, cur,
|
| 780 |
-
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
| 781 |
-
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
|
| 782 |
-
}
|
| 783 |
-
|
| 784 |
-
return tmp;
|
| 785 |
-
}
|
| 786 |
-
|
| 787 |
-
class llm_graph_input_k_shift : public llm_graph_input_i {
|
| 788 |
-
public:
|
| 789 |
-
llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
|
| 790 |
-
virtual ~llm_graph_input_k_shift() = default;
|
| 791 |
-
|
| 792 |
-
void set_input(const llama_ubatch * ubatch) override;
|
| 793 |
-
|
| 794 |
-
ggml_tensor * k_shift; // I32 [kv_size]
|
| 795 |
-
|
| 796 |
-
const llama_kv_cache_unified * kv_self;
|
| 797 |
-
};
|
| 798 |
-
|
| 799 |
-
void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
|
| 800 |
-
GGML_UNUSED(ubatch);
|
| 801 |
-
|
| 802 |
-
if (k_shift) {
|
| 803 |
-
kv_self->set_input_k_shift(k_shift);
|
| 804 |
-
}
|
| 805 |
-
}
|
| 806 |
-
|
| 807 |
-
llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
| 808 |
-
const llama_cparams & cparams,
|
| 809 |
-
ggml_context * ctx,
|
| 810 |
-
ggml_cgraph * gf) const {
|
| 811 |
-
auto res = std::make_unique<llm_graph_result>();
|
| 812 |
-
|
| 813 |
-
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
| 814 |
-
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
| 815 |
-
|
| 816 |
-
//GGML_ASSERT(kv_self->size == n_ctx);
|
| 817 |
-
|
| 818 |
-
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
| 819 |
-
|
| 820 |
-
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
|
| 821 |
-
ggml_set_input(inp->k_shift);
|
| 822 |
-
|
| 823 |
-
for (const auto & layer : layers) {
|
| 824 |
-
const uint32_t il = layer.il;
|
| 825 |
-
|
| 826 |
-
const int64_t n_head_kv = hparams.n_head_kv(il);
|
| 827 |
-
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
| 828 |
-
|
| 829 |
-
const float freq_base_l = model.get_rope_freq_base (cparams, il);
|
| 830 |
-
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
| 831 |
-
|
| 832 |
-
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
|
| 833 |
-
|
| 834 |
-
ggml_tensor * k =
|
| 835 |
-
ggml_view_3d(ctx, layer.k,
|
| 836 |
-
n_embd_head_k, n_head_kv, cells.size(),
|
| 837 |
-
ggml_row_size(layer.k->type, n_embd_head_k),
|
| 838 |
-
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
| 839 |
-
0);
|
| 840 |
-
|
| 841 |
-
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
|
| 842 |
-
|
| 843 |
-
ggml_build_forward_expand(gf, cur);
|
| 844 |
-
}
|
| 845 |
-
|
| 846 |
-
res->add_input(std::move(inp));
|
| 847 |
-
|
| 848 |
-
return res;
|
| 849 |
-
}
|
| 850 |
-
|
| 851 |
-
llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
| 852 |
-
const llama_cparams & cparams,
|
| 853 |
-
ggml_context * ctx,
|
| 854 |
-
ggml_cgraph * gf) const {
|
| 855 |
-
auto res = std::make_unique<llm_graph_result>();
|
| 856 |
-
|
| 857 |
-
const auto & ids = defrag_info.ids;
|
| 858 |
-
|
| 859 |
-
#if 0
|
| 860 |
-
// CPU defrag
|
| 861 |
-
//
|
| 862 |
-
// TODO: optimizations are possible:
|
| 863 |
-
// - multiple threads
|
| 864 |
-
// - avoid copying to the host memory when already there
|
| 865 |
-
//
|
| 866 |
-
// likely not worth the effort, as we have ggml_graph based defrag
|
| 867 |
-
//
|
| 868 |
-
|
| 869 |
-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
| 870 |
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
| 871 |
-
|
| 872 |
-
const uint32_t kv_size = size;
|
| 873 |
-
|
| 874 |
-
std::vector<uint8_t> buf_k;
|
| 875 |
-
std::vector<uint8_t> buf_v;
|
| 876 |
-
|
| 877 |
-
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 878 |
-
const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
|
| 879 |
-
const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
|
| 880 |
-
|
| 881 |
-
const size_t v_size_el = ggml_type_size(v_l[il]->type);
|
| 882 |
-
const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
|
| 883 |
-
|
| 884 |
-
buf_k.resize(k_size);
|
| 885 |
-
buf_v.resize(v_size);
|
| 886 |
-
|
| 887 |
-
ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
|
| 888 |
-
ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
|
| 889 |
-
|
| 890 |
-
// batch move [i, i+nm) to [id, id+nm)
|
| 891 |
-
// note: cells can move only to a lower index
|
| 892 |
-
for (uint32_t i = 0; i < n_kv; ++i) {
|
| 893 |
-
const uint32_t id = ids[i];
|
| 894 |
-
|
| 895 |
-
if (i == id || id == n_kv) {
|
| 896 |
-
continue;
|
| 897 |
-
}
|
| 898 |
-
|
| 899 |
-
uint32_t nm = 1;
|
| 900 |
-
|
| 901 |
-
while (i + nm < n_kv && ids[i + nm] == id + nm) {
|
| 902 |
-
nm++;
|
| 903 |
-
}
|
| 904 |
-
|
| 905 |
-
// move keys
|
| 906 |
-
{
|
| 907 |
-
const int64_t os = i*k_size_row;
|
| 908 |
-
const int64_t od = id*k_size_row;
|
| 909 |
-
|
| 910 |
-
memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
|
| 911 |
-
}
|
| 912 |
-
|
| 913 |
-
// move values (note: they are transposed)
|
| 914 |
-
{
|
| 915 |
-
const int64_t os = i;
|
| 916 |
-
const int64_t od = id;
|
| 917 |
-
|
| 918 |
-
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
| 919 |
-
memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
|
| 920 |
-
}
|
| 921 |
-
}
|
| 922 |
-
|
| 923 |
-
i += nm - 1;
|
| 924 |
-
}
|
| 925 |
-
|
| 926 |
-
ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
|
| 927 |
-
ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
|
| 928 |
-
}
|
| 929 |
-
#else
|
| 930 |
-
for (uint32_t i = 0; i < ids.size(); ++i) {
|
| 931 |
-
const uint32_t id = ids[i];
|
| 932 |
-
|
| 933 |
-
if (i == id || id == ids.size()) {
|
| 934 |
-
continue;
|
| 935 |
-
}
|
| 936 |
-
|
| 937 |
-
uint32_t nm = 1;
|
| 938 |
-
|
| 939 |
-
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
|
| 940 |
-
nm++;
|
| 941 |
-
}
|
| 942 |
-
|
| 943 |
-
for (const auto & layer : layers) {
|
| 944 |
-
const uint32_t il = layer.il;
|
| 945 |
-
|
| 946 |
-
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
| 947 |
-
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
| 948 |
-
|
| 949 |
-
ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k,
|
| 950 |
-
n_embd_k_gqa, nm,
|
| 951 |
-
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
| 952 |
-
ggml_row_size(layer.k->type, n_embd_k_gqa*i));
|
| 953 |
-
|
| 954 |
-
ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k,
|
| 955 |
-
n_embd_k_gqa, nm,
|
| 956 |
-
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
| 957 |
-
ggml_row_size(layer.k->type, n_embd_k_gqa*id));
|
| 958 |
-
|
| 959 |
-
ggml_tensor * view_v_src;
|
| 960 |
-
ggml_tensor * view_v_dst;
|
| 961 |
-
|
| 962 |
-
if (cparams.flash_attn) {
|
| 963 |
-
// NOTE: the V cache is not transposed when using flash attention
|
| 964 |
-
view_v_src = ggml_view_2d(ctx, layer.v,
|
| 965 |
-
n_embd_v_gqa, nm,
|
| 966 |
-
ggml_row_size(layer.v->type, n_embd_v_gqa),
|
| 967 |
-
ggml_row_size(layer.v->type, n_embd_v_gqa*i));
|
| 968 |
-
|
| 969 |
-
view_v_dst = ggml_view_2d(ctx, layer.v,
|
| 970 |
-
n_embd_v_gqa, nm,
|
| 971 |
-
ggml_row_size(layer.v->type, n_embd_v_gqa),
|
| 972 |
-
ggml_row_size(layer.v->type, n_embd_v_gqa*id));
|
| 973 |
-
} else {
|
| 974 |
-
view_v_src = ggml_view_2d(ctx, layer.v,
|
| 975 |
-
nm, n_embd_v_gqa,
|
| 976 |
-
ggml_row_size(layer.v->type, cells.size()),
|
| 977 |
-
ggml_row_size(layer.v->type, i));
|
| 978 |
-
|
| 979 |
-
view_v_dst = ggml_view_2d(ctx, layer.v,
|
| 980 |
-
nm, n_embd_v_gqa,
|
| 981 |
-
ggml_row_size(layer.v->type, cells.size()),
|
| 982 |
-
ggml_row_size(layer.v->type, id));
|
| 983 |
-
}
|
| 984 |
-
|
| 985 |
-
ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
|
| 986 |
-
ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst));
|
| 987 |
-
}
|
| 988 |
-
|
| 989 |
-
i += nm - 1;
|
| 990 |
-
}
|
| 991 |
-
|
| 992 |
-
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
|
| 993 |
-
#endif
|
| 994 |
-
|
| 995 |
-
return res;
|
| 996 |
-
}
|
| 997 |
-
|
| 998 |
-
bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
| 999 |
-
const uint32_t n_layer = layers.size();
|
| 1000 |
-
|
| 1001 |
-
const uint32_t n_kv = cells.used_max_p1();
|
| 1002 |
-
const uint32_t n_used = cells.get_used();
|
| 1003 |
-
|
| 1004 |
-
assert(n_used <= n_kv);
|
| 1005 |
-
|
| 1006 |
-
//const int64_t t_start = ggml_time_us();
|
| 1007 |
-
|
| 1008 |
-
// number of cells moved
|
| 1009 |
-
uint32_t n_moves = 0;
|
| 1010 |
-
|
| 1011 |
-
// each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
|
| 1012 |
-
// - source view, destination view, copy operation
|
| 1013 |
-
// - x2 for keys and values
|
| 1014 |
-
//const uint32_t max_moves = max_nodes()/(6*n_layer);
|
| 1015 |
-
// TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
|
| 1016 |
-
const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
|
| 1017 |
-
|
| 1018 |
-
// determine which KV cells to move where
|
| 1019 |
-
//
|
| 1020 |
-
// cell i moves to ids[i]
|
| 1021 |
-
//
|
| 1022 |
-
// if ids[i] == i || ids[i] == n_kv, then cell i is not moved
|
| 1023 |
-
//
|
| 1024 |
-
auto & ids = defrag_info.ids;
|
| 1025 |
-
|
| 1026 |
-
ids.clear();
|
| 1027 |
-
ids.resize(n_kv, n_kv);
|
| 1028 |
-
|
| 1029 |
-
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
| 1030 |
-
if (!cells.is_empty(i0)) {
|
| 1031 |
-
ids[i0] = i0;
|
| 1032 |
-
|
| 1033 |
-
continue;
|
| 1034 |
-
}
|
| 1035 |
-
|
| 1036 |
-
// found a hole - fill it with data from the end of the cache
|
| 1037 |
-
|
| 1038 |
-
uint32_t nh = 1;
|
| 1039 |
-
|
| 1040 |
-
// determine the size of the hole
|
| 1041 |
-
while (i0 + nh < n_used && cells.is_empty(i0 + nh)) {
|
| 1042 |
-
nh++;
|
| 1043 |
-
}
|
| 1044 |
-
|
| 1045 |
-
uint32_t nf = 0;
|
| 1046 |
-
uint32_t is = n_kv - 1;
|
| 1047 |
-
|
| 1048 |
-
// starting from the end, find nh non-empty cells
|
| 1049 |
-
for (; is > i0; --is) {
|
| 1050 |
-
if (cells.is_empty(is) || ids[is] != n_kv) {
|
| 1051 |
-
continue;
|
| 1052 |
-
}
|
| 1053 |
-
|
| 1054 |
-
// non-empty cell which is not yet moved
|
| 1055 |
-
nf++;
|
| 1056 |
-
|
| 1057 |
-
if (nf == nh) {
|
| 1058 |
-
break;
|
| 1059 |
-
}
|
| 1060 |
-
}
|
| 1061 |
-
|
| 1062 |
-
// this can only happen if `n_used` is not accurate, which would be a bug
|
| 1063 |
-
GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
|
| 1064 |
-
|
| 1065 |
-
nf = 0;
|
| 1066 |
-
|
| 1067 |
-
uint32_t i1 = is;
|
| 1068 |
-
|
| 1069 |
-
// are we moving a continuous block of memory?
|
| 1070 |
-
bool cont = false;
|
| 1071 |
-
|
| 1072 |
-
// should we stop searching for the next move?
|
| 1073 |
-
bool stop = false;
|
| 1074 |
-
|
| 1075 |
-
// go back and move the nf cells to the hole
|
| 1076 |
-
for (; i1 < n_kv; ++i1) {
|
| 1077 |
-
if (cells.is_empty(i1) || ids[i1] != n_kv) {
|
| 1078 |
-
if (n_moves == max_moves) {
|
| 1079 |
-
stop = true;
|
| 1080 |
-
break;
|
| 1081 |
-
}
|
| 1082 |
-
|
| 1083 |
-
cont = false;
|
| 1084 |
-
continue;
|
| 1085 |
-
}
|
| 1086 |
-
|
| 1087 |
-
// this cell goes to (i0 + nf)
|
| 1088 |
-
ids[i1] = i0 + nf;
|
| 1089 |
-
|
| 1090 |
-
// move the cell meta data
|
| 1091 |
-
cells.mv(i1, i0 + nf);
|
| 1092 |
-
|
| 1093 |
-
head = n_used;
|
| 1094 |
-
|
| 1095 |
-
if (!cont) {
|
| 1096 |
-
n_moves++;
|
| 1097 |
-
cont = true;
|
| 1098 |
-
}
|
| 1099 |
-
|
| 1100 |
-
nf++;
|
| 1101 |
-
|
| 1102 |
-
if (nf == nh) {
|
| 1103 |
-
break;
|
| 1104 |
-
}
|
| 1105 |
-
}
|
| 1106 |
-
|
| 1107 |
-
if (stop || n_moves == max_moves) {
|
| 1108 |
-
break;
|
| 1109 |
-
}
|
| 1110 |
-
|
| 1111 |
-
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
|
| 1112 |
-
|
| 1113 |
-
i0 += nh - 1;
|
| 1114 |
-
}
|
| 1115 |
-
|
| 1116 |
-
if (n_moves == 0) {
|
| 1117 |
-
return false;
|
| 1118 |
-
}
|
| 1119 |
-
|
| 1120 |
-
LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
|
| 1121 |
-
|
| 1122 |
-
LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
|
| 1123 |
-
|
| 1124 |
-
return true;
|
| 1125 |
-
}
|
| 1126 |
-
|
| 1127 |
-
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
| 1128 |
-
assert(p0 >= 0 && p1 >= 0);
|
| 1129 |
-
|
| 1130 |
-
switch (swa_type) {
|
| 1131 |
-
case LLAMA_SWA_TYPE_NONE:
|
| 1132 |
-
{
|
| 1133 |
-
} break;
|
| 1134 |
-
case LLAMA_SWA_TYPE_STANDARD:
|
| 1135 |
-
{
|
| 1136 |
-
if (p1 - p0 >= (int32_t) n_swa) {
|
| 1137 |
-
return true;
|
| 1138 |
-
}
|
| 1139 |
-
} break;
|
| 1140 |
-
case LLAMA_SWA_TYPE_CHUNKED:
|
| 1141 |
-
{
|
| 1142 |
-
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
|
| 1143 |
-
|
| 1144 |
-
if (p0 < pos_chunk_start) {
|
| 1145 |
-
return true;
|
| 1146 |
-
}
|
| 1147 |
-
} break;
|
| 1148 |
-
}
|
| 1149 |
-
|
| 1150 |
-
return false;
|
| 1151 |
-
}
|
| 1152 |
-
|
| 1153 |
-
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
| 1154 |
-
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
| 1155 |
-
uint32_t cell_count = 0;
|
| 1156 |
-
|
| 1157 |
-
// Count the number of cells with the specified seq_id
|
| 1158 |
-
// Find all the ranges of cells with this seq id (or all, when -1)
|
| 1159 |
-
uint32_t cell_range_begin = cells.size();
|
| 1160 |
-
|
| 1161 |
-
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 1162 |
-
if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
|
| 1163 |
-
++cell_count;
|
| 1164 |
-
if (cell_range_begin == cells.size()) {
|
| 1165 |
-
cell_range_begin = i;
|
| 1166 |
-
}
|
| 1167 |
-
} else {
|
| 1168 |
-
if (cell_range_begin != cells.size()) {
|
| 1169 |
-
cell_ranges.emplace_back(cell_range_begin, i);
|
| 1170 |
-
cell_range_begin = cells.size();
|
| 1171 |
-
}
|
| 1172 |
-
}
|
| 1173 |
-
}
|
| 1174 |
-
|
| 1175 |
-
if (cell_range_begin != cells.size()) {
|
| 1176 |
-
cell_ranges.emplace_back(cell_range_begin, cells.size());
|
| 1177 |
-
}
|
| 1178 |
-
|
| 1179 |
-
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
|
| 1180 |
-
uint32_t cell_count_check = 0;
|
| 1181 |
-
for (const auto & range : cell_ranges) {
|
| 1182 |
-
cell_count_check += range.second - range.first;
|
| 1183 |
-
}
|
| 1184 |
-
GGML_ASSERT(cell_count == cell_count_check);
|
| 1185 |
-
|
| 1186 |
-
io.write(&cell_count, sizeof(cell_count));
|
| 1187 |
-
|
| 1188 |
-
state_write_meta(io, cell_ranges, seq_id);
|
| 1189 |
-
state_write_data(io, cell_ranges);
|
| 1190 |
-
}
|
| 1191 |
-
|
| 1192 |
-
void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
| 1193 |
-
uint32_t cell_count;
|
| 1194 |
-
io.read_to(&cell_count, sizeof(cell_count));
|
| 1195 |
-
|
| 1196 |
-
bool res = true;
|
| 1197 |
-
res = res && state_read_meta(io, cell_count, seq_id);
|
| 1198 |
-
res = res && state_read_data(io, cell_count);
|
| 1199 |
-
|
| 1200 |
-
if (!res) {
|
| 1201 |
-
if (seq_id == -1) {
|
| 1202 |
-
clear();
|
| 1203 |
-
} else {
|
| 1204 |
-
seq_rm(seq_id, -1, -1);
|
| 1205 |
-
}
|
| 1206 |
-
throw std::runtime_error("failed to restore kv cache");
|
| 1207 |
-
}
|
| 1208 |
-
}
|
| 1209 |
-
|
| 1210 |
-
void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
| 1211 |
-
for (const auto & range : cell_ranges) {
|
| 1212 |
-
for (uint32_t i = range.first; i < range.second; ++i) {
|
| 1213 |
-
std::vector<llama_seq_id> seq_ids;
|
| 1214 |
-
|
| 1215 |
-
for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
|
| 1216 |
-
if (cur == seq_id || seq_id == -1) {
|
| 1217 |
-
if (cells.seq_has(i, cur)) {
|
| 1218 |
-
seq_ids.push_back(cur);
|
| 1219 |
-
}
|
| 1220 |
-
}
|
| 1221 |
-
}
|
| 1222 |
-
|
| 1223 |
-
const llama_pos pos = cells.pos_get(i);
|
| 1224 |
-
const uint32_t n_seq_id = seq_ids.size();
|
| 1225 |
-
|
| 1226 |
-
io.write(&pos, sizeof(pos));
|
| 1227 |
-
io.write(&n_seq_id, sizeof(n_seq_id));
|
| 1228 |
-
|
| 1229 |
-
for (const auto & seq_id : seq_ids) {
|
| 1230 |
-
io.write(&seq_id, sizeof(seq_id));
|
| 1231 |
-
}
|
| 1232 |
-
}
|
| 1233 |
-
}
|
| 1234 |
-
}
|
| 1235 |
-
|
| 1236 |
-
void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
| 1237 |
-
const uint32_t v_trans = this->v_trans ? 1 : 0;
|
| 1238 |
-
const uint32_t n_layer = layers.size();
|
| 1239 |
-
|
| 1240 |
-
io.write(&v_trans, sizeof(v_trans));
|
| 1241 |
-
io.write(&n_layer, sizeof(n_layer));
|
| 1242 |
-
|
| 1243 |
-
std::vector<uint8_t> tmp_buf;
|
| 1244 |
-
|
| 1245 |
-
// Iterate and write all the keys first, each row is a cell
|
| 1246 |
-
// Get whole range at a time
|
| 1247 |
-
for (const auto & layer : layers) {
|
| 1248 |
-
const uint32_t il = layer.il;
|
| 1249 |
-
|
| 1250 |
-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
| 1251 |
-
|
| 1252 |
-
// Write key type
|
| 1253 |
-
const int32_t k_type_i = (int32_t)layer.k->type;
|
| 1254 |
-
io.write(&k_type_i, sizeof(k_type_i));
|
| 1255 |
-
|
| 1256 |
-
// Write row size of key
|
| 1257 |
-
const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
|
| 1258 |
-
io.write(&k_size_row, sizeof(k_size_row));
|
| 1259 |
-
|
| 1260 |
-
// Read each range of cells of k_size length each into tmp_buf and write out
|
| 1261 |
-
for (const auto & range : cell_ranges) {
|
| 1262 |
-
const size_t range_size = range.second - range.first;
|
| 1263 |
-
const size_t buf_size = range_size * k_size_row;
|
| 1264 |
-
io.write_tensor(layer.k, range.first * k_size_row, buf_size);
|
| 1265 |
-
}
|
| 1266 |
-
}
|
| 1267 |
-
|
| 1268 |
-
if (!v_trans) {
|
| 1269 |
-
for (const auto & layer : layers) {
|
| 1270 |
-
const uint32_t il = layer.il;
|
| 1271 |
-
|
| 1272 |
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 1273 |
-
|
| 1274 |
-
// Write value type
|
| 1275 |
-
const int32_t v_type_i = (int32_t)layer.v->type;
|
| 1276 |
-
io.write(&v_type_i, sizeof(v_type_i));
|
| 1277 |
-
|
| 1278 |
-
// Write row size of value
|
| 1279 |
-
const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
|
| 1280 |
-
io.write(&v_size_row, sizeof(v_size_row));
|
| 1281 |
-
|
| 1282 |
-
// Read each range of cells of v_size length each into tmp_buf and write out
|
| 1283 |
-
for (const auto & range : cell_ranges) {
|
| 1284 |
-
const size_t range_size = range.second - range.first;
|
| 1285 |
-
const size_t buf_size = range_size * v_size_row;
|
| 1286 |
-
io.write_tensor(layer.v, range.first * v_size_row, buf_size);
|
| 1287 |
-
}
|
| 1288 |
-
}
|
| 1289 |
-
} else {
|
| 1290 |
-
// When v is transposed, we also need the element size and get the element ranges from each row
|
| 1291 |
-
const uint32_t kv_size = cells.size();
|
| 1292 |
-
|
| 1293 |
-
for (const auto & layer : layers) {
|
| 1294 |
-
const uint32_t il = layer.il;
|
| 1295 |
-
|
| 1296 |
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 1297 |
-
|
| 1298 |
-
// Write value type
|
| 1299 |
-
const int32_t v_type_i = (int32_t)layer.v->type;
|
| 1300 |
-
io.write(&v_type_i, sizeof(v_type_i));
|
| 1301 |
-
|
| 1302 |
-
// Write element size
|
| 1303 |
-
const uint32_t v_size_el = ggml_type_size(layer.v->type);
|
| 1304 |
-
io.write(&v_size_el, sizeof(v_size_el));
|
| 1305 |
-
|
| 1306 |
-
// Write GQA embedding size
|
| 1307 |
-
io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
|
| 1308 |
-
|
| 1309 |
-
// For each row, we get the element values of each cell
|
| 1310 |
-
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
| 1311 |
-
// Read each range of cells of v_size_el length each into tmp_buf and write out
|
| 1312 |
-
for (const auto & range : cell_ranges) {
|
| 1313 |
-
const size_t range_size = range.second - range.first;
|
| 1314 |
-
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
|
| 1315 |
-
const size_t buf_size = range_size * v_size_el;
|
| 1316 |
-
io.write_tensor(layer.v, src_offset, buf_size);
|
| 1317 |
-
}
|
| 1318 |
-
}
|
| 1319 |
-
}
|
| 1320 |
-
}
|
| 1321 |
-
}
|
| 1322 |
-
|
| 1323 |
-
bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
| 1324 |
-
if (dest_seq_id != -1) {
|
| 1325 |
-
// single sequence
|
| 1326 |
-
|
| 1327 |
-
seq_rm(dest_seq_id, -1, -1);
|
| 1328 |
-
|
| 1329 |
-
llama_sbatch sbatch;
|
| 1330 |
-
llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
| 1331 |
-
|
| 1332 |
-
batch.n_tokens = cell_count;
|
| 1333 |
-
|
| 1334 |
-
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 1335 |
-
llama_pos pos;
|
| 1336 |
-
uint32_t n_seq_id;
|
| 1337 |
-
|
| 1338 |
-
io.read_to(&pos, sizeof(pos));
|
| 1339 |
-
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
| 1340 |
-
|
| 1341 |
-
if (n_seq_id != 1) {
|
| 1342 |
-
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
| 1343 |
-
return false;
|
| 1344 |
-
}
|
| 1345 |
-
|
| 1346 |
-
// read the sequence id, but directly discard it - we will use dest_seq_id instead
|
| 1347 |
-
{
|
| 1348 |
-
llama_seq_id seq_id;
|
| 1349 |
-
io.read_to(&seq_id, sizeof(seq_id));
|
| 1350 |
-
}
|
| 1351 |
-
|
| 1352 |
-
batch.pos[i] = pos;
|
| 1353 |
-
batch.n_seq_id[i] = n_seq_id;
|
| 1354 |
-
batch.seq_id[i] = &dest_seq_id;
|
| 1355 |
-
}
|
| 1356 |
-
|
| 1357 |
-
if (!find_slot(batch)) {
|
| 1358 |
-
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
| 1359 |
-
return false;
|
| 1360 |
-
}
|
| 1361 |
-
|
| 1362 |
-
commit();
|
| 1363 |
-
|
| 1364 |
-
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
| 1365 |
-
// Assume that this is one contiguous block of cells
|
| 1366 |
-
GGML_ASSERT(head + cell_count <= cells.size());
|
| 1367 |
-
GGML_ASSERT(cells.pos_get(head) == batch.pos[0]);
|
| 1368 |
-
GGML_ASSERT(cells.pos_get(head + cell_count - 1) == batch.pos[cell_count - 1]);
|
| 1369 |
-
GGML_ASSERT(cells.seq_has(head, dest_seq_id));
|
| 1370 |
-
GGML_ASSERT(cells.seq_has(head + cell_count - 1, dest_seq_id));
|
| 1371 |
-
} else {
|
| 1372 |
-
// whole KV cache restore
|
| 1373 |
-
|
| 1374 |
-
if (cell_count > cells.size()) {
|
| 1375 |
-
LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
|
| 1376 |
-
return false;
|
| 1377 |
-
}
|
| 1378 |
-
|
| 1379 |
-
clear();
|
| 1380 |
-
|
| 1381 |
-
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 1382 |
-
llama_pos pos;
|
| 1383 |
-
uint32_t n_seq_id;
|
| 1384 |
-
|
| 1385 |
-
io.read_to(&pos, sizeof(pos));
|
| 1386 |
-
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
| 1387 |
-
|
| 1388 |
-
cells.pos_set(i, pos);
|
| 1389 |
-
|
| 1390 |
-
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
| 1391 |
-
llama_seq_id seq_id;
|
| 1392 |
-
io.read_to(&seq_id, sizeof(seq_id));
|
| 1393 |
-
|
| 1394 |
-
if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
|
| 1395 |
-
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
|
| 1396 |
-
return false;
|
| 1397 |
-
}
|
| 1398 |
-
|
| 1399 |
-
cells.seq_add(i, seq_id);
|
| 1400 |
-
}
|
| 1401 |
-
}
|
| 1402 |
-
|
| 1403 |
-
head = 0;
|
| 1404 |
-
}
|
| 1405 |
-
|
| 1406 |
-
return true;
|
| 1407 |
-
}
|
| 1408 |
-
|
| 1409 |
-
bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
| 1410 |
-
uint32_t v_trans;
|
| 1411 |
-
uint32_t n_layer;
|
| 1412 |
-
|
| 1413 |
-
io.read_to(&v_trans, sizeof(v_trans));
|
| 1414 |
-
io.read_to(&n_layer, sizeof(n_layer));
|
| 1415 |
-
|
| 1416 |
-
if (n_layer != layers.size()) {
|
| 1417 |
-
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
|
| 1418 |
-
return false;
|
| 1419 |
-
}
|
| 1420 |
-
if (cell_count > cells.size()) {
|
| 1421 |
-
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
|
| 1422 |
-
return false;
|
| 1423 |
-
}
|
| 1424 |
-
if (this->v_trans != (bool) v_trans) {
|
| 1425 |
-
LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
|
| 1426 |
-
return false;
|
| 1427 |
-
}
|
| 1428 |
-
|
| 1429 |
-
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
|
| 1430 |
-
for (const auto & layer : layers) {
|
| 1431 |
-
const uint32_t il = layer.il;
|
| 1432 |
-
|
| 1433 |
-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
| 1434 |
-
|
| 1435 |
-
// Read type of key
|
| 1436 |
-
int32_t k_type_i_ref;
|
| 1437 |
-
io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
|
| 1438 |
-
const int32_t k_type_i = (int32_t) layer.k->type;
|
| 1439 |
-
if (k_type_i != k_type_i_ref) {
|
| 1440 |
-
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
|
| 1441 |
-
return false;
|
| 1442 |
-
}
|
| 1443 |
-
|
| 1444 |
-
// Read row size of key
|
| 1445 |
-
uint64_t k_size_row_ref;
|
| 1446 |
-
io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
|
| 1447 |
-
const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
|
| 1448 |
-
if (k_size_row != k_size_row_ref) {
|
| 1449 |
-
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
|
| 1450 |
-
return false;
|
| 1451 |
-
}
|
| 1452 |
-
|
| 1453 |
-
if (cell_count) {
|
| 1454 |
-
// Read and set the keys for the whole cell range
|
| 1455 |
-
ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
|
| 1456 |
-
}
|
| 1457 |
-
}
|
| 1458 |
-
|
| 1459 |
-
if (!this->v_trans) {
|
| 1460 |
-
for (const auto & layer : layers) {
|
| 1461 |
-
const uint32_t il = layer.il;
|
| 1462 |
-
|
| 1463 |
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 1464 |
-
|
| 1465 |
-
// Read type of value
|
| 1466 |
-
int32_t v_type_i_ref;
|
| 1467 |
-
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
| 1468 |
-
const int32_t v_type_i = (int32_t)layer.v->type;
|
| 1469 |
-
if (v_type_i != v_type_i_ref) {
|
| 1470 |
-
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
| 1471 |
-
return false;
|
| 1472 |
-
}
|
| 1473 |
-
|
| 1474 |
-
// Read row size of value
|
| 1475 |
-
uint64_t v_size_row_ref;
|
| 1476 |
-
io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
|
| 1477 |
-
const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
|
| 1478 |
-
if (v_size_row != v_size_row_ref) {
|
| 1479 |
-
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
|
| 1480 |
-
return false;
|
| 1481 |
-
}
|
| 1482 |
-
|
| 1483 |
-
if (cell_count) {
|
| 1484 |
-
// Read and set the values for the whole cell range
|
| 1485 |
-
ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
|
| 1486 |
-
}
|
| 1487 |
-
}
|
| 1488 |
-
} else {
|
| 1489 |
-
// For each layer, read the values for each cell (transposed)
|
| 1490 |
-
for (const auto & layer : layers) {
|
| 1491 |
-
const uint32_t il = layer.il;
|
| 1492 |
-
|
| 1493 |
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 1494 |
-
|
| 1495 |
-
// Read type of value
|
| 1496 |
-
int32_t v_type_i_ref;
|
| 1497 |
-
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
| 1498 |
-
const int32_t v_type_i = (int32_t)layer.v->type;
|
| 1499 |
-
if (v_type_i != v_type_i_ref) {
|
| 1500 |
-
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
| 1501 |
-
return false;
|
| 1502 |
-
}
|
| 1503 |
-
|
| 1504 |
-
// Read element size of value
|
| 1505 |
-
uint32_t v_size_el_ref;
|
| 1506 |
-
io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
|
| 1507 |
-
const size_t v_size_el = ggml_type_size(layer.v->type);
|
| 1508 |
-
if (v_size_el != v_size_el_ref) {
|
| 1509 |
-
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
|
| 1510 |
-
return false;
|
| 1511 |
-
}
|
| 1512 |
-
|
| 1513 |
-
// Read GQA embedding size
|
| 1514 |
-
uint32_t n_embd_v_gqa_ref;
|
| 1515 |
-
io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
|
| 1516 |
-
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
|
| 1517 |
-
LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
|
| 1518 |
-
return false;
|
| 1519 |
-
}
|
| 1520 |
-
|
| 1521 |
-
if (cell_count) {
|
| 1522 |
-
// For each row in the transposed matrix, read the values for the whole cell range
|
| 1523 |
-
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
| 1524 |
-
const size_t dst_offset = (head + j * cells.size()) * v_size_el;
|
| 1525 |
-
ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
| 1526 |
-
}
|
| 1527 |
-
}
|
| 1528 |
-
}
|
| 1529 |
-
}
|
| 1530 |
-
|
| 1531 |
-
return true;
|
| 1532 |
-
}
|
| 1533 |
-
|
| 1534 |
-
//
|
| 1535 |
-
// llama_kv_cache_unified_iswa
|
| 1536 |
-
//
|
| 1537 |
-
|
| 1538 |
-
llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
| 1539 |
-
const llama_model & model,
|
| 1540 |
-
ggml_type type_k,
|
| 1541 |
-
ggml_type type_v,
|
| 1542 |
-
bool v_trans,
|
| 1543 |
-
bool offload,
|
| 1544 |
-
bool swa_full,
|
| 1545 |
-
uint32_t kv_size,
|
| 1546 |
-
uint32_t n_seq_max,
|
| 1547 |
-
uint32_t n_batch,
|
| 1548 |
-
uint32_t n_pad) : hparams(model.hparams) {
|
| 1549 |
-
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
| 1550 |
-
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
| 1551 |
-
|
| 1552 |
-
const uint32_t size_base = kv_size;
|
| 1553 |
-
|
| 1554 |
-
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad));
|
| 1555 |
-
|
| 1556 |
-
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning
|
| 1557 |
-
if (swa_full) {
|
| 1558 |
-
LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
|
| 1559 |
-
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
| 1560 |
-
|
| 1561 |
-
size_swa = size_base;
|
| 1562 |
-
do_prune = false;
|
| 1563 |
-
}
|
| 1564 |
-
|
| 1565 |
-
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
|
| 1566 |
-
|
| 1567 |
-
kv_base = std::make_unique<llama_kv_cache_unified>(
|
| 1568 |
-
model, std::move(filter_base), type_k, type_v,
|
| 1569 |
-
v_trans, offload, size_base, n_seq_max, n_pad,
|
| 1570 |
-
0, LLAMA_SWA_TYPE_NONE);
|
| 1571 |
-
|
| 1572 |
-
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
| 1573 |
-
|
| 1574 |
-
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
| 1575 |
-
model, std::move(filter_swa), type_k, type_v,
|
| 1576 |
-
v_trans, offload, size_swa, n_seq_max, n_pad,
|
| 1577 |
-
hparams.n_swa, hparams.swa_type);
|
| 1578 |
-
}
|
| 1579 |
-
|
| 1580 |
-
void llama_kv_cache_unified_iswa::clear() {
|
| 1581 |
-
kv_base->clear();
|
| 1582 |
-
kv_swa ->clear();
|
| 1583 |
-
}
|
| 1584 |
-
|
| 1585 |
-
bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
| 1586 |
-
bool res = true;
|
| 1587 |
-
|
| 1588 |
-
res = res & kv_base->seq_rm(seq_id, p0, p1);
|
| 1589 |
-
res = res & kv_swa ->seq_rm(seq_id, p0, p1);
|
| 1590 |
-
|
| 1591 |
-
return res;
|
| 1592 |
-
}
|
| 1593 |
-
|
| 1594 |
-
void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
| 1595 |
-
kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
| 1596 |
-
kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
| 1597 |
-
}
|
| 1598 |
-
|
| 1599 |
-
void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
|
| 1600 |
-
kv_base->seq_keep(seq_id);
|
| 1601 |
-
kv_swa ->seq_keep(seq_id);
|
| 1602 |
-
}
|
| 1603 |
-
|
| 1604 |
-
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
| 1605 |
-
kv_base->seq_add(seq_id, p0, p1, shift);
|
| 1606 |
-
kv_swa ->seq_add(seq_id, p0, p1, shift);
|
| 1607 |
-
}
|
| 1608 |
-
|
| 1609 |
-
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
| 1610 |
-
kv_base->seq_div(seq_id, p0, p1, d);
|
| 1611 |
-
kv_swa ->seq_div(seq_id, p0, p1, d);
|
| 1612 |
-
}
|
| 1613 |
-
|
| 1614 |
-
llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
|
| 1615 |
-
// the base cache is a superset of the SWA cache, so we can just check the SWA cache
|
| 1616 |
-
return kv_swa->seq_pos_min(seq_id);
|
| 1617 |
-
}
|
| 1618 |
-
|
| 1619 |
-
llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
| 1620 |
-
return kv_swa->seq_pos_max(seq_id);
|
| 1621 |
-
}
|
| 1622 |
-
|
| 1623 |
-
void llama_kv_cache_unified_iswa::restore() {
|
| 1624 |
-
kv_base->restore();
|
| 1625 |
-
kv_swa ->restore();
|
| 1626 |
-
}
|
| 1627 |
-
|
| 1628 |
-
void llama_kv_cache_unified_iswa::commit() {
|
| 1629 |
-
kv_base->commit();
|
| 1630 |
-
kv_swa ->commit();
|
| 1631 |
-
|
| 1632 |
-
// slide the attention window, forgetting/pruning old tokens that are outside the window
|
| 1633 |
-
if (do_prune) {
|
| 1634 |
-
for (const auto & [seq_id, entry] : pending.pos) {
|
| 1635 |
-
kv_swa->prune_swa(seq_id, entry.pmin, entry.pmax);
|
| 1636 |
-
}
|
| 1637 |
-
|
| 1638 |
-
}
|
| 1639 |
-
|
| 1640 |
-
pending.clear();
|
| 1641 |
-
}
|
| 1642 |
-
|
| 1643 |
-
bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
|
| 1644 |
-
bool res = true;
|
| 1645 |
-
|
| 1646 |
-
res = res & kv_base->update(lctx);
|
| 1647 |
-
res = res & kv_swa ->update(lctx);
|
| 1648 |
-
|
| 1649 |
-
return res;
|
| 1650 |
-
}
|
| 1651 |
-
|
| 1652 |
-
void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
|
| 1653 |
-
kv_base->defrag_sched(thold);
|
| 1654 |
-
kv_swa ->defrag_sched(thold);
|
| 1655 |
-
}
|
| 1656 |
-
|
| 1657 |
-
void llama_kv_cache_unified_iswa::set_full() {
|
| 1658 |
-
kv_base->set_full();
|
| 1659 |
-
kv_swa ->set_full();
|
| 1660 |
-
}
|
| 1661 |
-
|
| 1662 |
-
llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
|
| 1663 |
-
pending.clear();
|
| 1664 |
-
|
| 1665 |
-
if (do_prune) {
|
| 1666 |
-
for (int i = 0; i < batch.n_tokens; ++i) {
|
| 1667 |
-
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
|
| 1668 |
-
const llama_seq_id seq_id = batch.seq_id[i][s];
|
| 1669 |
-
const llama_pos pos = batch.pos[i];
|
| 1670 |
-
|
| 1671 |
-
if (pending.pos.find(seq_id) == pending.pos.end()) {
|
| 1672 |
-
pending.pos[seq_id].pmin = pos;
|
| 1673 |
-
pending.pos[seq_id].pmax = pos;
|
| 1674 |
-
} else {
|
| 1675 |
-
pending.pos[seq_id].pmin = std::min(pending.pos[seq_id].pmin, pos);
|
| 1676 |
-
pending.pos[seq_id].pmax = std::max(pending.pos[seq_id].pmax, pos);
|
| 1677 |
-
}
|
| 1678 |
-
}
|
| 1679 |
-
}
|
| 1680 |
-
}
|
| 1681 |
-
|
| 1682 |
-
return llama_sbatch(batch, hparams.n_embd, true, logits_all);
|
| 1683 |
-
}
|
| 1684 |
-
|
| 1685 |
-
llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
|
| 1686 |
-
GGML_UNUSED(embd_pooled);
|
| 1687 |
-
return sbatch.split_simple(n_ubatch);
|
| 1688 |
-
}
|
| 1689 |
-
|
| 1690 |
-
bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
|
| 1691 |
-
bool res = true;
|
| 1692 |
-
|
| 1693 |
-
res = res & kv_base->find_slot(batch);
|
| 1694 |
-
res = res & kv_swa ->find_slot(batch);
|
| 1695 |
-
|
| 1696 |
-
return res;
|
| 1697 |
-
}
|
| 1698 |
-
|
| 1699 |
-
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
| 1700 |
-
return kv_base->get_size() == kv_swa->get_size();
|
| 1701 |
-
}
|
| 1702 |
-
|
| 1703 |
-
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
| 1704 |
-
kv_base->state_write(io, seq_id);
|
| 1705 |
-
kv_swa ->state_write(io, seq_id);
|
| 1706 |
-
}
|
| 1707 |
-
|
| 1708 |
-
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
| 1709 |
-
kv_base->state_read(io, seq_id);
|
| 1710 |
-
kv_swa ->state_read(io, seq_id);
|
| 1711 |
-
}
|
| 1712 |
-
|
| 1713 |
-
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_base() const {
|
| 1714 |
-
return kv_base.get();
|
| 1715 |
-
}
|
| 1716 |
-
|
| 1717 |
-
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_swa() const {
|
| 1718 |
-
return kv_swa.get();
|
| 1719 |
-
}
|
| 1720 |
-
|
| 1721 |
-
//
|
| 1722 |
-
// llama_kv_cache_recurrent
|
| 1723 |
-
//
|
| 1724 |
-
|
| 1725 |
-
llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
| 1726 |
-
const llama_model & model,
|
| 1727 |
-
ggml_type type_k,
|
| 1728 |
-
ggml_type type_v,
|
| 1729 |
-
bool offload,
|
| 1730 |
-
uint32_t kv_size,
|
| 1731 |
-
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
|
| 1732 |
-
const int32_t n_layer = hparams.n_layer;
|
| 1733 |
-
|
| 1734 |
-
LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
|
| 1735 |
-
__func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
|
| 1736 |
-
|
| 1737 |
-
head = 0;
|
| 1738 |
-
size = kv_size;
|
| 1739 |
-
used = 0;
|
| 1740 |
-
|
| 1741 |
-
cells.clear();
|
| 1742 |
-
cells.resize(kv_size);
|
| 1743 |
-
|
| 1744 |
-
// create a context for each buffer type
|
| 1745 |
-
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
| 1746 |
-
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
| 1747 |
-
auto it = ctx_map.find(buft);
|
| 1748 |
-
if (it == ctx_map.end()) {
|
| 1749 |
-
ggml_init_params params = {
|
| 1750 |
-
/*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
|
| 1751 |
-
/*.mem_buffer =*/ NULL,
|
| 1752 |
-
/*.no_alloc =*/ true,
|
| 1753 |
-
};
|
| 1754 |
-
|
| 1755 |
-
ggml_context * ctx = ggml_init(params);
|
| 1756 |
-
if (!ctx) {
|
| 1757 |
-
return nullptr;
|
| 1758 |
-
}
|
| 1759 |
-
|
| 1760 |
-
ctx_map[buft] = ctx;
|
| 1761 |
-
ctxs.emplace_back(ctx);
|
| 1762 |
-
|
| 1763 |
-
return ctx;
|
| 1764 |
-
}
|
| 1765 |
-
|
| 1766 |
-
return it->second;
|
| 1767 |
-
};
|
| 1768 |
-
|
| 1769 |
-
k_l.reserve(n_layer);
|
| 1770 |
-
v_l.reserve(n_layer);
|
| 1771 |
-
|
| 1772 |
-
for (int i = 0; i < n_layer; i++) {
|
| 1773 |
-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
|
| 1774 |
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
|
| 1775 |
-
|
| 1776 |
-
const char * dev_name = "CPU";
|
| 1777 |
-
|
| 1778 |
-
ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
|
| 1779 |
-
|
| 1780 |
-
if (offload) {
|
| 1781 |
-
auto * dev = model.dev_layer(i);
|
| 1782 |
-
buft = ggml_backend_dev_buffer_type(dev);
|
| 1783 |
-
|
| 1784 |
-
dev_name = ggml_backend_dev_name(dev);
|
| 1785 |
-
}
|
| 1786 |
-
|
| 1787 |
-
LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name);
|
| 1788 |
-
|
| 1789 |
-
ggml_context * ctx = ctx_for_buft(buft);
|
| 1790 |
-
if (!ctx) {
|
| 1791 |
-
throw std::runtime_error("failed to create ggml context for kv cache");
|
| 1792 |
-
}
|
| 1793 |
-
|
| 1794 |
-
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
|
| 1795 |
-
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
|
| 1796 |
-
ggml_format_name(k, "cache_k_l%d", i);
|
| 1797 |
-
ggml_format_name(v, "cache_v_l%d", i);
|
| 1798 |
-
k_l.push_back(k);
|
| 1799 |
-
v_l.push_back(v);
|
| 1800 |
-
}
|
| 1801 |
-
|
| 1802 |
-
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
| 1803 |
-
for (auto it : ctx_map) {
|
| 1804 |
-
auto * buft = it.first;
|
| 1805 |
-
auto * ctx = it.second;
|
| 1806 |
-
|
| 1807 |
-
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
| 1808 |
-
if (!buf) {
|
| 1809 |
-
throw std::runtime_error("failed to allocate buffer for kv cache");
|
| 1810 |
-
}
|
| 1811 |
-
ggml_backend_buffer_clear(buf, 0);
|
| 1812 |
-
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
| 1813 |
-
bufs.emplace_back(buf);
|
| 1814 |
-
}
|
| 1815 |
-
|
| 1816 |
-
{
|
| 1817 |
-
const size_t memory_size_k = size_k_bytes();
|
| 1818 |
-
const size_t memory_size_v = size_v_bytes();
|
| 1819 |
-
|
| 1820 |
-
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
| 1821 |
-
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
| 1822 |
-
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
| 1823 |
-
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
| 1824 |
-
}
|
| 1825 |
-
}
|
| 1826 |
-
|
| 1827 |
-
void llama_kv_cache_recurrent::clear() {
|
| 1828 |
-
for (int32_t i = 0; i < (int32_t) size; ++i) {
|
| 1829 |
-
cells[i].pos = -1;
|
| 1830 |
-
cells[i].seq_id.clear();
|
| 1831 |
-
cells[i].src = -1;
|
| 1832 |
-
cells[i].tail = -1;
|
| 1833 |
-
}
|
| 1834 |
-
head = 0;
|
| 1835 |
-
used = 0;
|
| 1836 |
-
|
| 1837 |
-
for (auto & buf : bufs) {
|
| 1838 |
-
ggml_backend_buffer_clear(buf.get(), 0);
|
| 1839 |
-
}
|
| 1840 |
-
}
|
| 1841 |
-
|
| 1842 |
-
bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
| 1843 |
-
uint32_t new_head = size;
|
| 1844 |
-
|
| 1845 |
-
if (p0 < 0) {
|
| 1846 |
-
p0 = 0;
|
| 1847 |
-
}
|
| 1848 |
-
|
| 1849 |
-
if (p1 < 0) {
|
| 1850 |
-
p1 = std::numeric_limits<llama_pos>::max();
|
| 1851 |
-
}
|
| 1852 |
-
|
| 1853 |
-
// models like Mamba or RWKV can't have a state partially erased
|
| 1854 |
-
if (seq_id >= (int64_t) size) {
|
| 1855 |
-
// could be fatal
|
| 1856 |
-
return false;
|
| 1857 |
-
}
|
| 1858 |
-
if (0 <= seq_id) {
|
| 1859 |
-
int32_t & tail_id = cells[seq_id].tail;
|
| 1860 |
-
if (tail_id >= 0) {
|
| 1861 |
-
const kv_cell & cell = cells[tail_id];
|
| 1862 |
-
// partial intersection is invalid
|
| 1863 |
-
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
|
| 1864 |
-
return false;
|
| 1865 |
-
}
|
| 1866 |
-
// invalidate tails which will be cleared
|
| 1867 |
-
if (p0 <= cell.pos && cell.pos < p1) {
|
| 1868 |
-
tail_id = -1;
|
| 1869 |
-
}
|
| 1870 |
-
}
|
| 1871 |
-
} else {
|
| 1872 |
-
// seq_id is negative, then the range should include everything or nothing
|
| 1873 |
-
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
|
| 1874 |
-
return false;
|
| 1875 |
-
}
|
| 1876 |
-
}
|
| 1877 |
-
|
| 1878 |
-
for (uint32_t i = 0; i < size; ++i) {
|
| 1879 |
-
if (cells[i].pos >= p0 && cells[i].pos < p1) {
|
| 1880 |
-
if (seq_id < 0) {
|
| 1881 |
-
cells[i].seq_id.clear();
|
| 1882 |
-
} else if (cells[i].has_seq_id(seq_id)) {
|
| 1883 |
-
cells[i].seq_id.erase(seq_id);
|
| 1884 |
-
} else {
|
| 1885 |
-
continue;
|
| 1886 |
-
}
|
| 1887 |
-
if (cells[i].is_empty()) {
|
| 1888 |
-
// keep count of the number of used cells
|
| 1889 |
-
if (cells[i].pos >= 0) {
|
| 1890 |
-
used--;
|
| 1891 |
-
}
|
| 1892 |
-
cells[i].pos = -1;
|
| 1893 |
-
cells[i].src = -1;
|
| 1894 |
-
if (new_head == size) {
|
| 1895 |
-
new_head = i;
|
| 1896 |
-
}
|
| 1897 |
-
}
|
| 1898 |
-
}
|
| 1899 |
-
}
|
| 1900 |
-
|
| 1901 |
-
// If we freed up a slot, set head to it so searching can start there.
|
| 1902 |
-
if (new_head != size && new_head < head) {
|
| 1903 |
-
head = new_head;
|
| 1904 |
-
}
|
| 1905 |
-
|
| 1906 |
-
return true;
|
| 1907 |
-
}
|
| 1908 |
-
|
| 1909 |
-
void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
| 1910 |
-
if (seq_id_src == seq_id_dst) {
|
| 1911 |
-
return;
|
| 1912 |
-
}
|
| 1913 |
-
|
| 1914 |
-
if (p0 < 0) {
|
| 1915 |
-
p0 = 0;
|
| 1916 |
-
}
|
| 1917 |
-
|
| 1918 |
-
if (p1 < 0) {
|
| 1919 |
-
p1 = std::numeric_limits<llama_pos>::max();
|
| 1920 |
-
}
|
| 1921 |
-
|
| 1922 |
-
if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
|
| 1923 |
-
kv_cell & tail_src = cells[seq_id_src];
|
| 1924 |
-
kv_cell & tail_dst = cells[seq_id_dst];
|
| 1925 |
-
if (tail_dst.tail >= 0) {
|
| 1926 |
-
// clear destination seq_id if it wasn't empty
|
| 1927 |
-
kv_cell & cell_dst = cells[tail_dst.tail];
|
| 1928 |
-
|
| 1929 |
-
cell_dst.seq_id.erase(seq_id_dst);
|
| 1930 |
-
tail_dst.tail = -1;
|
| 1931 |
-
if (cell_dst.seq_id.empty()) {
|
| 1932 |
-
cell_dst.pos = -1;
|
| 1933 |
-
cell_dst.src = -1;
|
| 1934 |
-
used -= 1;
|
| 1935 |
-
}
|
| 1936 |
-
}
|
| 1937 |
-
if (tail_src.tail >= 0) {
|
| 1938 |
-
kv_cell & cell_src = cells[tail_src.tail];
|
| 1939 |
-
|
| 1940 |
-
cell_src.seq_id.insert(seq_id_dst);
|
| 1941 |
-
tail_dst.tail = tail_src.tail;
|
| 1942 |
-
}
|
| 1943 |
-
}
|
| 1944 |
-
}
|
| 1945 |
-
|
| 1946 |
-
void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
|
| 1947 |
-
uint32_t new_head = size;
|
| 1948 |
-
|
| 1949 |
-
for (uint32_t i = 0; i < size; ++i) {
|
| 1950 |
-
if ((llama_seq_id) i != seq_id) {
|
| 1951 |
-
cells[i].tail = -1;
|
| 1952 |
-
}
|
| 1953 |
-
|
| 1954 |
-
if (!cells[i].has_seq_id(seq_id)) {
|
| 1955 |
-
if (cells[i].pos >= 0) {
|
| 1956 |
-
used--;
|
| 1957 |
-
}
|
| 1958 |
-
|
| 1959 |
-
cells[i].pos = -1;
|
| 1960 |
-
cells[i].src = -1;
|
| 1961 |
-
cells[i].seq_id.clear();
|
| 1962 |
-
|
| 1963 |
-
if (new_head == size){
|
| 1964 |
-
new_head = i;
|
| 1965 |
-
}
|
| 1966 |
-
} else {
|
| 1967 |
-
cells[i].seq_id.clear();
|
| 1968 |
-
cells[i].seq_id.insert(seq_id);
|
| 1969 |
-
}
|
| 1970 |
-
}
|
| 1971 |
-
|
| 1972 |
-
// If we freed up a slot, set head to it so searching can start there.
|
| 1973 |
-
if (new_head != size && new_head < head) {
|
| 1974 |
-
head = new_head;
|
| 1975 |
-
}
|
| 1976 |
-
}
|
| 1977 |
-
|
| 1978 |
-
void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
| 1979 |
-
if (shift == 0) {
|
| 1980 |
-
return;
|
| 1981 |
-
}
|
| 1982 |
-
|
| 1983 |
-
if (p0 < 0) {
|
| 1984 |
-
p0 = 0;
|
| 1985 |
-
}
|
| 1986 |
-
|
| 1987 |
-
if (p1 < 0) {
|
| 1988 |
-
p1 = std::numeric_limits<llama_pos>::max();
|
| 1989 |
-
}
|
| 1990 |
-
|
| 1991 |
-
// If there is no range then return early to avoid looping over the
|
| 1992 |
-
if (p0 == p1) {
|
| 1993 |
-
return;
|
| 1994 |
-
}
|
| 1995 |
-
|
| 1996 |
-
// for Mamba-like or RWKV models, only the pos needs to be shifted
|
| 1997 |
-
if (0 <= seq_id && seq_id < (int64_t) size) {
|
| 1998 |
-
const int32_t tail_id = cells[seq_id].tail;
|
| 1999 |
-
if (tail_id >= 0) {
|
| 2000 |
-
kv_cell & cell = cells[tail_id];
|
| 2001 |
-
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
| 2002 |
-
cell.pos += shift;
|
| 2003 |
-
}
|
| 2004 |
-
}
|
| 2005 |
-
}
|
| 2006 |
-
}
|
| 2007 |
-
|
| 2008 |
-
void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
| 2009 |
-
if (d == 1) {
|
| 2010 |
-
return;
|
| 2011 |
-
}
|
| 2012 |
-
|
| 2013 |
-
if (p0 < 0) {
|
| 2014 |
-
p0 = 0;
|
| 2015 |
-
}
|
| 2016 |
-
|
| 2017 |
-
if (p1 < 0) {
|
| 2018 |
-
p1 = std::numeric_limits<llama_pos>::max();
|
| 2019 |
-
}
|
| 2020 |
-
|
| 2021 |
-
// If there is no range then return early to avoid looping over the cache.
|
| 2022 |
-
if (p0 == p1) {
|
| 2023 |
-
return;
|
| 2024 |
-
}
|
| 2025 |
-
|
| 2026 |
-
// for Mamba-like or RWKV models, only the pos needs to be changed
|
| 2027 |
-
if (0 <= seq_id && seq_id < (int64_t) size) {
|
| 2028 |
-
const int32_t tail_id = cells[seq_id].tail;
|
| 2029 |
-
if (tail_id >= 0) {
|
| 2030 |
-
kv_cell & cell = cells[tail_id];
|
| 2031 |
-
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
| 2032 |
-
cell.pos /= d;
|
| 2033 |
-
}
|
| 2034 |
-
}
|
| 2035 |
-
}
|
| 2036 |
-
}
|
| 2037 |
-
|
| 2038 |
-
llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
|
| 2039 |
-
llama_pos result = std::numeric_limits<llama_pos>::max();
|
| 2040 |
-
|
| 2041 |
-
for (uint32_t i = 0; i < size; ++i) {
|
| 2042 |
-
if (cells[i].has_seq_id(seq_id)) {
|
| 2043 |
-
result = std::min(result, cells[i].pos);
|
| 2044 |
-
}
|
| 2045 |
-
}
|
| 2046 |
-
|
| 2047 |
-
if (result == std::numeric_limits<llama_pos>::max()) {
|
| 2048 |
-
result = -1;
|
| 2049 |
-
}
|
| 2050 |
-
|
| 2051 |
-
return result;
|
| 2052 |
-
}
|
| 2053 |
-
|
| 2054 |
-
llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
| 2055 |
-
llama_pos result = -1;
|
| 2056 |
-
|
| 2057 |
-
for (uint32_t i = 0; i < size; ++i) {
|
| 2058 |
-
if (cells[i].has_seq_id(seq_id)) {
|
| 2059 |
-
result = std::max(result, cells[i].pos);
|
| 2060 |
-
}
|
| 2061 |
-
}
|
| 2062 |
-
|
| 2063 |
-
return result;
|
| 2064 |
-
}
|
| 2065 |
-
|
| 2066 |
-
void llama_kv_cache_recurrent::restore() {
|
| 2067 |
-
if (pending.ranges.empty()) {
|
| 2068 |
-
return;
|
| 2069 |
-
}
|
| 2070 |
-
|
| 2071 |
-
seq_rm(-1, -1, -1);
|
| 2072 |
-
}
|
| 2073 |
-
|
| 2074 |
-
void llama_kv_cache_recurrent::commit() {
|
| 2075 |
-
pending.ranges.clear();
|
| 2076 |
-
}
|
| 2077 |
-
|
| 2078 |
-
bool llama_kv_cache_recurrent::update(llama_context & ctx) {
|
| 2079 |
-
GGML_UNUSED(ctx);
|
| 2080 |
-
return false;
|
| 2081 |
-
}
|
| 2082 |
-
|
| 2083 |
-
void llama_kv_cache_recurrent::defrag_sched(float thold) {
|
| 2084 |
-
GGML_UNUSED(thold);
|
| 2085 |
-
// noop
|
| 2086 |
-
}
|
| 2087 |
-
|
| 2088 |
-
void llama_kv_cache_recurrent::set_full() {
|
| 2089 |
-
n = size;
|
| 2090 |
-
head = 0;
|
| 2091 |
-
}
|
| 2092 |
-
|
| 2093 |
-
llama_sbatch llama_kv_cache_recurrent::sbatch_init(
|
| 2094 |
-
const llama_batch & batch,
|
| 2095 |
-
bool logits_all) {
|
| 2096 |
-
return llama_sbatch(batch, hparams.n_embd, false, logits_all);
|
| 2097 |
-
}
|
| 2098 |
-
|
| 2099 |
-
llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
|
| 2100 |
-
if (embd_pooled) {
|
| 2101 |
-
// Pooled embeddings cannot be split across ubatches (yet)
|
| 2102 |
-
return sbatch.split_seq(n_ubatch);
|
| 2103 |
-
}
|
| 2104 |
-
|
| 2105 |
-
return sbatch.split_equal(n_ubatch);
|
| 2106 |
-
}
|
| 2107 |
-
|
| 2108 |
-
bool llama_kv_cache_recurrent::find_slot(
|
| 2109 |
-
const llama_ubatch & ubatch) {
|
| 2110 |
-
const uint32_t n_tokens = ubatch.n_tokens;
|
| 2111 |
-
const uint32_t n_seqs = ubatch.n_seqs;
|
| 2112 |
-
|
| 2113 |
-
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
| 2114 |
-
|
| 2115 |
-
// if we have enough unused cells before the current head ->
|
| 2116 |
-
// better to start searching from the beginning of the cache, hoping to fill it
|
| 2117 |
-
if (head > used + 2*n_tokens) {
|
| 2118 |
-
head = 0;
|
| 2119 |
-
}
|
| 2120 |
-
|
| 2121 |
-
// For recurrent state architectures (like Mamba or RWKV),
|
| 2122 |
-
// each cache cell can store the state for a whole sequence.
|
| 2123 |
-
// A slot should be always be contiguous.
|
| 2124 |
-
|
| 2125 |
-
// can only process batches with an equal number of new tokens in each sequence
|
| 2126 |
-
GGML_ASSERT(ubatch.equal_seqs);
|
| 2127 |
-
|
| 2128 |
-
int32_t min = size - 1;
|
| 2129 |
-
int32_t max = 0;
|
| 2130 |
-
|
| 2131 |
-
// everything should fit if all seq_ids are smaller than the max
|
| 2132 |
-
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 2133 |
-
const uint32_t n_seq_id = ubatch.n_seq_id[s];
|
| 2134 |
-
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
| 2135 |
-
const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
| 2136 |
-
|
| 2137 |
-
if (seq_id < 0 || (uint32_t) seq_id >= size) {
|
| 2138 |
-
// too big seq_id
|
| 2139 |
-
// TODO: would it be possible to resize the cache instead?
|
| 2140 |
-
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max);
|
| 2141 |
-
return false;
|
| 2142 |
-
}
|
| 2143 |
-
if (j > 0) {
|
| 2144 |
-
kv_cell & seq = cells[seq_id];
|
| 2145 |
-
if (seq.tail >= 0) {
|
| 2146 |
-
kv_cell & cell = cells[seq.tail];
|
| 2147 |
-
// clear cells from seq_ids that become shared
|
| 2148 |
-
// (should not normally happen, but let's handle it anyway)
|
| 2149 |
-
cell.seq_id.erase(seq_id);
|
| 2150 |
-
seq.tail = -1;
|
| 2151 |
-
if (cell.seq_id.empty()) {
|
| 2152 |
-
cell.pos = -1;
|
| 2153 |
-
cell.src = -1;
|
| 2154 |
-
used -= 1;
|
| 2155 |
-
}
|
| 2156 |
-
}
|
| 2157 |
-
}
|
| 2158 |
-
}
|
| 2159 |
-
}
|
| 2160 |
-
|
| 2161 |
-
#ifndef NDEBUG
|
| 2162 |
-
{
|
| 2163 |
-
std::vector<int32_t> tails_verif;
|
| 2164 |
-
tails_verif.assign(size, -1);
|
| 2165 |
-
for (uint32_t i = 0; i < size; ++i) {
|
| 2166 |
-
kv_cell & cell = cells[i];
|
| 2167 |
-
for (llama_seq_id seq_id : cell.seq_id) {
|
| 2168 |
-
if (tails_verif[seq_id] != -1) {
|
| 2169 |
-
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
|
| 2170 |
-
}
|
| 2171 |
-
tails_verif[seq_id] = i;
|
| 2172 |
-
}
|
| 2173 |
-
}
|
| 2174 |
-
for (uint32_t i = 0; i < size; ++i) {
|
| 2175 |
-
if (tails_verif[i] != cells[i].tail) {
|
| 2176 |
-
LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]);
|
| 2177 |
-
}
|
| 2178 |
-
}
|
| 2179 |
-
}
|
| 2180 |
-
#endif
|
| 2181 |
-
|
| 2182 |
-
// find next empty cell
|
| 2183 |
-
uint32_t next_empty_cell = head;
|
| 2184 |
-
|
| 2185 |
-
for (uint32_t i = 0; i < size; ++i) {
|
| 2186 |
-
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
| 2187 |
-
kv_cell & cell = cells[next_empty_cell];
|
| 2188 |
-
if (cell.is_empty()) { break; }
|
| 2189 |
-
next_empty_cell += 1;
|
| 2190 |
-
}
|
| 2191 |
-
|
| 2192 |
-
// find usable cell range
|
| 2193 |
-
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 2194 |
-
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
| 2195 |
-
kv_cell & seq_meta = cells[seq_id];
|
| 2196 |
-
bool has_cell = false;
|
| 2197 |
-
if (seq_meta.tail >= 0) {
|
| 2198 |
-
kv_cell & cell = cells[seq_meta.tail];
|
| 2199 |
-
GGML_ASSERT(cell.has_seq_id(seq_id));
|
| 2200 |
-
// does this seq_id "own" the cell?
|
| 2201 |
-
if (cell.seq_id.size() == 1) { has_cell = true; }
|
| 2202 |
-
}
|
| 2203 |
-
if (!has_cell) {
|
| 2204 |
-
kv_cell & empty_cell = cells[next_empty_cell];
|
| 2205 |
-
GGML_ASSERT(empty_cell.is_empty());
|
| 2206 |
-
// copy old tail into the empty cell
|
| 2207 |
-
if (seq_meta.tail >= 0) {
|
| 2208 |
-
kv_cell & orig_cell = cells[seq_meta.tail];
|
| 2209 |
-
empty_cell.pos = orig_cell.pos;
|
| 2210 |
-
empty_cell.src = orig_cell.src;
|
| 2211 |
-
orig_cell.seq_id.erase(seq_id);
|
| 2212 |
-
empty_cell.seq_id.insert(seq_id); // will be overwritten
|
| 2213 |
-
}
|
| 2214 |
-
seq_meta.tail = next_empty_cell;
|
| 2215 |
-
// find next empty cell
|
| 2216 |
-
if (s + 1 < n_seqs) {
|
| 2217 |
-
next_empty_cell += 1;
|
| 2218 |
-
for (uint32_t i = 0; i < size; ++i) {
|
| 2219 |
-
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
| 2220 |
-
kv_cell & cell = cells[next_empty_cell];
|
| 2221 |
-
if (cell.is_empty()) { break; }
|
| 2222 |
-
next_empty_cell += 1;
|
| 2223 |
-
}
|
| 2224 |
-
}
|
| 2225 |
-
}
|
| 2226 |
-
if (min > seq_meta.tail) { min = seq_meta.tail; }
|
| 2227 |
-
if (max < seq_meta.tail) { max = seq_meta.tail; }
|
| 2228 |
-
}
|
| 2229 |
-
|
| 2230 |
-
// gather and re-order
|
| 2231 |
-
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 2232 |
-
int32_t dst_id = s + min;
|
| 2233 |
-
int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
|
| 2234 |
-
if (dst_id != src_id) {
|
| 2235 |
-
kv_cell & dst_cell = cells[dst_id];
|
| 2236 |
-
kv_cell & src_cell = cells[src_id];
|
| 2237 |
-
|
| 2238 |
-
std::swap(dst_cell.pos, src_cell.pos);
|
| 2239 |
-
std::swap(dst_cell.src, src_cell.src);
|
| 2240 |
-
std::swap(dst_cell.seq_id, src_cell.seq_id);
|
| 2241 |
-
|
| 2242 |
-
// swap tails (assuming they NEVER overlap)
|
| 2243 |
-
for (const llama_seq_id seq_id : src_cell.seq_id) {
|
| 2244 |
-
cells[seq_id].tail = src_id;
|
| 2245 |
-
}
|
| 2246 |
-
for (const llama_seq_id seq_id : dst_cell.seq_id) {
|
| 2247 |
-
cells[seq_id].tail = dst_id;
|
| 2248 |
-
}
|
| 2249 |
-
}
|
| 2250 |
-
}
|
| 2251 |
-
|
| 2252 |
-
// update the pos of the used seqs
|
| 2253 |
-
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 2254 |
-
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
|
| 2255 |
-
int32_t cell_id = s + min;
|
| 2256 |
-
kv_cell & cell = cells[cell_id];
|
| 2257 |
-
|
| 2258 |
-
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
|
| 2259 |
-
// What should happen when the pos backtracks or skips a value?
|
| 2260 |
-
// Clearing the state mid-batch would require special-casing which isn't done.
|
| 2261 |
-
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
|
| 2262 |
-
__func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
|
| 2263 |
-
}
|
| 2264 |
-
cell.pos = last_pos;
|
| 2265 |
-
cell.seq_id.clear();
|
| 2266 |
-
for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
|
| 2267 |
-
const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
| 2268 |
-
cell.seq_id.insert(seq_id);
|
| 2269 |
-
cells[seq_id].tail = cell_id;
|
| 2270 |
-
}
|
| 2271 |
-
}
|
| 2272 |
-
|
| 2273 |
-
// allow getting the range of used cells, from head to head + n
|
| 2274 |
-
head = min;
|
| 2275 |
-
n = max - min + 1;
|
| 2276 |
-
used = std::count_if(cells.begin(), cells.end(),
|
| 2277 |
-
[](const kv_cell & cell){ return !cell.is_empty(); });
|
| 2278 |
-
|
| 2279 |
-
// sanity check
|
| 2280 |
-
return n >= n_seqs;
|
| 2281 |
-
}
|
| 2282 |
-
|
| 2283 |
-
bool llama_kv_cache_recurrent::get_can_shift() const {
|
| 2284 |
-
return false;
|
| 2285 |
-
}
|
| 2286 |
-
|
| 2287 |
-
int32_t llama_kv_cache_recurrent::s_copy(int i) const {
|
| 2288 |
-
const uint32_t cell_id = i + head;
|
| 2289 |
-
|
| 2290 |
-
//////////////////////////////////////////////
|
| 2291 |
-
// TODO: this should not mutate the KV cache !
|
| 2292 |
-
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
|
| 2293 |
-
|
| 2294 |
-
// prevent out-of-bound sources
|
| 2295 |
-
if (cell.src < 0 || (uint32_t) cell.src >= size) {
|
| 2296 |
-
cell.src = cell_id;
|
| 2297 |
-
}
|
| 2298 |
-
|
| 2299 |
-
int32_t res = cell.src;
|
| 2300 |
-
|
| 2301 |
-
// TODO: do not mutate the KV cache
|
| 2302 |
-
// ensure copy only happens once
|
| 2303 |
-
if (cell.src != (int32_t) cell_id) {
|
| 2304 |
-
cell.src = cell_id;
|
| 2305 |
-
}
|
| 2306 |
-
|
| 2307 |
-
return res;
|
| 2308 |
-
}
|
| 2309 |
-
|
| 2310 |
-
float llama_kv_cache_recurrent::s_mask(int i) const {
|
| 2311 |
-
const uint32_t cell_id = i + head;
|
| 2312 |
-
|
| 2313 |
-
//////////////////////////////////////////////
|
| 2314 |
-
// TODO: this should not mutate the KV cache !
|
| 2315 |
-
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
|
| 2316 |
-
|
| 2317 |
-
float res = (float) (cell.src >= 0);
|
| 2318 |
-
|
| 2319 |
-
// only clear once
|
| 2320 |
-
if (cell.src < 0) {
|
| 2321 |
-
cell.src = cell_id;
|
| 2322 |
-
}
|
| 2323 |
-
|
| 2324 |
-
return res;
|
| 2325 |
-
}
|
| 2326 |
-
|
| 2327 |
-
uint32_t llama_kv_cache_recurrent::cell_max() const {
|
| 2328 |
-
for (uint32_t i = size; i > 0; --i) {
|
| 2329 |
-
const kv_cell & cell = cells[i - 1];
|
| 2330 |
-
|
| 2331 |
-
if (cell.pos >= 0 && !cell.is_empty()) {
|
| 2332 |
-
return i;
|
| 2333 |
-
}
|
| 2334 |
-
}
|
| 2335 |
-
|
| 2336 |
-
return 0;
|
| 2337 |
-
}
|
| 2338 |
-
|
| 2339 |
-
size_t llama_kv_cache_recurrent::total_size() const {
|
| 2340 |
-
size_t size = 0;
|
| 2341 |
-
for (const auto & buf : bufs) {
|
| 2342 |
-
size += ggml_backend_buffer_get_size(buf.get());
|
| 2343 |
-
}
|
| 2344 |
-
|
| 2345 |
-
return size;
|
| 2346 |
-
}
|
| 2347 |
-
|
| 2348 |
-
size_t llama_kv_cache_recurrent::size_k_bytes() const {
|
| 2349 |
-
size_t size_k_bytes = 0;
|
| 2350 |
-
|
| 2351 |
-
for (const auto & k : k_l) {
|
| 2352 |
-
size_k_bytes += ggml_nbytes(k);
|
| 2353 |
-
}
|
| 2354 |
-
|
| 2355 |
-
return size_k_bytes;
|
| 2356 |
-
}
|
| 2357 |
-
|
| 2358 |
-
size_t llama_kv_cache_recurrent::size_v_bytes() const {
|
| 2359 |
-
size_t size_v_bytes = 0;
|
| 2360 |
-
|
| 2361 |
-
for (const auto & v : v_l) {
|
| 2362 |
-
size_v_bytes += ggml_nbytes(v);
|
| 2363 |
-
}
|
| 2364 |
-
|
| 2365 |
-
return size_v_bytes;
|
| 2366 |
-
}
|
| 2367 |
-
|
| 2368 |
-
void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
| 2369 |
-
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
| 2370 |
-
uint32_t cell_count = 0;
|
| 2371 |
-
|
| 2372 |
-
// Count the number of cells with the specified seq_id
|
| 2373 |
-
// Find all the ranges of cells with this seq id (or all, when -1)
|
| 2374 |
-
uint32_t cell_range_begin = size;
|
| 2375 |
-
for (uint32_t i = 0; i < size; ++i) {
|
| 2376 |
-
const auto & cell = cells[i];
|
| 2377 |
-
if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
|
| 2378 |
-
++cell_count;
|
| 2379 |
-
if (cell_range_begin == size) {
|
| 2380 |
-
cell_range_begin = i;
|
| 2381 |
-
}
|
| 2382 |
-
} else {
|
| 2383 |
-
if (cell_range_begin != size) {
|
| 2384 |
-
cell_ranges.emplace_back(cell_range_begin, i);
|
| 2385 |
-
cell_range_begin = size;
|
| 2386 |
-
}
|
| 2387 |
-
}
|
| 2388 |
-
}
|
| 2389 |
-
if (cell_range_begin != size) {
|
| 2390 |
-
cell_ranges.emplace_back(cell_range_begin, size);
|
| 2391 |
-
}
|
| 2392 |
-
|
| 2393 |
-
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
|
| 2394 |
-
uint32_t cell_count_check = 0;
|
| 2395 |
-
for (const auto & range : cell_ranges) {
|
| 2396 |
-
cell_count_check += range.second - range.first;
|
| 2397 |
-
}
|
| 2398 |
-
GGML_ASSERT(cell_count == cell_count_check);
|
| 2399 |
-
|
| 2400 |
-
io.write(&cell_count, sizeof(cell_count));
|
| 2401 |
-
|
| 2402 |
-
state_write_meta(io, cell_ranges, seq_id);
|
| 2403 |
-
state_write_data(io, cell_ranges);
|
| 2404 |
-
}
|
| 2405 |
-
|
| 2406 |
-
void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
| 2407 |
-
uint32_t cell_count;
|
| 2408 |
-
io.read_to(&cell_count, sizeof(cell_count));
|
| 2409 |
-
|
| 2410 |
-
bool res = true;
|
| 2411 |
-
|
| 2412 |
-
res = res && state_read_meta(io, cell_count, seq_id);
|
| 2413 |
-
res = res && state_read_data(io, cell_count);
|
| 2414 |
-
|
| 2415 |
-
if (!res) {
|
| 2416 |
-
if (seq_id == -1) {
|
| 2417 |
-
clear();
|
| 2418 |
-
} else {
|
| 2419 |
-
seq_rm(seq_id, -1, -1);
|
| 2420 |
-
}
|
| 2421 |
-
throw std::runtime_error("failed to restore kv cache");
|
| 2422 |
-
}
|
| 2423 |
-
}
|
| 2424 |
-
|
| 2425 |
-
void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
| 2426 |
-
for (const auto & range : cell_ranges) {
|
| 2427 |
-
for (uint32_t i = range.first; i < range.second; ++i) {
|
| 2428 |
-
const auto & cell = cells[i];
|
| 2429 |
-
const llama_pos pos = cell.pos;
|
| 2430 |
-
const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
|
| 2431 |
-
|
| 2432 |
-
io.write(&pos, sizeof(pos));
|
| 2433 |
-
io.write(&n_seq_id, sizeof(n_seq_id));
|
| 2434 |
-
|
| 2435 |
-
if (n_seq_id) {
|
| 2436 |
-
for (auto seq_id : cell.seq_id) {
|
| 2437 |
-
io.write(&seq_id, sizeof(seq_id));
|
| 2438 |
-
}
|
| 2439 |
-
}
|
| 2440 |
-
}
|
| 2441 |
-
}
|
| 2442 |
-
}
|
| 2443 |
-
|
| 2444 |
-
void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
| 2445 |
-
const uint32_t v_trans = 0;
|
| 2446 |
-
const uint32_t n_layer = hparams.n_layer;
|
| 2447 |
-
|
| 2448 |
-
io.write(&v_trans, sizeof(v_trans));
|
| 2449 |
-
io.write(&n_layer, sizeof(n_layer));
|
| 2450 |
-
|
| 2451 |
-
std::vector<uint8_t> tmp_buf;
|
| 2452 |
-
|
| 2453 |
-
// Iterate and write all the keys first, each row is a cell
|
| 2454 |
-
// Get whole range at a time
|
| 2455 |
-
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 2456 |
-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
| 2457 |
-
|
| 2458 |
-
// Write key type
|
| 2459 |
-
const int32_t k_type_i = (int32_t)k_l[il]->type;
|
| 2460 |
-
io.write(&k_type_i, sizeof(k_type_i));
|
| 2461 |
-
|
| 2462 |
-
// Write row size of key
|
| 2463 |
-
const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
|
| 2464 |
-
io.write(&k_size_row, sizeof(k_size_row));
|
| 2465 |
-
|
| 2466 |
-
// Read each range of cells of k_size length each into tmp_buf and write out
|
| 2467 |
-
for (const auto & range : cell_ranges) {
|
| 2468 |
-
const size_t range_size = range.second - range.first;
|
| 2469 |
-
const size_t buf_size = range_size * k_size_row;
|
| 2470 |
-
io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
|
| 2471 |
-
}
|
| 2472 |
-
}
|
| 2473 |
-
|
| 2474 |
-
if (!v_trans) {
|
| 2475 |
-
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 2476 |
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 2477 |
-
|
| 2478 |
-
// Write value type
|
| 2479 |
-
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
| 2480 |
-
io.write(&v_type_i, sizeof(v_type_i));
|
| 2481 |
-
|
| 2482 |
-
// Write row size of value
|
| 2483 |
-
const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
|
| 2484 |
-
io.write(&v_size_row, sizeof(v_size_row));
|
| 2485 |
-
|
| 2486 |
-
// Read each range of cells of v_size length each into tmp_buf and write out
|
| 2487 |
-
for (const auto & range : cell_ranges) {
|
| 2488 |
-
const size_t range_size = range.second - range.first;
|
| 2489 |
-
const size_t buf_size = range_size * v_size_row;
|
| 2490 |
-
io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
|
| 2491 |
-
}
|
| 2492 |
-
}
|
| 2493 |
-
} else {
|
| 2494 |
-
// When v is transposed, we also need the element size and get the element ranges from each row
|
| 2495 |
-
const uint32_t kv_size = size;
|
| 2496 |
-
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 2497 |
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 2498 |
-
|
| 2499 |
-
// Write value type
|
| 2500 |
-
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
| 2501 |
-
io.write(&v_type_i, sizeof(v_type_i));
|
| 2502 |
-
|
| 2503 |
-
// Write element size
|
| 2504 |
-
const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
|
| 2505 |
-
io.write(&v_size_el, sizeof(v_size_el));
|
| 2506 |
-
|
| 2507 |
-
// Write GQA embedding size
|
| 2508 |
-
io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
|
| 2509 |
-
|
| 2510 |
-
// For each row, we get the element values of each cell
|
| 2511 |
-
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
| 2512 |
-
// Read each range of cells of v_size_el length each into tmp_buf and write out
|
| 2513 |
-
for (const auto & range : cell_ranges) {
|
| 2514 |
-
const size_t range_size = range.second - range.first;
|
| 2515 |
-
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
|
| 2516 |
-
const size_t buf_size = range_size * v_size_el;
|
| 2517 |
-
io.write_tensor(v_l[il], src_offset, buf_size);
|
| 2518 |
-
}
|
| 2519 |
-
}
|
| 2520 |
-
}
|
| 2521 |
-
}
|
| 2522 |
-
}
|
| 2523 |
-
|
| 2524 |
-
bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
| 2525 |
-
if (dest_seq_id != -1) {
|
| 2526 |
-
// single sequence
|
| 2527 |
-
|
| 2528 |
-
seq_rm(dest_seq_id, -1, -1);
|
| 2529 |
-
|
| 2530 |
-
llama_sbatch sbatch;
|
| 2531 |
-
llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
| 2532 |
-
|
| 2533 |
-
batch.n_tokens = cell_count;
|
| 2534 |
-
batch.n_seq_tokens = cell_count;
|
| 2535 |
-
batch.n_seqs = 1;
|
| 2536 |
-
|
| 2537 |
-
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 2538 |
-
llama_pos pos;
|
| 2539 |
-
uint32_t n_seq_id;
|
| 2540 |
-
|
| 2541 |
-
io.read_to(&pos, sizeof(pos));
|
| 2542 |
-
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
| 2543 |
-
|
| 2544 |
-
if (n_seq_id != 0) {
|
| 2545 |
-
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
| 2546 |
-
return false;
|
| 2547 |
-
}
|
| 2548 |
-
|
| 2549 |
-
batch.pos[i] = pos;
|
| 2550 |
-
}
|
| 2551 |
-
batch.n_seq_id[0] = 1;
|
| 2552 |
-
batch.seq_id[0] = &dest_seq_id;
|
| 2553 |
-
if (!find_slot(batch)) {
|
| 2554 |
-
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
| 2555 |
-
return false;
|
| 2556 |
-
}
|
| 2557 |
-
commit();
|
| 2558 |
-
|
| 2559 |
-
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
| 2560 |
-
// Assume that this is one contiguous block of cells
|
| 2561 |
-
GGML_ASSERT(head + cell_count <= size);
|
| 2562 |
-
GGML_ASSERT(cells[head].pos == batch.pos[0]);
|
| 2563 |
-
GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
|
| 2564 |
-
GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
|
| 2565 |
-
GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
|
| 2566 |
-
} else {
|
| 2567 |
-
// whole KV cache restore
|
| 2568 |
-
|
| 2569 |
-
if (cell_count > size) {
|
| 2570 |
-
LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
|
| 2571 |
-
return false;
|
| 2572 |
-
}
|
| 2573 |
-
|
| 2574 |
-
clear();
|
| 2575 |
-
|
| 2576 |
-
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 2577 |
-
kv_cell & cell = cells[i];
|
| 2578 |
-
|
| 2579 |
-
llama_pos pos;
|
| 2580 |
-
uint32_t n_seq_id;
|
| 2581 |
-
|
| 2582 |
-
io.read_to(&pos, sizeof(pos));
|
| 2583 |
-
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
| 2584 |
-
|
| 2585 |
-
cell.pos = pos;
|
| 2586 |
-
|
| 2587 |
-
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
| 2588 |
-
llama_seq_id seq_id;
|
| 2589 |
-
io.read_to(&seq_id, sizeof(seq_id));
|
| 2590 |
-
|
| 2591 |
-
// TODO: llama_kv_cache_recurrent should have a notion of max sequences
|
| 2592 |
-
//if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
|
| 2593 |
-
if (seq_id < 0) {
|
| 2594 |
-
//LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
|
| 2595 |
-
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
|
| 2596 |
-
return false;
|
| 2597 |
-
}
|
| 2598 |
-
|
| 2599 |
-
cell.seq_id.insert(seq_id);
|
| 2600 |
-
|
| 2601 |
-
int32_t & tail = cells[seq_id].tail;
|
| 2602 |
-
if (tail != -1) {
|
| 2603 |
-
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
|
| 2604 |
-
return false;
|
| 2605 |
-
}
|
| 2606 |
-
tail = i;
|
| 2607 |
-
}
|
| 2608 |
-
}
|
| 2609 |
-
|
| 2610 |
-
head = 0;
|
| 2611 |
-
used = cell_count;
|
| 2612 |
-
}
|
| 2613 |
-
|
| 2614 |
-
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 2615 |
-
uint32_t cell_id = head + i;
|
| 2616 |
-
// make sure the recurrent states will keep their restored state
|
| 2617 |
-
cells[cell_id].src = cell_id;
|
| 2618 |
-
}
|
| 2619 |
-
|
| 2620 |
-
return true;
|
| 2621 |
-
}
|
| 2622 |
-
|
| 2623 |
-
bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
| 2624 |
-
uint32_t v_trans;
|
| 2625 |
-
uint32_t n_layer;
|
| 2626 |
-
io.read_to(&v_trans, sizeof(v_trans));
|
| 2627 |
-
io.read_to(&n_layer, sizeof(n_layer));
|
| 2628 |
-
|
| 2629 |
-
if (n_layer != hparams.n_layer) {
|
| 2630 |
-
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
|
| 2631 |
-
return false;
|
| 2632 |
-
}
|
| 2633 |
-
if (cell_count > size) {
|
| 2634 |
-
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
|
| 2635 |
-
return false;
|
| 2636 |
-
}
|
| 2637 |
-
if (false != (bool) v_trans) {
|
| 2638 |
-
LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
|
| 2639 |
-
return false;
|
| 2640 |
-
}
|
| 2641 |
-
|
| 2642 |
-
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
|
| 2643 |
-
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 2644 |
-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
| 2645 |
-
|
| 2646 |
-
// Read type of key
|
| 2647 |
-
int32_t k_type_i_ref;
|
| 2648 |
-
io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
|
| 2649 |
-
const int32_t k_type_i = (int32_t) k_l[il]->type;
|
| 2650 |
-
if (k_type_i != k_type_i_ref) {
|
| 2651 |
-
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
|
| 2652 |
-
return false;
|
| 2653 |
-
}
|
| 2654 |
-
|
| 2655 |
-
// Read row size of key
|
| 2656 |
-
uint64_t k_size_row_ref;
|
| 2657 |
-
io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
|
| 2658 |
-
const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
|
| 2659 |
-
if (k_size_row != k_size_row_ref) {
|
| 2660 |
-
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
|
| 2661 |
-
return false;
|
| 2662 |
-
}
|
| 2663 |
-
|
| 2664 |
-
if (cell_count) {
|
| 2665 |
-
// Read and set the keys for the whole cell range
|
| 2666 |
-
ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
|
| 2667 |
-
}
|
| 2668 |
-
}
|
| 2669 |
-
|
| 2670 |
-
if (!v_trans) {
|
| 2671 |
-
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 2672 |
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 2673 |
-
|
| 2674 |
-
// Read type of value
|
| 2675 |
-
int32_t v_type_i_ref;
|
| 2676 |
-
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
| 2677 |
-
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
| 2678 |
-
if (v_type_i != v_type_i_ref) {
|
| 2679 |
-
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
| 2680 |
-
return false;
|
| 2681 |
-
}
|
| 2682 |
-
|
| 2683 |
-
// Read row size of value
|
| 2684 |
-
uint64_t v_size_row_ref;
|
| 2685 |
-
io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
|
| 2686 |
-
const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
|
| 2687 |
-
if (v_size_row != v_size_row_ref) {
|
| 2688 |
-
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
|
| 2689 |
-
return false;
|
| 2690 |
-
}
|
| 2691 |
-
|
| 2692 |
-
if (cell_count) {
|
| 2693 |
-
// Read and set the values for the whole cell range
|
| 2694 |
-
ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
|
| 2695 |
-
}
|
| 2696 |
-
}
|
| 2697 |
-
} else {
|
| 2698 |
-
// For each layer, read the values for each cell (transposed)
|
| 2699 |
-
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 2700 |
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 2701 |
-
|
| 2702 |
-
// Read type of value
|
| 2703 |
-
int32_t v_type_i_ref;
|
| 2704 |
-
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
| 2705 |
-
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
| 2706 |
-
if (v_type_i != v_type_i_ref) {
|
| 2707 |
-
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
| 2708 |
-
return false;
|
| 2709 |
-
}
|
| 2710 |
-
|
| 2711 |
-
// Read element size of value
|
| 2712 |
-
uint32_t v_size_el_ref;
|
| 2713 |
-
io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
|
| 2714 |
-
const size_t v_size_el = ggml_type_size(v_l[il]->type);
|
| 2715 |
-
if (v_size_el != v_size_el_ref) {
|
| 2716 |
-
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
|
| 2717 |
-
return false;
|
| 2718 |
-
}
|
| 2719 |
-
|
| 2720 |
-
// Read GQA embedding size
|
| 2721 |
-
uint32_t n_embd_v_gqa_ref;
|
| 2722 |
-
io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
|
| 2723 |
-
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
|
| 2724 |
-
LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
|
| 2725 |
-
return false;
|
| 2726 |
-
}
|
| 2727 |
-
|
| 2728 |
-
if (cell_count) {
|
| 2729 |
-
// For each row in the transposed matrix, read the values for the whole cell range
|
| 2730 |
-
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
| 2731 |
-
const size_t dst_offset = (head + j * size) * v_size_el;
|
| 2732 |
-
ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
| 2733 |
-
}
|
| 2734 |
-
}
|
| 2735 |
-
}
|
| 2736 |
-
}
|
| 2737 |
-
|
| 2738 |
-
return true;
|
| 2739 |
-
}
|
|
|
|
| 1 |
#include "llama-kv-cache.h"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/talk-llama/llama-kv-cache.h
CHANGED
|
@@ -2,59 +2,33 @@
|
|
| 2 |
|
| 3 |
#include "llama.h"
|
| 4 |
#include "llama-io.h"
|
| 5 |
-
#include "llama-graph.h"
|
| 6 |
#include "llama-memory.h"
|
| 7 |
-
#include "llama-kv-cells.h"
|
| 8 |
-
|
| 9 |
-
#include "ggml-cpp.h"
|
| 10 |
-
|
| 11 |
-
#include <set>
|
| 12 |
-
#include <unordered_map>
|
| 13 |
-
#include <vector>
|
| 14 |
-
|
| 15 |
-
struct llama_cparams;
|
| 16 |
-
struct llama_hparams;
|
| 17 |
-
struct llama_ubatch;
|
| 18 |
-
struct llama_sbatch;
|
| 19 |
-
struct llama_model;
|
| 20 |
-
struct llama_context;
|
| 21 |
|
| 22 |
struct llama_kv_cache : public llama_memory_i {
|
| 23 |
virtual ~llama_kv_cache() = default;
|
| 24 |
|
| 25 |
-
//
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
//
|
| 29 |
-
virtual
|
| 30 |
|
| 31 |
// process any pending defrag/shift/etc. operations
|
| 32 |
// optionally call once before processing a new batch
|
|
|
|
| 33 |
virtual bool update(llama_context & lctx) = 0;
|
| 34 |
|
| 35 |
// schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
// simulate full cache, used for allocating worst-case compute buffers
|
| 39 |
-
// TODO: remove
|
| 40 |
-
virtual void set_full() = 0;
|
| 41 |
-
|
| 42 |
//
|
| 43 |
-
|
| 44 |
-
//
|
| 45 |
-
|
| 46 |
-
// =============================================================================================================
|
| 47 |
-
// TODO: refactor and simplify this [TAG: KV_API]
|
| 48 |
-
|
| 49 |
-
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
|
| 50 |
-
|
| 51 |
-
// different KV caches require different batch splitting strategies
|
| 52 |
-
virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
|
| 53 |
-
|
| 54 |
-
// find an empty slot of size "n_tokens" in the cache
|
| 55 |
-
virtual bool find_slot(const llama_ubatch & batch) = 0;
|
| 56 |
-
|
| 57 |
-
// =============================================================================================================
|
| 58 |
|
| 59 |
// getters
|
| 60 |
virtual bool get_can_shift() const = 0;
|
|
@@ -68,435 +42,3 @@ struct llama_kv_cache : public llama_memory_i {
|
|
| 68 |
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
|
| 69 |
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
|
| 70 |
};
|
| 71 |
-
|
| 72 |
-
//
|
| 73 |
-
// llama_kv_cache_guard
|
| 74 |
-
//
|
| 75 |
-
|
| 76 |
-
struct llama_kv_cache_guard {
|
| 77 |
-
llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
|
| 78 |
-
|
| 79 |
-
~llama_kv_cache_guard() {
|
| 80 |
-
kv->restore();
|
| 81 |
-
}
|
| 82 |
-
|
| 83 |
-
void commit() {
|
| 84 |
-
kv->commit();
|
| 85 |
-
}
|
| 86 |
-
|
| 87 |
-
private:
|
| 88 |
-
llama_kv_cache * kv;
|
| 89 |
-
};
|
| 90 |
-
|
| 91 |
-
//
|
| 92 |
-
// llama_kv_cache_unified
|
| 93 |
-
//
|
| 94 |
-
|
| 95 |
-
class llama_kv_cache_unified : public llama_kv_cache {
|
| 96 |
-
public:
|
| 97 |
-
static uint32_t get_padding(const llama_cparams & cparams);
|
| 98 |
-
|
| 99 |
-
// this callback is used to filter out layers that should not be included in the cache
|
| 100 |
-
using layer_filter_cb = std::function<bool(int32_t il)>;
|
| 101 |
-
|
| 102 |
-
llama_kv_cache_unified(
|
| 103 |
-
const llama_model & model,
|
| 104 |
-
layer_filter_cb && filter,
|
| 105 |
-
ggml_type type_k,
|
| 106 |
-
ggml_type type_v,
|
| 107 |
-
bool v_trans,
|
| 108 |
-
bool offload,
|
| 109 |
-
uint32_t kv_size,
|
| 110 |
-
uint32_t n_seq_max,
|
| 111 |
-
uint32_t n_pad,
|
| 112 |
-
uint32_t n_swa,
|
| 113 |
-
llama_swa_type swa_type);
|
| 114 |
-
|
| 115 |
-
~llama_kv_cache_unified() = default;
|
| 116 |
-
|
| 117 |
-
//
|
| 118 |
-
// llama_memory_i
|
| 119 |
-
//
|
| 120 |
-
|
| 121 |
-
void clear() override;
|
| 122 |
-
|
| 123 |
-
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
| 124 |
-
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
| 125 |
-
void seq_keep(llama_seq_id seq_id) override;
|
| 126 |
-
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
| 127 |
-
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
| 128 |
-
|
| 129 |
-
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
| 130 |
-
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
| 131 |
-
|
| 132 |
-
//
|
| 133 |
-
// llama_kv_cache
|
| 134 |
-
//
|
| 135 |
-
|
| 136 |
-
void restore() override;
|
| 137 |
-
void commit() override;
|
| 138 |
-
|
| 139 |
-
bool update(llama_context & ctx) override;
|
| 140 |
-
|
| 141 |
-
void defrag_sched(float thold) override;
|
| 142 |
-
|
| 143 |
-
void set_full() override;
|
| 144 |
-
|
| 145 |
-
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
|
| 146 |
-
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
|
| 147 |
-
|
| 148 |
-
// updates the cache head
|
| 149 |
-
// Note: On success, it's important that cache.head points
|
| 150 |
-
// to the first cell of the slot.
|
| 151 |
-
bool find_slot(const llama_ubatch & batch) override;
|
| 152 |
-
|
| 153 |
-
bool get_can_shift() const override;
|
| 154 |
-
|
| 155 |
-
// state write/load
|
| 156 |
-
|
| 157 |
-
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
| 158 |
-
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
| 159 |
-
|
| 160 |
-
//
|
| 161 |
-
// llama_kv_cache_unified specific API
|
| 162 |
-
//
|
| 163 |
-
|
| 164 |
-
uint32_t get_n() const;
|
| 165 |
-
uint32_t get_size() const;
|
| 166 |
-
|
| 167 |
-
// get views of the current state of the cache
|
| 168 |
-
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
|
| 169 |
-
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
| 170 |
-
|
| 171 |
-
// store k_cur and v_cur in the cache based on the current head location
|
| 172 |
-
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
|
| 173 |
-
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
|
| 174 |
-
|
| 175 |
-
void prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax);
|
| 176 |
-
|
| 177 |
-
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
| 178 |
-
void set_input_k_shift (ggml_tensor * dst) const;
|
| 179 |
-
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
| 180 |
-
|
| 181 |
-
private:
|
| 182 |
-
const llama_model & model;
|
| 183 |
-
const llama_hparams & hparams;
|
| 184 |
-
|
| 185 |
-
struct kv_layer {
|
| 186 |
-
// layer index in the model
|
| 187 |
-
// note: can be different from the layer index in the KV cache
|
| 188 |
-
uint32_t il;
|
| 189 |
-
|
| 190 |
-
ggml_tensor * k;
|
| 191 |
-
ggml_tensor * v;
|
| 192 |
-
};
|
| 193 |
-
|
| 194 |
-
bool do_defrag = false;
|
| 195 |
-
bool v_trans = true; // the value tensor is transposed
|
| 196 |
-
|
| 197 |
-
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
| 198 |
-
|
| 199 |
-
// computed before each graph build
|
| 200 |
-
// TODO: cells should start to maintain this value dynamically based on the edits
|
| 201 |
-
uint32_t n = 0;
|
| 202 |
-
|
| 203 |
-
const uint32_t n_seq_max = 1;
|
| 204 |
-
|
| 205 |
-
// required padding
|
| 206 |
-
const uint32_t n_pad = 1;
|
| 207 |
-
|
| 208 |
-
// SWA
|
| 209 |
-
const uint32_t n_swa = 0;
|
| 210 |
-
|
| 211 |
-
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
| 212 |
-
|
| 213 |
-
std::vector<ggml_context_ptr> ctxs;
|
| 214 |
-
std::vector<ggml_backend_buffer_ptr> bufs;
|
| 215 |
-
|
| 216 |
-
llama_kv_cells_unified cells;
|
| 217 |
-
|
| 218 |
-
std::vector<kv_layer> layers;
|
| 219 |
-
|
| 220 |
-
// model layer id -> KV cache layer id
|
| 221 |
-
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
| 222 |
-
|
| 223 |
-
// recovery information used to restore the KV cells to their original state in case of a failure
|
| 224 |
-
// TODO: do not store as a state in the llama_kv_cache object, instead return upon batch preparation
|
| 225 |
-
// to achieve that, first need to refactor the llama_kv_cache interface [TAG: KV_API]
|
| 226 |
-
struct {
|
| 227 |
-
void clear() {
|
| 228 |
-
states.clear();
|
| 229 |
-
}
|
| 230 |
-
|
| 231 |
-
struct state {
|
| 232 |
-
uint32_t i;
|
| 233 |
-
|
| 234 |
-
llama_kv_cells_unified cells;
|
| 235 |
-
};
|
| 236 |
-
|
| 237 |
-
// stack with the partial states before each ubatch
|
| 238 |
-
std::vector<state> states;
|
| 239 |
-
} recovery;
|
| 240 |
-
|
| 241 |
-
// defrag
|
| 242 |
-
struct {
|
| 243 |
-
std::vector<uint32_t> ids;
|
| 244 |
-
} defrag_info;
|
| 245 |
-
|
| 246 |
-
// return true if cells have been moved
|
| 247 |
-
bool defrag_prepare(int32_t n_max_nodes);
|
| 248 |
-
|
| 249 |
-
size_t total_size() const;
|
| 250 |
-
|
| 251 |
-
size_t size_k_bytes() const;
|
| 252 |
-
size_t size_v_bytes() const;
|
| 253 |
-
|
| 254 |
-
bool is_masked_swa(llama_pos p0, llama_pos p1) const;
|
| 255 |
-
|
| 256 |
-
ggml_tensor * build_rope_shift(
|
| 257 |
-
const llama_cparams & cparams,
|
| 258 |
-
ggml_context * ctx,
|
| 259 |
-
ggml_tensor * cur,
|
| 260 |
-
ggml_tensor * shift,
|
| 261 |
-
ggml_tensor * factors,
|
| 262 |
-
float freq_base,
|
| 263 |
-
float freq_scale) const;
|
| 264 |
-
|
| 265 |
-
llm_graph_result_ptr build_graph_shift(
|
| 266 |
-
const llama_cparams & cparams,
|
| 267 |
-
ggml_context * ctx,
|
| 268 |
-
ggml_cgraph * gf) const;
|
| 269 |
-
|
| 270 |
-
llm_graph_result_ptr build_graph_defrag(
|
| 271 |
-
const llama_cparams & cparams,
|
| 272 |
-
ggml_context * ctx,
|
| 273 |
-
ggml_cgraph * gf) const;
|
| 274 |
-
|
| 275 |
-
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
| 276 |
-
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
| 277 |
-
|
| 278 |
-
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
| 279 |
-
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
| 280 |
-
};
|
| 281 |
-
|
| 282 |
-
//
|
| 283 |
-
// llama_kv_cache_unified_iswa
|
| 284 |
-
//
|
| 285 |
-
|
| 286 |
-
// utilizes two instances of llama_kv_cache_unified
|
| 287 |
-
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
|
| 288 |
-
// upon successful commit, the SWA cache removes old tokens outside the n_swa window
|
| 289 |
-
|
| 290 |
-
class llama_kv_cache_unified_iswa : public llama_kv_cache {
|
| 291 |
-
public:
|
| 292 |
-
llama_kv_cache_unified_iswa(
|
| 293 |
-
const llama_model & model,
|
| 294 |
-
ggml_type type_k,
|
| 295 |
-
ggml_type type_v,
|
| 296 |
-
bool v_trans,
|
| 297 |
-
bool offload,
|
| 298 |
-
bool swa_full,
|
| 299 |
-
uint32_t kv_size,
|
| 300 |
-
uint32_t n_seq_max,
|
| 301 |
-
uint32_t n_batch,
|
| 302 |
-
uint32_t n_pad);
|
| 303 |
-
|
| 304 |
-
~llama_kv_cache_unified_iswa() = default;
|
| 305 |
-
|
| 306 |
-
//
|
| 307 |
-
// llama_memory_i
|
| 308 |
-
//
|
| 309 |
-
|
| 310 |
-
void clear() override;
|
| 311 |
-
|
| 312 |
-
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
| 313 |
-
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
| 314 |
-
void seq_keep(llama_seq_id seq_id) override;
|
| 315 |
-
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
| 316 |
-
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
| 317 |
-
|
| 318 |
-
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
| 319 |
-
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
| 320 |
-
|
| 321 |
-
//
|
| 322 |
-
// llama_kv_cache
|
| 323 |
-
//
|
| 324 |
-
|
| 325 |
-
void restore() override;
|
| 326 |
-
void commit() override;
|
| 327 |
-
|
| 328 |
-
bool update(llama_context & ctx) override;
|
| 329 |
-
|
| 330 |
-
void defrag_sched(float thold) override;
|
| 331 |
-
|
| 332 |
-
void set_full() override;
|
| 333 |
-
|
| 334 |
-
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
|
| 335 |
-
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
|
| 336 |
-
|
| 337 |
-
bool find_slot(const llama_ubatch & batch) override;
|
| 338 |
-
|
| 339 |
-
bool get_can_shift() const override;
|
| 340 |
-
|
| 341 |
-
// state write/load
|
| 342 |
-
|
| 343 |
-
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
| 344 |
-
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
| 345 |
-
|
| 346 |
-
//
|
| 347 |
-
// llama_kv_cache_unified_iswa specific API
|
| 348 |
-
//
|
| 349 |
-
|
| 350 |
-
llama_kv_cache_unified * get_kv_base() const;
|
| 351 |
-
llama_kv_cache_unified * get_kv_swa () const;
|
| 352 |
-
|
| 353 |
-
private:
|
| 354 |
-
const llama_hparams & hparams;
|
| 355 |
-
|
| 356 |
-
bool do_prune = true;
|
| 357 |
-
|
| 358 |
-
struct {
|
| 359 |
-
struct entry {
|
| 360 |
-
llama_pos pmin;
|
| 361 |
-
llama_pos pmax;
|
| 362 |
-
};
|
| 363 |
-
|
| 364 |
-
void clear() {
|
| 365 |
-
pos.clear();
|
| 366 |
-
}
|
| 367 |
-
|
| 368 |
-
// used to perform SWA pruning of old tokens
|
| 369 |
-
std::unordered_map<llama_seq_id, entry> pos;
|
| 370 |
-
} pending;
|
| 371 |
-
|
| 372 |
-
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
| 373 |
-
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
| 374 |
-
};
|
| 375 |
-
|
| 376 |
-
//
|
| 377 |
-
// llama_kv_cache_recurrent
|
| 378 |
-
//
|
| 379 |
-
|
| 380 |
-
class llama_kv_cache_recurrent : public llama_kv_cache {
|
| 381 |
-
public:
|
| 382 |
-
struct kv_cell {
|
| 383 |
-
llama_pos pos = -1;
|
| 384 |
-
int32_t src = -1; // used to copy states
|
| 385 |
-
int32_t tail = -1;
|
| 386 |
-
|
| 387 |
-
std::set<llama_seq_id> seq_id;
|
| 388 |
-
|
| 389 |
-
bool has_seq_id(const llama_seq_id & id) const {
|
| 390 |
-
return seq_id.find(id) != seq_id.end();
|
| 391 |
-
}
|
| 392 |
-
|
| 393 |
-
bool is_empty() const {
|
| 394 |
-
return seq_id.empty();
|
| 395 |
-
}
|
| 396 |
-
|
| 397 |
-
bool is_same_seq(const kv_cell & other) const {
|
| 398 |
-
return seq_id == other.seq_id;
|
| 399 |
-
}
|
| 400 |
-
};
|
| 401 |
-
|
| 402 |
-
llama_kv_cache_recurrent(
|
| 403 |
-
const llama_model & model,
|
| 404 |
-
ggml_type type_k,
|
| 405 |
-
ggml_type type_v,
|
| 406 |
-
bool offload,
|
| 407 |
-
uint32_t kv_size,
|
| 408 |
-
uint32_t n_seq_max);
|
| 409 |
-
|
| 410 |
-
~llama_kv_cache_recurrent() = default;
|
| 411 |
-
|
| 412 |
-
//
|
| 413 |
-
// llama_memory_i
|
| 414 |
-
//
|
| 415 |
-
|
| 416 |
-
void clear() override;
|
| 417 |
-
|
| 418 |
-
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
| 419 |
-
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
| 420 |
-
void seq_keep(llama_seq_id seq_id) override;
|
| 421 |
-
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
| 422 |
-
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
| 423 |
-
|
| 424 |
-
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
| 425 |
-
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
| 426 |
-
|
| 427 |
-
//
|
| 428 |
-
// llama_kv_cache
|
| 429 |
-
//
|
| 430 |
-
|
| 431 |
-
void restore() override;
|
| 432 |
-
void commit() override;
|
| 433 |
-
|
| 434 |
-
bool update(llama_context & ctx) override;
|
| 435 |
-
|
| 436 |
-
void defrag_sched(float thold) override;
|
| 437 |
-
|
| 438 |
-
void set_full() override;
|
| 439 |
-
|
| 440 |
-
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
|
| 441 |
-
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
|
| 442 |
-
|
| 443 |
-
bool find_slot(const llama_ubatch & batch) override;
|
| 444 |
-
|
| 445 |
-
bool get_can_shift() const override;
|
| 446 |
-
|
| 447 |
-
// TODO: temporary methods - they are not really const as they do const_cast<>, fix this
|
| 448 |
-
int32_t s_copy(int i) const;
|
| 449 |
-
float s_mask(int i) const;
|
| 450 |
-
|
| 451 |
-
// state write/load
|
| 452 |
-
|
| 453 |
-
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
| 454 |
-
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
| 455 |
-
|
| 456 |
-
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
| 457 |
-
uint32_t size = 0; // total number of cells, shared across all sequences
|
| 458 |
-
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
| 459 |
-
|
| 460 |
-
// computed before each graph build
|
| 461 |
-
uint32_t n = 0;
|
| 462 |
-
|
| 463 |
-
std::vector<kv_cell> cells;
|
| 464 |
-
|
| 465 |
-
std::vector<ggml_tensor *> k_l; // per layer
|
| 466 |
-
std::vector<ggml_tensor *> v_l;
|
| 467 |
-
|
| 468 |
-
private:
|
| 469 |
-
//const llama_model & model;
|
| 470 |
-
const llama_hparams & hparams;
|
| 471 |
-
|
| 472 |
-
// commit/restore cache
|
| 473 |
-
// TODO: rework for recurrent cache
|
| 474 |
-
struct slot_range {
|
| 475 |
-
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
|
| 476 |
-
uint32_t c1 = 0;
|
| 477 |
-
};
|
| 478 |
-
|
| 479 |
-
// pending cell updates that are not yet committed
|
| 480 |
-
struct {
|
| 481 |
-
std::vector<slot_range> ranges;
|
| 482 |
-
} pending;
|
| 483 |
-
|
| 484 |
-
const uint32_t n_seq_max = 1;
|
| 485 |
-
|
| 486 |
-
std::vector<ggml_context_ptr> ctxs;
|
| 487 |
-
std::vector<ggml_backend_buffer_ptr> bufs;
|
| 488 |
-
|
| 489 |
-
// find how many cells are currently in use
|
| 490 |
-
uint32_t cell_max() const;
|
| 491 |
-
|
| 492 |
-
size_t total_size() const;
|
| 493 |
-
|
| 494 |
-
size_t size_k_bytes() const;
|
| 495 |
-
size_t size_v_bytes() const;
|
| 496 |
-
|
| 497 |
-
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
| 498 |
-
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
| 499 |
-
|
| 500 |
-
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
| 501 |
-
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
| 502 |
-
};
|
|
|
|
| 2 |
|
| 3 |
#include "llama.h"
|
| 4 |
#include "llama-io.h"
|
|
|
|
| 5 |
#include "llama-memory.h"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
struct llama_kv_cache : public llama_memory_i {
|
| 8 |
virtual ~llama_kv_cache() = default;
|
| 9 |
|
| 10 |
+
// split the input batch into a set of ubatches and verify that they can fit into the cache
|
| 11 |
+
// return a state object containing the ubatches and KV cache state required to process them
|
| 12 |
+
// check the llama_memory_state_i::get_status() for the result
|
| 13 |
+
virtual llama_memory_state_ptr init_batch(
|
| 14 |
+
const llama_batch & batch,
|
| 15 |
+
uint32_t n_ubatch,
|
| 16 |
+
bool embd_pooled,
|
| 17 |
+
bool logits_all) = 0;
|
| 18 |
|
| 19 |
+
// simulate full cache, used for allocating worst-case compute buffers
|
| 20 |
+
virtual llama_memory_state_ptr init_full() = 0;
|
| 21 |
|
| 22 |
// process any pending defrag/shift/etc. operations
|
| 23 |
// optionally call once before processing a new batch
|
| 24 |
+
// return true if any operations were performed
|
| 25 |
virtual bool update(llama_context & lctx) = 0;
|
| 26 |
|
| 27 |
// schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
|
| 28 |
+
// TODO: change to
|
| 29 |
+
// llama_memory_state_ptr init_defrag(float thold) = 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
//
|
| 31 |
+
virtual void defrag_sched(float thold) = 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
// getters
|
| 34 |
virtual bool get_can_shift() const = 0;
|
|
|
|
| 42 |
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
|
| 43 |
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
|
| 44 |
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/talk-llama/llama-kv-cells.h
CHANGED
|
@@ -68,12 +68,6 @@ public:
|
|
| 68 |
// the index of the last cell that is used + 1
|
| 69 |
// return 0 if no cells are used
|
| 70 |
uint32_t used_max_p1() const {
|
| 71 |
-
#if 0
|
| 72 |
-
if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin());
|
| 73 |
-
if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin());
|
| 74 |
-
if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin());
|
| 75 |
-
#endif
|
| 76 |
-
|
| 77 |
return used.empty() ? 0 : *used.rbegin() + 1;
|
| 78 |
}
|
| 79 |
|
|
@@ -144,6 +138,19 @@ public:
|
|
| 144 |
}
|
| 145 |
}
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
// note: call only if the cell has seq_id
|
| 148 |
// return true if the cell becomes empty
|
| 149 |
bool seq_rm(uint32_t i, llama_seq_id seq_id) {
|
|
@@ -196,6 +203,15 @@ public:
|
|
| 196 |
return false;
|
| 197 |
}
|
| 198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
bool seq_has(uint32_t i, llama_seq_id seq_id) const {
|
| 200 |
assert(i < pos.size());
|
| 201 |
assert(seq_id >= 0);
|
|
@@ -213,6 +229,20 @@ public:
|
|
| 213 |
seq_pos[seq_id].insert(pos[i]);
|
| 214 |
}
|
| 215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
// the minimum position of sequence seq_id currently present in any of the cells
|
| 217 |
// return -1 if the sequence is not present
|
| 218 |
llama_pos seq_pos_min(llama_seq_id seq_id) const {
|
|
@@ -268,6 +298,7 @@ public:
|
|
| 268 |
void pos_set(uint32_t i, llama_pos p) {
|
| 269 |
assert(i < pos.size());
|
| 270 |
assert(pos[i] == -1);
|
|
|
|
| 271 |
|
| 272 |
pos[i] = p;
|
| 273 |
|
|
|
|
| 68 |
// the index of the last cell that is used + 1
|
| 69 |
// return 0 if no cells are used
|
| 70 |
uint32_t used_max_p1() const {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
return used.empty() ? 0 : *used.rbegin() + 1;
|
| 72 |
}
|
| 73 |
|
|
|
|
| 138 |
}
|
| 139 |
}
|
| 140 |
|
| 141 |
+
// clear a non-empty cell
|
| 142 |
+
void rm(uint32_t i) {
|
| 143 |
+
assert(i < pos.size());
|
| 144 |
+
assert(pos[i] != -1);
|
| 145 |
+
|
| 146 |
+
seq_pos_rm(i);
|
| 147 |
+
|
| 148 |
+
pos[i] = -1;
|
| 149 |
+
seq[i].reset();
|
| 150 |
+
|
| 151 |
+
used.erase(i);
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
// note: call only if the cell has seq_id
|
| 155 |
// return true if the cell becomes empty
|
| 156 |
bool seq_rm(uint32_t i, llama_seq_id seq_id) {
|
|
|
|
| 203 |
return false;
|
| 204 |
}
|
| 205 |
|
| 206 |
+
// number of different sequences in the cell
|
| 207 |
+
int seq_count(uint32_t i) const {
|
| 208 |
+
assert(i < pos.size());
|
| 209 |
+
assert(pos[i] != -1);
|
| 210 |
+
|
| 211 |
+
return seq[i].count();
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
// check if the cell contains seq_id
|
| 215 |
bool seq_has(uint32_t i, llama_seq_id seq_id) const {
|
| 216 |
assert(i < pos.size());
|
| 217 |
assert(seq_id >= 0);
|
|
|
|
| 229 |
seq_pos[seq_id].insert(pos[i]);
|
| 230 |
}
|
| 231 |
|
| 232 |
+
// return the sequence id of this cell
|
| 233 |
+
// note: call only for cells with exactly one sequence
|
| 234 |
+
llama_seq_id seq_get(uint32_t i) const {
|
| 235 |
+
assert(seq[i].count() == 1);
|
| 236 |
+
|
| 237 |
+
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
| 238 |
+
if (seq[i].test(s)) {
|
| 239 |
+
return s;
|
| 240 |
+
}
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
return -1;
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
// the minimum position of sequence seq_id currently present in any of the cells
|
| 247 |
// return -1 if the sequence is not present
|
| 248 |
llama_pos seq_pos_min(llama_seq_id seq_id) const {
|
|
|
|
| 298 |
void pos_set(uint32_t i, llama_pos p) {
|
| 299 |
assert(i < pos.size());
|
| 300 |
assert(pos[i] == -1);
|
| 301 |
+
assert(seq[i].none());
|
| 302 |
|
| 303 |
pos[i] = p;
|
| 304 |
|
examples/talk-llama/llama-memory.h
CHANGED
|
@@ -2,6 +2,11 @@
|
|
| 2 |
|
| 3 |
#include "llama.h"
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
struct llama_memory_params {
|
| 6 |
// kv cache
|
| 7 |
ggml_type type_k;
|
|
@@ -30,3 +35,42 @@ public:
|
|
| 30 |
|
| 31 |
virtual bool get_can_edit() const = 0;
|
| 32 |
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
#include "llama.h"
|
| 4 |
|
| 5 |
+
#include <memory>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
struct llama_ubatch;
|
| 9 |
+
|
| 10 |
struct llama_memory_params {
|
| 11 |
// kv cache
|
| 12 |
ggml_type type_k;
|
|
|
|
| 35 |
|
| 36 |
virtual bool get_can_edit() const = 0;
|
| 37 |
};
|
| 38 |
+
|
| 39 |
+
enum llama_memory_status {
|
| 40 |
+
LLAMA_MEMORY_STATUS_SUCCESS = 0,
|
| 41 |
+
LLAMA_MEMORY_STATUS_FAILED_PREPARE,
|
| 42 |
+
LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
|
| 43 |
+
};
|
| 44 |
+
|
| 45 |
+
// the interface for managing the memory state during batch processing
|
| 46 |
+
// this interface is implemented per memory type. see:
|
| 47 |
+
// - llama_kv_cache_unified_state
|
| 48 |
+
// - llama_kv_cache_unified_iswa_state
|
| 49 |
+
// ...
|
| 50 |
+
//
|
| 51 |
+
// the only method that can mutate the memory and the memory state is llama_memory_i::apply()
|
| 52 |
+
//
|
| 53 |
+
// TODO: rename to llama_memory_context_i ?
|
| 54 |
+
class llama_memory_state_i {
|
| 55 |
+
public:
|
| 56 |
+
virtual ~llama_memory_state_i() = default;
|
| 57 |
+
|
| 58 |
+
// consume the current ubatch from the state and proceed to the next one
|
| 59 |
+
// return false if we are done
|
| 60 |
+
virtual bool next() = 0;
|
| 61 |
+
|
| 62 |
+
// apply the memory state for the current ubatch to the memory object
|
| 63 |
+
// return false on failure
|
| 64 |
+
virtual bool apply() = 0;
|
| 65 |
+
|
| 66 |
+
// TODO: this might get reworked in the future when refactoring llama_batch
|
| 67 |
+
virtual std::vector<int64_t> & out_ids() = 0;
|
| 68 |
+
|
| 69 |
+
// get the current ubatch
|
| 70 |
+
virtual const llama_ubatch & get_ubatch() const = 0;
|
| 71 |
+
|
| 72 |
+
// get the status of the memory state
|
| 73 |
+
virtual llama_memory_status get_status() const = 0;
|
| 74 |
+
};
|
| 75 |
+
|
| 76 |
+
using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
|
examples/talk-llama/llama-model.cpp
CHANGED
|
@@ -5,7 +5,10 @@
|
|
| 5 |
#include "llama-batch.h"
|
| 6 |
#include "llama-cparams.h"
|
| 7 |
#include "llama-model-loader.h"
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
#include "ggml-cpp.h"
|
| 11 |
|
|
@@ -683,6 +686,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
| 683 |
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
| 684 |
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
| 685 |
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
|
|
|
|
| 686 |
|
| 687 |
switch (hparams.n_layer) {
|
| 688 |
case 3:
|
|
@@ -2113,7 +2117,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|
| 2113 |
case LLM_ARCH_NOMIC_BERT_MOE:
|
| 2114 |
{
|
| 2115 |
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
| 2116 |
-
type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types},
|
| 2117 |
|
| 2118 |
if (arch == LLM_ARCH_BERT) {
|
| 2119 |
pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0);
|
|
@@ -2121,8 +2125,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|
| 2121 |
cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
|
| 2122 |
cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED);
|
| 2123 |
|
| 2124 |
-
cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd,
|
| 2125 |
-
cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {
|
| 2126 |
}
|
| 2127 |
|
| 2128 |
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
|
|
@@ -2131,7 +2135,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|
| 2131 |
for (int i = 0; i < n_layer; ++i) {
|
| 2132 |
auto & layer = layers[i];
|
| 2133 |
|
| 2134 |
-
|
|
|
|
|
|
|
|
|
|
| 2135 |
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
| 2136 |
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
|
| 2137 |
|
|
@@ -2140,12 +2147,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|
| 2140 |
|
| 2141 |
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
| 2142 |
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
|
| 2143 |
-
} else {
|
| 2144 |
-
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
|
| 2145 |
-
}
|
| 2146 |
-
|
| 2147 |
-
if (arch == LLM_ARCH_NOMIC_BERT_MOE) {
|
| 2148 |
-
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0);
|
| 2149 |
}
|
| 2150 |
|
| 2151 |
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
@@ -5887,8 +5888,10 @@ struct llm_build_bert : public llm_graph_context {
|
|
| 5887 |
inpL = build_inp_embd(model.tok_embd);
|
| 5888 |
|
| 5889 |
// token types are hardcoded to zero ("Sentence A")
|
| 5890 |
-
|
| 5891 |
-
|
|
|
|
|
|
|
| 5892 |
if (model.arch == LLM_ARCH_BERT) {
|
| 5893 |
inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL);
|
| 5894 |
}
|
|
@@ -5909,36 +5912,11 @@ struct llm_build_bert : public llm_graph_context {
|
|
| 5909 |
ggml_tensor * Vcur;
|
| 5910 |
|
| 5911 |
// self-attention
|
| 5912 |
-
if (model.
|
| 5913 |
-
Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
|
| 5914 |
-
|
| 5915 |
-
if (model.layers[il].attn_q_norm) {
|
| 5916 |
-
Qcur = build_norm(Qcur,
|
| 5917 |
-
model.layers[il].attn_q_norm,
|
| 5918 |
-
model.layers[il].attn_q_norm_b,
|
| 5919 |
-
LLM_NORM, il);
|
| 5920 |
-
}
|
| 5921 |
-
|
| 5922 |
-
Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
|
| 5923 |
-
|
| 5924 |
-
if (model.layers[il].attn_k_norm) {
|
| 5925 |
-
Kcur = build_norm(Kcur,
|
| 5926 |
-
model.layers[il].attn_k_norm,
|
| 5927 |
-
model.layers[il].attn_k_norm_b,
|
| 5928 |
-
LLM_NORM, il);
|
| 5929 |
-
}
|
| 5930 |
-
|
| 5931 |
-
Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
|
| 5932 |
-
|
| 5933 |
-
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
| 5934 |
-
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
| 5935 |
-
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
| 5936 |
-
} else {
|
| 5937 |
-
// compute Q and K and RoPE them
|
| 5938 |
cur = build_lora_mm(model.layers[il].wqkv, cur);
|
| 5939 |
cb(cur, "wqkv", il);
|
| 5940 |
|
| 5941 |
-
if (model.
|
| 5942 |
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
| 5943 |
cb(cur, "bqkv", il);
|
| 5944 |
}
|
|
@@ -5946,11 +5924,32 @@ struct llm_build_bert : public llm_graph_context {
|
|
| 5946 |
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
| 5947 |
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
| 5948 |
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5949 |
|
| 5950 |
-
|
| 5951 |
-
|
| 5952 |
-
|
|
|
|
|
|
|
|
|
|
| 5953 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5954 |
Qcur = ggml_rope_ext(
|
| 5955 |
ctx0, Qcur, inp_pos, nullptr,
|
| 5956 |
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
@@ -8896,9 +8895,9 @@ struct llm_build_mamba : public llm_graph_context {
|
|
| 8896 |
ggml_tensor * state_mask,
|
| 8897 |
const llama_ubatch & ubatch,
|
| 8898 |
int il) const {
|
| 8899 |
-
const
|
| 8900 |
|
| 8901 |
-
const auto kv_head =
|
| 8902 |
|
| 8903 |
const int64_t d_conv = hparams.ssm_d_conv;
|
| 8904 |
const int64_t d_inner = hparams.ssm_d_inner;
|
|
@@ -8916,8 +8915,8 @@ struct llm_build_mamba : public llm_graph_context {
|
|
| 8916 |
GGML_ASSERT(ubatch.equal_seqs);
|
| 8917 |
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
| 8918 |
|
| 8919 |
-
ggml_tensor * conv_states_all =
|
| 8920 |
-
ggml_tensor * ssm_states_all =
|
| 8921 |
|
| 8922 |
// (ab)using the KV cache to store the states
|
| 8923 |
ggml_tensor * conv = build_copy_mask_state(
|
|
@@ -11644,7 +11643,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|
| 11644 |
ggml_tensor * state_mask,
|
| 11645 |
const llama_ubatch & ubatch,
|
| 11646 |
int il) const {
|
| 11647 |
-
const
|
| 11648 |
|
| 11649 |
const auto n_tokens = ubatch.n_tokens;
|
| 11650 |
const auto n_seqs = ubatch.n_seqs;
|
|
@@ -11654,7 +11653,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|
| 11654 |
const auto n_head = n_embd / head_size;
|
| 11655 |
const auto n_head_kv = hparams.n_head_kv(il);
|
| 11656 |
|
| 11657 |
-
const auto kv_head =
|
| 11658 |
|
| 11659 |
const auto & layer = model.layers[il];
|
| 11660 |
|
|
@@ -11766,7 +11765,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|
| 11766 |
}
|
| 11767 |
|
| 11768 |
ggml_tensor * wkv_state = build_copy_mask_state(
|
| 11769 |
-
gf,
|
| 11770 |
hparams.n_embd_v_s(), n_seqs);
|
| 11771 |
|
| 11772 |
ggml_tensor * wkv_output;
|
|
@@ -11785,9 +11784,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|
| 11785 |
wkv_state,
|
| 11786 |
ggml_view_1d(
|
| 11787 |
ctx0,
|
| 11788 |
-
|
| 11789 |
hparams.n_embd_v_s() * n_seqs,
|
| 11790 |
-
hparams.n_embd_v_s() * kv_head * ggml_element_size(
|
| 11791 |
)
|
| 11792 |
)
|
| 11793 |
);
|
|
@@ -12040,7 +12039,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|
| 12040 |
ggml_tensor *& first_layer_value,
|
| 12041 |
const llama_ubatch & ubatch,
|
| 12042 |
int il) const {
|
| 12043 |
-
const
|
| 12044 |
|
| 12045 |
const auto n_tokens = ubatch.n_tokens;
|
| 12046 |
const auto n_seqs = ubatch.n_seqs;
|
|
@@ -12049,7 +12048,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|
| 12049 |
const auto head_count = n_embd / head_size;
|
| 12050 |
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
| 12051 |
|
| 12052 |
-
const auto kv_head =
|
| 12053 |
|
| 12054 |
const auto & layer = model.layers[il];
|
| 12055 |
|
|
@@ -12120,7 +12119,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|
| 12120 |
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
|
| 12121 |
|
| 12122 |
ggml_tensor * wkv_state = build_copy_mask_state(
|
| 12123 |
-
gf,
|
| 12124 |
hparams.n_embd_v_s(), n_seqs);
|
| 12125 |
|
| 12126 |
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
|
|
@@ -12134,9 +12133,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|
| 12134 |
wkv_state,
|
| 12135 |
ggml_view_1d(
|
| 12136 |
ctx0,
|
| 12137 |
-
|
| 12138 |
hparams.n_embd_v_s() * n_seqs,
|
| 12139 |
-
hparams.n_embd_v_s() * kv_head * ggml_element_size(
|
| 12140 |
)
|
| 12141 |
)
|
| 12142 |
);
|
|
@@ -13234,7 +13233,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|
| 13234 |
params.swa_full,
|
| 13235 |
cparams.n_ctx,
|
| 13236 |
cparams.n_seq_max,
|
| 13237 |
-
cparams.
|
| 13238 |
padding);
|
| 13239 |
} else {
|
| 13240 |
GGML_ASSERT(!hparams.is_swa_any());
|
|
@@ -13266,7 +13265,6 @@ llm_graph_result_ptr llama_model::build_graph(
|
|
| 13266 |
|
| 13267 |
switch (arch) {
|
| 13268 |
case LLM_ARCH_LLAMA:
|
| 13269 |
-
case LLM_ARCH_MINICPM:
|
| 13270 |
{
|
| 13271 |
llm = std::make_unique<llm_build_llama>(*this, params, gf);
|
| 13272 |
} break;
|
|
@@ -13507,6 +13505,7 @@ llm_graph_result_ptr llama_model::build_graph(
|
|
| 13507 |
} break;
|
| 13508 |
case LLM_ARCH_GRANITE:
|
| 13509 |
case LLM_ARCH_GRANITE_MOE:
|
|
|
|
| 13510 |
{
|
| 13511 |
llm = std::make_unique<llm_build_granite>(*this, params, gf);
|
| 13512 |
} break;
|
|
@@ -13597,6 +13596,10 @@ int32_t llama_model_n_head_kv(const llama_model * model) {
|
|
| 13597 |
return model->hparams.n_head_kv();
|
| 13598 |
}
|
| 13599 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13600 |
// deprecated
|
| 13601 |
int32_t llama_n_ctx_train(const llama_model * model) {
|
| 13602 |
return llama_model_n_ctx_train(model);
|
|
|
|
| 5 |
#include "llama-batch.h"
|
| 6 |
#include "llama-cparams.h"
|
| 7 |
#include "llama-model-loader.h"
|
| 8 |
+
|
| 9 |
+
#include "llama-kv-cache-unified.h"
|
| 10 |
+
#include "llama-kv-cache-unified-iswa.h"
|
| 11 |
+
#include "llama-kv-cache-recurrent.h"
|
| 12 |
|
| 13 |
#include "ggml-cpp.h"
|
| 14 |
|
|
|
|
| 686 |
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
| 687 |
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
| 688 |
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
|
| 689 |
+
ml.get_arr_n(LLM_KV_CLASSIFIER_OUTPUT_LABELS, hparams.n_cls_out, false);
|
| 690 |
|
| 691 |
switch (hparams.n_layer) {
|
| 692 |
case 3:
|
|
|
|
| 2117 |
case LLM_ARCH_NOMIC_BERT_MOE:
|
| 2118 |
{
|
| 2119 |
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
| 2120 |
+
type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED);
|
| 2121 |
|
| 2122 |
if (arch == LLM_ARCH_BERT) {
|
| 2123 |
pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0);
|
|
|
|
| 2125 |
cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
|
| 2126 |
cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED);
|
| 2127 |
|
| 2128 |
+
cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
|
| 2129 |
+
cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
|
| 2130 |
}
|
| 2131 |
|
| 2132 |
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
|
|
|
|
| 2135 |
for (int i = 0; i < n_layer; ++i) {
|
| 2136 |
auto & layer = layers[i];
|
| 2137 |
|
| 2138 |
+
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
| 2139 |
+
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
| 2140 |
+
|
| 2141 |
+
if (!layer.wqkv) {
|
| 2142 |
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
| 2143 |
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
|
| 2144 |
|
|
|
|
| 2147 |
|
| 2148 |
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
| 2149 |
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2150 |
}
|
| 2151 |
|
| 2152 |
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
|
|
| 5888 |
inpL = build_inp_embd(model.tok_embd);
|
| 5889 |
|
| 5890 |
// token types are hardcoded to zero ("Sentence A")
|
| 5891 |
+
if (model.type_embd) {
|
| 5892 |
+
ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
|
| 5893 |
+
inpL = ggml_add(ctx0, inpL, type_row0);
|
| 5894 |
+
}
|
| 5895 |
if (model.arch == LLM_ARCH_BERT) {
|
| 5896 |
inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL);
|
| 5897 |
}
|
|
|
|
| 5912 |
ggml_tensor * Vcur;
|
| 5913 |
|
| 5914 |
// self-attention
|
| 5915 |
+
if (model.layers[il].wqkv) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5916 |
cur = build_lora_mm(model.layers[il].wqkv, cur);
|
| 5917 |
cb(cur, "wqkv", il);
|
| 5918 |
|
| 5919 |
+
if (model.layers[il].bqkv) {
|
| 5920 |
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
| 5921 |
cb(cur, "bqkv", il);
|
| 5922 |
}
|
|
|
|
| 5924 |
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
| 5925 |
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
| 5926 |
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
| 5927 |
+
} else {
|
| 5928 |
+
Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
|
| 5929 |
+
Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
|
| 5930 |
+
Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
|
| 5931 |
+
}
|
| 5932 |
|
| 5933 |
+
if (model.layers[il].attn_q_norm) {
|
| 5934 |
+
Qcur = build_norm(Qcur,
|
| 5935 |
+
model.layers[il].attn_q_norm,
|
| 5936 |
+
model.layers[il].attn_q_norm_b,
|
| 5937 |
+
LLM_NORM, il);
|
| 5938 |
+
}
|
| 5939 |
|
| 5940 |
+
if (model.layers[il].attn_k_norm) {
|
| 5941 |
+
Kcur = build_norm(Kcur,
|
| 5942 |
+
model.layers[il].attn_k_norm,
|
| 5943 |
+
model.layers[il].attn_k_norm_b,
|
| 5944 |
+
LLM_NORM, il);
|
| 5945 |
+
}
|
| 5946 |
+
|
| 5947 |
+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
| 5948 |
+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
| 5949 |
+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
| 5950 |
+
|
| 5951 |
+
// RoPE
|
| 5952 |
+
if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
|
| 5953 |
Qcur = ggml_rope_ext(
|
| 5954 |
ctx0, Qcur, inp_pos, nullptr,
|
| 5955 |
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
|
|
| 8895 |
ggml_tensor * state_mask,
|
| 8896 |
const llama_ubatch & ubatch,
|
| 8897 |
int il) const {
|
| 8898 |
+
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
| 8899 |
|
| 8900 |
+
const auto kv_head = kv_state->get_head();
|
| 8901 |
|
| 8902 |
const int64_t d_conv = hparams.ssm_d_conv;
|
| 8903 |
const int64_t d_inner = hparams.ssm_d_inner;
|
|
|
|
| 8915 |
GGML_ASSERT(ubatch.equal_seqs);
|
| 8916 |
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
| 8917 |
|
| 8918 |
+
ggml_tensor * conv_states_all = kv_state->get_k_l(il);
|
| 8919 |
+
ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
|
| 8920 |
|
| 8921 |
// (ab)using the KV cache to store the states
|
| 8922 |
ggml_tensor * conv = build_copy_mask_state(
|
|
|
|
| 11643 |
ggml_tensor * state_mask,
|
| 11644 |
const llama_ubatch & ubatch,
|
| 11645 |
int il) const {
|
| 11646 |
+
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
| 11647 |
|
| 11648 |
const auto n_tokens = ubatch.n_tokens;
|
| 11649 |
const auto n_seqs = ubatch.n_seqs;
|
|
|
|
| 11653 |
const auto n_head = n_embd / head_size;
|
| 11654 |
const auto n_head_kv = hparams.n_head_kv(il);
|
| 11655 |
|
| 11656 |
+
const auto kv_head = kv_state->get_head();
|
| 11657 |
|
| 11658 |
const auto & layer = model.layers[il];
|
| 11659 |
|
|
|
|
| 11765 |
}
|
| 11766 |
|
| 11767 |
ggml_tensor * wkv_state = build_copy_mask_state(
|
| 11768 |
+
gf, kv_state->get_v_l(il), state_copy, state_mask,
|
| 11769 |
hparams.n_embd_v_s(), n_seqs);
|
| 11770 |
|
| 11771 |
ggml_tensor * wkv_output;
|
|
|
|
| 11784 |
wkv_state,
|
| 11785 |
ggml_view_1d(
|
| 11786 |
ctx0,
|
| 11787 |
+
kv_state->get_v_l(il),
|
| 11788 |
hparams.n_embd_v_s() * n_seqs,
|
| 11789 |
+
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il))
|
| 11790 |
)
|
| 11791 |
)
|
| 11792 |
);
|
|
|
|
| 12039 |
ggml_tensor *& first_layer_value,
|
| 12040 |
const llama_ubatch & ubatch,
|
| 12041 |
int il) const {
|
| 12042 |
+
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
| 12043 |
|
| 12044 |
const auto n_tokens = ubatch.n_tokens;
|
| 12045 |
const auto n_seqs = ubatch.n_seqs;
|
|
|
|
| 12048 |
const auto head_count = n_embd / head_size;
|
| 12049 |
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
| 12050 |
|
| 12051 |
+
const auto kv_head = kv_state->get_head();
|
| 12052 |
|
| 12053 |
const auto & layer = model.layers[il];
|
| 12054 |
|
|
|
|
| 12119 |
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
|
| 12120 |
|
| 12121 |
ggml_tensor * wkv_state = build_copy_mask_state(
|
| 12122 |
+
gf, kv_state->get_v_l(il), state_copy, state_mask,
|
| 12123 |
hparams.n_embd_v_s(), n_seqs);
|
| 12124 |
|
| 12125 |
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
|
|
|
|
| 12133 |
wkv_state,
|
| 12134 |
ggml_view_1d(
|
| 12135 |
ctx0,
|
| 12136 |
+
kv_state->get_v_l(il),
|
| 12137 |
hparams.n_embd_v_s() * n_seqs,
|
| 12138 |
+
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il))
|
| 12139 |
)
|
| 12140 |
)
|
| 12141 |
);
|
|
|
|
| 13233 |
params.swa_full,
|
| 13234 |
cparams.n_ctx,
|
| 13235 |
cparams.n_seq_max,
|
| 13236 |
+
cparams.n_ubatch,
|
| 13237 |
padding);
|
| 13238 |
} else {
|
| 13239 |
GGML_ASSERT(!hparams.is_swa_any());
|
|
|
|
| 13265 |
|
| 13266 |
switch (arch) {
|
| 13267 |
case LLM_ARCH_LLAMA:
|
|
|
|
| 13268 |
{
|
| 13269 |
llm = std::make_unique<llm_build_llama>(*this, params, gf);
|
| 13270 |
} break;
|
|
|
|
| 13505 |
} break;
|
| 13506 |
case LLM_ARCH_GRANITE:
|
| 13507 |
case LLM_ARCH_GRANITE_MOE:
|
| 13508 |
+
case LLM_ARCH_MINICPM:
|
| 13509 |
{
|
| 13510 |
llm = std::make_unique<llm_build_granite>(*this, params, gf);
|
| 13511 |
} break;
|
|
|
|
| 13596 |
return model->hparams.n_head_kv();
|
| 13597 |
}
|
| 13598 |
|
| 13599 |
+
int32_t llama_model_n_swa(const llama_model * model) {
|
| 13600 |
+
return model->hparams.n_swa;
|
| 13601 |
+
}
|
| 13602 |
+
|
| 13603 |
// deprecated
|
| 13604 |
int32_t llama_n_ctx_train(const llama_model * model) {
|
| 13605 |
return llama_model_n_ctx_train(model);
|
examples/talk-llama/llama.h
CHANGED
|
@@ -259,9 +259,9 @@ extern "C" {
|
|
| 259 |
llama_token * token;
|
| 260 |
float * embd;
|
| 261 |
llama_pos * pos;
|
| 262 |
-
int32_t * n_seq_id;
|
| 263 |
-
llama_seq_id ** seq_id;
|
| 264 |
-
int8_t * logits;
|
| 265 |
} llama_batch;
|
| 266 |
|
| 267 |
enum llama_model_kv_override_type {
|
|
@@ -366,6 +366,8 @@ extern "C" {
|
|
| 366 |
bool no_perf; // measure performance timings
|
| 367 |
bool op_offload; // offload host tensor operations to device
|
| 368 |
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
|
|
|
|
|
|
| 369 |
};
|
| 370 |
|
| 371 |
// model quantization parameters
|
|
@@ -502,6 +504,7 @@ extern "C" {
|
|
| 502 |
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
| 503 |
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
| 504 |
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
|
|
|
|
| 505 |
|
| 506 |
// Get the model's RoPE frequency scaling factor
|
| 507 |
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
|
@@ -652,7 +655,6 @@ extern "C" {
|
|
| 652 |
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
| 653 |
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
| 654 |
// - lazily on next llama_decode()
|
| 655 |
-
// - explicitly with llama_kv_self_update()
|
| 656 |
// p0 < 0 : [0, p1]
|
| 657 |
// p1 < 0 : [p0, inf)
|
| 658 |
LLAMA_API void llama_kv_self_seq_add(
|
|
@@ -665,7 +667,6 @@ extern "C" {
|
|
| 665 |
// Integer division of the positions by factor of `d > 1`
|
| 666 |
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
| 667 |
// - lazily on next llama_decode()
|
| 668 |
-
// - explicitly with llama_kv_self_update()
|
| 669 |
// p0 < 0 : [0, p1]
|
| 670 |
// p1 < 0 : [p0, inf)
|
| 671 |
LLAMA_API void llama_kv_self_seq_div(
|
|
@@ -677,12 +678,14 @@ extern "C" {
|
|
| 677 |
|
| 678 |
// Returns the smallest position present in the KV cache for the specified sequence
|
| 679 |
// This is typically non-zero only for SWA caches
|
|
|
|
| 680 |
// Return -1 if the sequence is empty
|
| 681 |
LLAMA_API llama_pos llama_kv_self_seq_pos_min(
|
| 682 |
struct llama_context * ctx,
|
| 683 |
llama_seq_id seq_id);
|
| 684 |
|
| 685 |
// Returns the largest position present in the KV cache for the specified sequence
|
|
|
|
| 686 |
// Return -1 if the sequence is empty
|
| 687 |
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
|
| 688 |
struct llama_context * ctx,
|
|
@@ -691,14 +694,15 @@ extern "C" {
|
|
| 691 |
// Defragment the KV cache
|
| 692 |
// This will be applied:
|
| 693 |
// - lazily on next llama_decode()
|
| 694 |
-
|
| 695 |
-
|
| 696 |
|
| 697 |
// Check if the context supports KV cache shifting
|
| 698 |
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
|
| 699 |
|
| 700 |
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
| 701 |
-
LLAMA_API void llama_kv_self_update(struct llama_context * ctx)
|
|
|
|
| 702 |
|
| 703 |
//
|
| 704 |
// State / sessions
|
|
|
|
| 259 |
llama_token * token;
|
| 260 |
float * embd;
|
| 261 |
llama_pos * pos;
|
| 262 |
+
int32_t * n_seq_id; // TODO: remove, should belong to only 1 sequence
|
| 263 |
+
llama_seq_id ** seq_id; // TODO: become llama_seq_id * seq_id;
|
| 264 |
+
int8_t * logits; // TODO: rename this to "output"
|
| 265 |
} llama_batch;
|
| 266 |
|
| 267 |
enum llama_model_kv_override_type {
|
|
|
|
| 366 |
bool no_perf; // measure performance timings
|
| 367 |
bool op_offload; // offload host tensor operations to device
|
| 368 |
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
| 369 |
+
// NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
|
| 370 |
+
// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
|
| 371 |
};
|
| 372 |
|
| 373 |
// model quantization parameters
|
|
|
|
| 504 |
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
| 505 |
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
| 506 |
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
|
| 507 |
+
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
|
| 508 |
|
| 509 |
// Get the model's RoPE frequency scaling factor
|
| 510 |
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
|
|
|
| 655 |
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
| 656 |
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
| 657 |
// - lazily on next llama_decode()
|
|
|
|
| 658 |
// p0 < 0 : [0, p1]
|
| 659 |
// p1 < 0 : [p0, inf)
|
| 660 |
LLAMA_API void llama_kv_self_seq_add(
|
|
|
|
| 667 |
// Integer division of the positions by factor of `d > 1`
|
| 668 |
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
| 669 |
// - lazily on next llama_decode()
|
|
|
|
| 670 |
// p0 < 0 : [0, p1]
|
| 671 |
// p1 < 0 : [p0, inf)
|
| 672 |
LLAMA_API void llama_kv_self_seq_div(
|
|
|
|
| 678 |
|
| 679 |
// Returns the smallest position present in the KV cache for the specified sequence
|
| 680 |
// This is typically non-zero only for SWA caches
|
| 681 |
+
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
|
| 682 |
// Return -1 if the sequence is empty
|
| 683 |
LLAMA_API llama_pos llama_kv_self_seq_pos_min(
|
| 684 |
struct llama_context * ctx,
|
| 685 |
llama_seq_id seq_id);
|
| 686 |
|
| 687 |
// Returns the largest position present in the KV cache for the specified sequence
|
| 688 |
+
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
|
| 689 |
// Return -1 if the sequence is empty
|
| 690 |
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
|
| 691 |
struct llama_context * ctx,
|
|
|
|
| 694 |
// Defragment the KV cache
|
| 695 |
// This will be applied:
|
| 696 |
// - lazily on next llama_decode()
|
| 697 |
+
LLAMA_API DEPRECATED(void llama_kv_self_defrag(struct llama_context * ctx),
|
| 698 |
+
"simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
|
| 699 |
|
| 700 |
// Check if the context supports KV cache shifting
|
| 701 |
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
|
| 702 |
|
| 703 |
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
| 704 |
+
LLAMA_API DEPRECATED(void llama_kv_self_update(struct llama_context * ctx),
|
| 705 |
+
"simply remove this call, updates are applied lazily on the next llama_decode()");
|
| 706 |
|
| 707 |
//
|
| 708 |
// State / sessions
|