ggerganov commited on
Commit
58220b6
·
1 Parent(s): 337f4d9

talk-llama : sync llama.cpp

Browse files
examples/talk-llama/CMakeLists.txt CHANGED
@@ -17,6 +17,9 @@ if (WHISPER_SDL2)
17
  llama-impl.cpp
18
  llama-io.cpp
19
  llama-kv-cache.cpp
 
 
 
20
  llama-memory.cpp
21
  llama-mmap.cpp
22
  llama-model-loader.cpp
 
17
  llama-impl.cpp
18
  llama-io.cpp
19
  llama-kv-cache.cpp
20
+ llama-kv-cache-unified.cpp
21
+ llama-kv-cache-unified-iswa.cpp
22
+ llama-kv-cache-recurrent.cpp
23
  llama-memory.cpp
24
  llama-mmap.cpp
25
  llama-model-loader.cpp
examples/talk-llama/llama-arch.cpp CHANGED
@@ -174,6 +174,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
174
  { LLM_KV_CONVNEXT_EMBEDDING_LENGTH, "%s.convnext.embedding_length" },
175
  { LLM_KV_CONVNEXT_BLOCK_COUNT, "%s.convnext.block_count" },
176
 
 
 
177
  { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
178
  { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
179
  { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
@@ -448,6 +450,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
448
  { LLM_TENSOR_TOKEN_TYPES, "token_types" },
449
  { LLM_TENSOR_POS_EMBD, "position_embd" },
450
  { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
 
451
  { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
452
  { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
453
  { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
 
174
  { LLM_KV_CONVNEXT_EMBEDDING_LENGTH, "%s.convnext.embedding_length" },
175
  { LLM_KV_CONVNEXT_BLOCK_COUNT, "%s.convnext.block_count" },
176
 
177
+ { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
178
+
179
  { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
180
  { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
181
  { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
 
450
  { LLM_TENSOR_TOKEN_TYPES, "token_types" },
451
  { LLM_TENSOR_POS_EMBD, "position_embd" },
452
  { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
453
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
454
  { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
455
  { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
456
  { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
examples/talk-llama/llama-arch.h CHANGED
@@ -213,6 +213,8 @@ enum llm_kv {
213
  LLM_KV_CONVNEXT_EMBEDDING_LENGTH,
214
  LLM_KV_CONVNEXT_BLOCK_COUNT,
215
 
 
 
216
  // deprecated:
217
  LLM_KV_TOKENIZER_PREFIX_ID,
218
  LLM_KV_TOKENIZER_SUFFIX_ID,
 
213
  LLM_KV_CONVNEXT_EMBEDDING_LENGTH,
214
  LLM_KV_CONVNEXT_BLOCK_COUNT,
215
 
216
+ LLM_KV_CLASSIFIER_OUTPUT_LABELS,
217
+
218
  // deprecated:
219
  LLM_KV_TOKENIZER_PREFIX_ID,
220
  LLM_KV_TOKENIZER_SUFFIX_ID,
examples/talk-llama/llama-batch.cpp CHANGED
@@ -15,24 +15,31 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
15
  break;
16
  }
17
  }
18
- ubatch_token.resize(!has_embd ? n_ubatch : 0);
19
- ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
20
- ubatch_pos.resize(n_ubatch);
21
- ubatch_n_seq_id.resize(n_ubatch);
22
- ubatch_seq_id.resize(n_ubatch);
23
- ubatch_output.resize(n_ubatch);
 
 
 
 
 
 
24
  llama_ubatch ubatch = {
25
  /*equal_seqs =*/ true,
26
  /*n_tokens =*/ 0,
27
  /*n_seq_tokens =*/ 0,
28
  /*n_seqs =*/ 0,
29
- /*token =*/ !has_embd ? ubatch_token.data() : nullptr,
30
- /*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
31
- /*pos =*/ ubatch_pos.data(),
32
- /*n_seq_id =*/ ubatch_n_seq_id.data(),
33
- /*seq_id =*/ ubatch_seq_id.data(),
34
- /*output =*/ ubatch_output.data(),
35
  };
 
36
  return ubatch;
37
  }
38
 
 
15
  break;
16
  }
17
  }
18
+
19
+ udatas.push_back({});
20
+
21
+ auto & udata = udatas.back();
22
+
23
+ udata.token.resize(!has_embd ? n_ubatch : 0);
24
+ udata.embd.resize(has_embd ? n_embd * n_ubatch : 0);
25
+ udata.pos.resize(n_ubatch);
26
+ udata.n_seq_id.resize(n_ubatch);
27
+ udata.seq_id.resize(n_ubatch);
28
+ udata.output.resize(n_ubatch);
29
+
30
  llama_ubatch ubatch = {
31
  /*equal_seqs =*/ true,
32
  /*n_tokens =*/ 0,
33
  /*n_seq_tokens =*/ 0,
34
  /*n_seqs =*/ 0,
35
+ /*token =*/ !has_embd ? udata.token.data() : nullptr,
36
+ /*embd =*/ has_embd ? udata.embd.data() : nullptr,
37
+ /*pos =*/ udata.pos.data(),
38
+ /*n_seq_id =*/ udata.n_seq_id.data(),
39
+ /*seq_id =*/ udata.seq_id.data(),
40
+ /*output =*/ udata.output.data(),
41
  };
42
+
43
  return ubatch;
44
  }
45
 
examples/talk-llama/llama-batch.h CHANGED
@@ -11,15 +11,15 @@ struct llama_ubatch {
11
  bool equal_seqs;
12
  // TODO: whole_seqs for embeddings?
13
 
14
- uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
15
  uint32_t n_seq_tokens; // tokens per sequence
16
  uint32_t n_seqs;
17
 
18
  llama_token * token; // [n_tokens]
19
  float * embd; // [n_embd, n_tokens]
20
  llama_pos * pos; // [n_tokens]
21
- int32_t * n_seq_id; // [n_seqs]
22
- llama_seq_id ** seq_id; // [n_seqs]
23
  int8_t * output; // [n_tokens]
24
  };
25
 
@@ -49,13 +49,18 @@ struct llama_sbatch {
49
 
50
  const llama_batch * batch = nullptr;
51
 
52
- // buffers for the ubatch
53
- std::vector<llama_token> ubatch_token;
54
- std::vector<float> ubatch_embd;
55
- std::vector<llama_pos> ubatch_pos;
56
- std::vector<int32_t> ubatch_n_seq_id;
57
- std::vector<llama_seq_id *> ubatch_seq_id;
58
- std::vector<int8_t> ubatch_output;
 
 
 
 
 
59
 
60
  llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
61
 
 
11
  bool equal_seqs;
12
  // TODO: whole_seqs for embeddings?
13
 
14
+ uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
15
  uint32_t n_seq_tokens; // tokens per sequence
16
  uint32_t n_seqs;
17
 
18
  llama_token * token; // [n_tokens]
19
  float * embd; // [n_embd, n_tokens]
20
  llama_pos * pos; // [n_tokens]
21
+ int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
22
+ llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id;
23
  int8_t * output; // [n_tokens]
24
  };
25
 
 
49
 
50
  const llama_batch * batch = nullptr;
51
 
52
+ // buffers for the ubatches
53
+ // TODO: very hacky, this needs a complete rework
54
+ struct ubatch_data {
55
+ std::vector<llama_token> token;
56
+ std::vector<float> embd;
57
+ std::vector<llama_pos> pos;
58
+ std::vector<int32_t> n_seq_id;
59
+ std::vector<llama_seq_id *> seq_id;
60
+ std::vector<int8_t> output;
61
+ };
62
+
63
+ std::vector<ubatch_data> udatas;
64
 
65
  llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
66
 
examples/talk-llama/llama-context.cpp CHANGED
@@ -6,9 +6,10 @@
6
  #include "llama-model.h"
7
  #include "llama-kv-cache.h"
8
 
 
9
  #include <cstring>
 
10
  #include <stdexcept>
11
- #include <cinttypes>
12
 
13
  //
14
  // llama_context
@@ -122,6 +123,11 @@ llama_context::llama_context(
122
  __func__, n_ctx_per_seq, hparams.n_ctx_train);
123
  }
124
 
 
 
 
 
 
125
  if (!hparams.vocab_only) {
126
  // GPU backends
127
  for (auto * dev : model.devices) {
@@ -259,15 +265,9 @@ llama_context::llama_context(
259
 
260
  // reserve worst-case graph
261
  if (!hparams.vocab_only && memory) {
262
- const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
263
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
264
 
265
- llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
266
-
267
- // restore later
268
- // TODO: something cleaner
269
- const auto n_outputs_save = n_outputs;
270
-
271
  LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
272
 
273
  int n_splits_pp = -1;
@@ -279,23 +279,17 @@ llama_context::llama_context(
279
  // simulate full KV cache
280
  llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
281
 
282
- kv_self->set_full();
 
 
 
283
 
284
  cross.v_embd.clear();
285
 
286
  // reserve pp graph first so that buffers are only allocated once
287
  {
288
- llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
289
-
290
- // max number of outputs
291
- n_outputs = ubatch_pp.n_tokens;
292
-
293
- LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
294
-
295
- auto * gf = graph_init();
296
- graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
297
-
298
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
299
  throw std::runtime_error("failed to allocate compute pp buffers");
300
  }
301
 
@@ -305,16 +299,8 @@ llama_context::llama_context(
305
 
306
  // reserve with tg graph to get the number of splits and nodes
307
  {
308
- llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
309
-
310
- n_outputs = ubatch_tg.n_tokens;
311
-
312
- LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
313
-
314
- auto * gf = graph_init();
315
- graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
316
-
317
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
318
  throw std::runtime_error("failed to allocate compute tg buffers");
319
  }
320
 
@@ -324,22 +310,12 @@ llama_context::llama_context(
324
 
325
  // reserve again with pp graph to avoid ggml-alloc reallocations during inference
326
  {
327
- llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
328
-
329
- n_outputs = ubatch_pp.n_tokens;
330
-
331
- LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
332
-
333
- auto * gf = graph_init();
334
- graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
335
-
336
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
337
  throw std::runtime_error("failed to allocate compute pp buffers");
338
  }
339
  }
340
 
341
- n_outputs = n_outputs_save;
342
-
343
  for (size_t i = 0; i < backend_ptrs.size(); ++i) {
344
  ggml_backend_t backend = backend_ptrs[i];
345
  ggml_backend_buffer_type_t buft = backend_buft[i];
@@ -453,36 +429,33 @@ const llama_kv_cache * llama_context::get_kv_self() const {
453
  return kv_self;
454
  }
455
 
456
- void llama_context::kv_self_update() {
457
- bool need_reserve = false;
 
 
458
 
459
  llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
460
 
461
- need_reserve = kv_self->update(*this);
462
-
463
- // reserve a worst case graph if needed
464
- if (need_reserve) {
465
- LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
466
-
467
- // build worst-case graph
468
- uint32_t n_seqs = 1; // TODO: worst-case number of sequences
469
- uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
470
-
471
- // simulate full KV cache
472
- kv_self->set_full();
473
 
474
- llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
475
- llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
 
 
 
476
 
477
- auto * gf = graph_init();
478
- graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
479
 
480
- // initialize scheduler with the worst-case graph
481
- ggml_backend_sched_reset(sched.get());
482
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
483
- LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
484
- }
485
  }
 
 
486
  }
487
 
488
  enum llama_pooling_type llama_context::pooling_type() const {
@@ -676,6 +649,49 @@ bool llama_context::apply_adapter_cvec(
676
  return cvec.apply(model, data, len, n_embd, il_start, il_end);
677
  }
678
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
679
  int llama_context::encode(llama_batch & inp_batch) {
680
  if (inp_batch.n_tokens == 0) {
681
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
@@ -737,8 +753,6 @@ int llama_context::encode(llama_batch & inp_batch) {
737
 
738
  n_outputs = n_tokens;
739
 
740
- //batch_manager->prepare(ubatch);
741
-
742
  ggml_backend_sched_reset(sched.get());
743
  ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
744
 
@@ -749,26 +763,18 @@ int llama_context::encode(llama_batch & inp_batch) {
749
  // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
750
  cparams.causal_attn = false;
751
 
752
- auto * gf = graph_init();
753
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
754
-
755
- ggml_backend_sched_alloc_graph(sched.get(), gf);
756
-
757
- res->set_inputs(&ubatch);
758
 
759
  cparams.causal_attn = causal_attn_org;
760
 
761
- const auto compute_status = graph_compute(gf, n_tokens > 1);
762
- switch (compute_status) {
763
- case GGML_STATUS_SUCCESS:
764
- break;
765
- case GGML_STATUS_ABORTED:
766
- return 2;
767
- case GGML_STATUS_ALLOC_FAILED:
768
- return -2;
769
- case GGML_STATUS_FAILED:
770
- default:
771
- return -3;
772
  }
773
 
774
  auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
@@ -889,8 +895,6 @@ int llama_context::decode(llama_batch & inp_batch) {
889
  const int64_t n_tokens_all = batch.n_tokens;
890
  const int64_t n_embd = hparams.n_embd;
891
 
892
- llama_kv_cache_guard kv_guard(kv_self);
893
-
894
  GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
895
 
896
  // TODO: move the validation to the llama_batch_allocr
@@ -936,7 +940,48 @@ int llama_context::decode(llama_batch & inp_batch) {
936
  n_outputs_all = 1;
937
  }
938
 
939
- llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
940
 
941
  // reserve output buffer
942
  if (output_reserve(n_outputs_all) < n_outputs_all) {
@@ -944,13 +989,10 @@ int llama_context::decode(llama_batch & inp_batch) {
944
  return -2;
945
  };
946
 
947
- // handle any pending defrags/shifts
948
- kv_self_update();
949
-
950
  int64_t n_outputs_prev = 0;
951
 
952
- while (sbatch.n_tokens > 0) {
953
- llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
954
 
955
  // count the outputs in this u_batch
956
  {
@@ -969,33 +1011,37 @@ int llama_context::decode(llama_batch & inp_batch) {
969
  n_outputs = n_outputs_new;
970
  }
971
 
972
- // find KV slot
973
- if (!kv_self->find_slot(ubatch)) {
974
- return 1;
975
- }
976
-
977
  ggml_backend_sched_reset(sched.get());
978
  ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
979
 
980
- auto * gf = graph_init();
981
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER);
 
 
 
 
982
 
983
- // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
 
984
 
985
- ggml_backend_sched_alloc_graph(sched.get(), gf);
 
986
 
987
- res->set_inputs(&ubatch);
 
 
 
988
 
989
- const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
990
- if (compute_status != GGML_STATUS_SUCCESS) {
991
- switch (compute_status) {
992
- case GGML_STATUS_ABORTED:
993
- return 2;
994
- case GGML_STATUS_ALLOC_FAILED:
995
- return -2;
996
- case GGML_STATUS_FAILED:
997
- default:
998
- return -3;
999
  }
1000
  }
1001
 
@@ -1082,10 +1128,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1082
  }
1083
 
1084
  n_outputs_prev += n_outputs;
1085
- }
1086
-
1087
- // finalize the batch processing
1088
- kv_guard.commit();
1089
 
1090
  // set to total number of outputs in the batch, for use in llama_get_logits_ith
1091
  n_outputs = n_outputs_all;
@@ -1094,7 +1137,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1094
  {
1095
  bool sorted_output = true;
1096
 
1097
- auto & out_ids = sbatch.out_ids;
1098
 
1099
  GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
1100
 
@@ -1254,11 +1297,52 @@ ggml_cgraph * llama_context::graph_init() {
1254
  return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
1255
  }
1256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1257
  llm_graph_result_ptr llama_context::graph_build(
1258
- ggml_context * ctx,
1259
- ggml_cgraph * gf,
1260
- const llama_ubatch & ubatch,
1261
- llm_graph_type gtype) {
 
1262
  return model.build_graph(
1263
  {
1264
  /*.ctx =*/ ctx,
@@ -1270,7 +1354,7 @@ llm_graph_result_ptr llama_context::graph_build(
1270
  /*.backend_cpu =*/ backend_cpu,
1271
  /*.cvec =*/ &cvec,
1272
  /*.loras =*/ &loras,
1273
- /*.memory =*/ memory.get(),
1274
  /*.cross =*/ &cross,
1275
  /*.n_outputs =*/ n_outputs,
1276
  /*.cb =*/ graph_get_cb(),
@@ -1951,7 +2035,6 @@ void llama_context::opt_epoch_iter(
1951
  llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1952
 
1953
  kv_self->clear();
1954
- llama_kv_cache_guard kv_guard(kv_self);
1955
 
1956
  for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
1957
  batch.n_tokens = n_batch;
@@ -1974,7 +2057,11 @@ void llama_context::opt_epoch_iter(
1974
 
1975
  int64_t n_outputs_all = n_tokens_all;
1976
 
1977
- llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
 
 
 
 
1978
 
1979
  // reserve output buffer
1980
  if (output_reserve(n_outputs_all) < n_outputs_all) {
@@ -1982,20 +2069,19 @@ void llama_context::opt_epoch_iter(
1982
  GGML_ABORT("TODO: handle this error");
1983
  };
1984
 
1985
- for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
1986
- llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
 
1987
 
1988
  n_outputs = ubatch.n_tokens;
1989
 
1990
- // TODO: not sure if this is needed
1991
- if (!kv_self->find_slot(ubatch)) {
1992
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1993
-
1994
- GGML_ABORT("TODO: handle this error");
1995
  }
1996
 
1997
  auto * gf = graph_init();
1998
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
1999
 
2000
  struct ggml_context * ctx_compute_opt;
2001
  {
@@ -2010,6 +2096,7 @@ void llama_context::opt_epoch_iter(
2010
  }
2011
  ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
2012
  ggml_opt_alloc(opt_ctx, train);
 
2013
  res->set_inputs(&ubatch);
2014
  {
2015
  struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
@@ -2027,10 +2114,10 @@ void llama_context::opt_epoch_iter(
2027
  callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
2028
  }
2029
  ggml_free(ctx_compute_opt);
2030
- }
2031
- }
2032
 
2033
- kv_guard.commit();
 
 
2034
  }
2035
 
2036
  void llama_context::opt_epoch(
@@ -2194,6 +2281,7 @@ llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
2194
  return ctx->get_kv_self();
2195
  }
2196
 
 
2197
  void llama_kv_self_update(llama_context * ctx) {
2198
  ctx->kv_self_update();
2199
  }
@@ -2448,6 +2536,7 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2448
  return kv->seq_pos_max(seq_id);
2449
  }
2450
 
 
2451
  void llama_kv_self_defrag(llama_context * ctx) {
2452
  auto * kv = ctx->get_kv_self();
2453
  if (!kv) {
@@ -2589,22 +2678,8 @@ int32_t llama_encode(
2589
  int32_t llama_decode(
2590
  llama_context * ctx,
2591
  llama_batch batch) {
2592
- int ret = ctx->decode(batch);
2593
-
2594
- // defrag and try again
2595
- // TODO: distinguish return code when we are sure that even after defrag there is no space available
2596
- if (ret == 1) {
2597
- llama_kv_self_defrag(ctx);
2598
- ret = ctx->decode(batch);
2599
-
2600
- if (ret == 1) {
2601
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
2602
-
2603
- return ret;
2604
- }
2605
- }
2606
-
2607
- if (ret != 0) {
2608
  LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
2609
  }
2610
 
 
6
  #include "llama-model.h"
7
  #include "llama-kv-cache.h"
8
 
9
+ #include <cinttypes>
10
  #include <cstring>
11
+ #include <limits>
12
  #include <stdexcept>
 
13
 
14
  //
15
  // llama_context
 
123
  __func__, n_ctx_per_seq, hparams.n_ctx_train);
124
  }
125
 
126
+ if (!params.swa_full && cparams.n_seq_max > 1) {
127
+ LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
128
+ __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
129
+ }
130
+
131
  if (!hparams.vocab_only) {
132
  // GPU backends
133
  for (auto * dev : model.devices) {
 
265
 
266
  // reserve worst-case graph
267
  if (!hparams.vocab_only && memory) {
268
+ const uint32_t n_seqs = cparams.n_seq_max;
269
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
270
 
 
 
 
 
 
 
271
  LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
272
 
273
  int n_splits_pp = -1;
 
279
  // simulate full KV cache
280
  llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
281
 
282
+ const auto kv_state = kv_self->init_full();
283
+ if (!kv_state) {
284
+ throw std::runtime_error("failed to initialize KV cache");
285
+ }
286
 
287
  cross.v_embd.clear();
288
 
289
  // reserve pp graph first so that buffers are only allocated once
290
  {
291
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
292
+ if (!gf) {
 
 
 
 
 
 
 
 
 
293
  throw std::runtime_error("failed to allocate compute pp buffers");
294
  }
295
 
 
299
 
300
  // reserve with tg graph to get the number of splits and nodes
301
  {
302
+ auto * gf = graph_reserve(1, 1, 1, kv_state.get());
303
+ if (!gf) {
 
 
 
 
 
 
 
 
304
  throw std::runtime_error("failed to allocate compute tg buffers");
305
  }
306
 
 
310
 
311
  // reserve again with pp graph to avoid ggml-alloc reallocations during inference
312
  {
313
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
314
+ if (!gf) {
 
 
 
 
 
 
 
 
315
  throw std::runtime_error("failed to allocate compute pp buffers");
316
  }
317
  }
318
 
 
 
319
  for (size_t i = 0; i < backend_ptrs.size(); ++i) {
320
  ggml_backend_t backend = backend_ptrs[i];
321
  ggml_backend_buffer_type_t buft = backend_buft[i];
 
429
  return kv_self;
430
  }
431
 
432
+ bool llama_context::kv_self_update() {
433
+ if (!memory) {
434
+ return false;
435
+ }
436
 
437
  llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
438
 
439
+ if (!kv_self->update(*this)) {
440
+ // no updates have been performed
441
+ return false;
442
+ }
 
 
 
 
 
 
 
 
443
 
444
+ // if the KV cache did any computation, we have to reserve a new worst-case graph
445
+ const auto kv_state = kv_self->init_full();
446
+ if (!kv_state) {
447
+ throw std::runtime_error("failed to initialize KV cache");
448
+ }
449
 
450
+ const uint32_t n_seqs = cparams.n_seq_max;
451
+ const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
452
 
453
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
454
+ if (!gf) {
455
+ LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
 
 
456
  }
457
+
458
+ return true;
459
  }
460
 
461
  enum llama_pooling_type llama_context::pooling_type() const {
 
649
  return cvec.apply(model, data, len, n_embd, il_start, il_end);
650
  }
651
 
652
+ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) {
653
+ if (mstate && !mstate->apply()) {
654
+ LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__);
655
+ ret = GGML_STATUS_FAILED;
656
+ return nullptr;
657
+ }
658
+
659
+ auto * gf = graph_init();
660
+ if (!gf) {
661
+ LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
662
+ ret = GGML_STATUS_FAILED;
663
+ return nullptr;
664
+ }
665
+
666
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate);
667
+ if (!res) {
668
+ LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
669
+ ret = GGML_STATUS_FAILED;
670
+ return nullptr;
671
+ }
672
+
673
+ // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
674
+
675
+ if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
676
+ LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
677
+ ret = GGML_STATUS_ALLOC_FAILED;
678
+ return nullptr;
679
+ }
680
+
681
+ res->set_inputs(&ubatch);
682
+
683
+ const auto status = graph_compute(gf, ubatch.n_tokens > 1);
684
+ if (status != GGML_STATUS_SUCCESS) {
685
+ LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
686
+ ret = status;
687
+ return nullptr;
688
+ }
689
+
690
+ ret = GGML_STATUS_SUCCESS;
691
+
692
+ return res;
693
+ }
694
+
695
  int llama_context::encode(llama_batch & inp_batch) {
696
  if (inp_batch.n_tokens == 0) {
697
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
 
753
 
754
  n_outputs = n_tokens;
755
 
 
 
756
  ggml_backend_sched_reset(sched.get());
757
  ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
758
 
 
763
  // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
764
  cparams.causal_attn = false;
765
 
766
+ ggml_status status;
767
+ const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
 
 
 
 
768
 
769
  cparams.causal_attn = causal_attn_org;
770
 
771
+ if (!res) {
772
+ switch (status) {
773
+ case GGML_STATUS_ABORTED: return 2;
774
+ case GGML_STATUS_ALLOC_FAILED: return -2;
775
+ case GGML_STATUS_FAILED: return -3;
776
+ case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
777
+ }
 
 
 
 
778
  }
779
 
780
  auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
 
895
  const int64_t n_tokens_all = batch.n_tokens;
896
  const int64_t n_embd = hparams.n_embd;
897
 
 
 
898
  GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
899
 
900
  // TODO: move the validation to the llama_batch_allocr
 
940
  n_outputs_all = 1;
941
  }
942
 
943
+ // handle any pending defrags/shifts
944
+ kv_self_update();
945
+
946
+ llama_memory_state_ptr kv_state;
947
+
948
+ bool did_defrag = false;
949
+
950
+ while (true) {
951
+ kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
952
+ if (!kv_state) {
953
+ return -2;
954
+ }
955
+
956
+ switch (kv_state->get_status()) {
957
+ case LLAMA_MEMORY_STATUS_SUCCESS:
958
+ {
959
+ } break;
960
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
961
+ {
962
+ if (!did_defrag) {
963
+ did_defrag = true;
964
+
965
+ kv_self->defrag_sched(-1.0f);
966
+ if (kv_self_update()) {
967
+ LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
968
+
969
+ continue;
970
+ }
971
+ }
972
+
973
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
974
+
975
+ return 1;
976
+ }
977
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
978
+ {
979
+ return -2;
980
+ }
981
+ }
982
+
983
+ break;
984
+ }
985
 
986
  // reserve output buffer
987
  if (output_reserve(n_outputs_all) < n_outputs_all) {
 
989
  return -2;
990
  };
991
 
 
 
 
992
  int64_t n_outputs_prev = 0;
993
 
994
+ do {
995
+ const auto & ubatch = kv_state->get_ubatch();
996
 
997
  // count the outputs in this u_batch
998
  {
 
1011
  n_outputs = n_outputs_new;
1012
  }
1013
 
 
 
 
 
 
1014
  ggml_backend_sched_reset(sched.get());
1015
  ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1016
 
1017
+ ggml_status status;
1018
+ const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status);
1019
+
1020
+ if (!res) {
1021
+ // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
1022
+ llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits<llama_pos>::max() };
1023
 
1024
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
1025
+ const auto & seq_id = ubatch.seq_id[i][0];
1026
 
1027
+ pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
1028
+ }
1029
 
1030
+ for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
1031
+ if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
1032
+ continue;
1033
+ }
1034
 
1035
+ LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
1036
+
1037
+ llama_kv_self_seq_rm(this, s, pos_min[s], -1);
1038
+ }
1039
+
1040
+ switch (status) {
1041
+ case GGML_STATUS_ABORTED: return 2;
1042
+ case GGML_STATUS_ALLOC_FAILED: return -2;
1043
+ case GGML_STATUS_FAILED: return -3;
1044
+ case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
1045
  }
1046
  }
1047
 
 
1128
  }
1129
 
1130
  n_outputs_prev += n_outputs;
1131
+ } while (kv_state->next());
 
 
 
1132
 
1133
  // set to total number of outputs in the batch, for use in llama_get_logits_ith
1134
  n_outputs = n_outputs_all;
 
1137
  {
1138
  bool sorted_output = true;
1139
 
1140
+ auto & out_ids = kv_state->out_ids();
1141
 
1142
  GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
1143
 
 
1297
  return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
1298
  }
1299
 
1300
+ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) {
1301
+ LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
1302
+
1303
+ if (n_tokens % n_seqs != 0) {
1304
+ n_tokens = (n_tokens / n_seqs) * n_seqs;
1305
+ n_outputs = std::min(n_outputs, n_tokens);
1306
+
1307
+ LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
1308
+ }
1309
+
1310
+ // store the n_outputs as it is, and restore it afterwards
1311
+ // TODO: not sure if needed, might simplify in the future by removing this
1312
+ const auto save_n_outputs = this->n_outputs;
1313
+
1314
+ this->n_outputs = n_outputs;
1315
+
1316
+ llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
1317
+ llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
1318
+
1319
+ auto * gf = graph_init();
1320
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
1321
+
1322
+ this->n_outputs = save_n_outputs;
1323
+
1324
+ if (!res) {
1325
+ LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
1326
+ return nullptr;
1327
+ }
1328
+
1329
+ ggml_backend_sched_reset(sched.get());
1330
+
1331
+ // initialize scheduler with the specified graph
1332
+ if (!ggml_backend_sched_reserve(sched.get(), gf)) {
1333
+ LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
1334
+ return nullptr;
1335
+ }
1336
+
1337
+ return gf;
1338
+ }
1339
+
1340
  llm_graph_result_ptr llama_context::graph_build(
1341
+ ggml_context * ctx,
1342
+ ggml_cgraph * gf,
1343
+ const llama_ubatch & ubatch,
1344
+ llm_graph_type gtype,
1345
+ const llama_memory_state_i * mstate) {
1346
  return model.build_graph(
1347
  {
1348
  /*.ctx =*/ ctx,
 
1354
  /*.backend_cpu =*/ backend_cpu,
1355
  /*.cvec =*/ &cvec,
1356
  /*.loras =*/ &loras,
1357
+ /*.mstate =*/ mstate,
1358
  /*.cross =*/ &cross,
1359
  /*.n_outputs =*/ n_outputs,
1360
  /*.cb =*/ graph_get_cb(),
 
2035
  llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
2036
 
2037
  kv_self->clear();
 
2038
 
2039
  for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
2040
  batch.n_tokens = n_batch;
 
2057
 
2058
  int64_t n_outputs_all = n_tokens_all;
2059
 
2060
+ auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
2061
+ if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2062
+ LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
2063
+ break;
2064
+ }
2065
 
2066
  // reserve output buffer
2067
  if (output_reserve(n_outputs_all) < n_outputs_all) {
 
2069
  GGML_ABORT("TODO: handle this error");
2070
  };
2071
 
2072
+ uint32_t pos_batch = 0;
2073
+ do {
2074
+ const auto & ubatch = kv_state->get_ubatch();
2075
 
2076
  n_outputs = ubatch.n_tokens;
2077
 
2078
+ if (!kv_state->apply()) {
2079
+ LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
2080
+ break;
 
 
2081
  }
2082
 
2083
  auto * gf = graph_init();
2084
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get());
2085
 
2086
  struct ggml_context * ctx_compute_opt;
2087
  {
 
2096
  }
2097
  ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
2098
  ggml_opt_alloc(opt_ctx, train);
2099
+
2100
  res->set_inputs(&ubatch);
2101
  {
2102
  struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
 
2114
  callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
2115
  }
2116
  ggml_free(ctx_compute_opt);
 
 
2117
 
2118
+ pos_batch += ubatch.n_tokens;
2119
+ } while (kv_state->next());
2120
+ }
2121
  }
2122
 
2123
  void llama_context::opt_epoch(
 
2281
  return ctx->get_kv_self();
2282
  }
2283
 
2284
+ // deprecated
2285
  void llama_kv_self_update(llama_context * ctx) {
2286
  ctx->kv_self_update();
2287
  }
 
2536
  return kv->seq_pos_max(seq_id);
2537
  }
2538
 
2539
+ // deprecated
2540
  void llama_kv_self_defrag(llama_context * ctx) {
2541
  auto * kv = ctx->get_kv_self();
2542
  if (!kv) {
 
2678
  int32_t llama_decode(
2679
  llama_context * ctx,
2680
  llama_batch batch) {
2681
+ const int ret = ctx->decode(batch);
2682
+ if (ret != 0 && ret != 1) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2683
  LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
2684
  }
2685
 
examples/talk-llama/llama-context.h CHANGED
@@ -18,6 +18,9 @@ struct llama_kv_cache;
18
  class llama_io_read_i;
19
  class llama_io_write_i;
20
 
 
 
 
21
  struct llama_context {
22
  // init scheduler and compute buffers, reserve worst-case graphs
23
  llama_context(
@@ -47,7 +50,9 @@ struct llama_context {
47
  llama_kv_cache * get_kv_self();
48
  const llama_kv_cache * get_kv_self() const;
49
 
50
- void kv_self_update();
 
 
51
 
52
  enum llama_pooling_type pooling_type() const;
53
 
@@ -88,6 +93,16 @@ struct llama_context {
88
  int32_t il_start,
89
  int32_t il_end);
90
 
 
 
 
 
 
 
 
 
 
 
91
  int encode(llama_batch & inp_batch);
92
  int decode(llama_batch & inp_batch);
93
 
@@ -180,16 +195,18 @@ public:
180
  ggml_cgraph * graph_init();
181
 
182
  // returns the result of ggml_backend_sched_graph_compute_async execution
183
- ggml_status graph_compute(
184
- ggml_cgraph * gf,
185
- bool batched);
 
186
 
187
  private:
188
  llm_graph_result_ptr graph_build(
189
- ggml_context * ctx,
190
- ggml_cgraph * gf,
191
- const llama_ubatch & ubatch,
192
- llm_graph_type gtype);
 
193
 
194
  llm_graph_cb graph_get_cb() const;
195
 
 
18
  class llama_io_read_i;
19
  class llama_io_write_i;
20
 
21
+ class llama_memory_i;
22
+ class llama_memory_state_i;
23
+
24
  struct llama_context {
25
  // init scheduler and compute buffers, reserve worst-case graphs
26
  llama_context(
 
50
  llama_kv_cache * get_kv_self();
51
  const llama_kv_cache * get_kv_self() const;
52
 
53
+ // return true of the KV cache was updated
54
+ // TODO: remove
55
+ bool kv_self_update();
56
 
57
  enum llama_pooling_type pooling_type() const;
58
 
 
93
  int32_t il_start,
94
  int32_t il_end);
95
 
96
+ // process a single ubatch with a specific graph type
97
+ // if memory_state is provided, it will be applied first to the context's memory
98
+ // ret contains the status of the graph computation
99
+ // returns nullptr only if ret != GGML_STATUS_SUCCESS
100
+ llm_graph_result_ptr process_ubatch(
101
+ const llama_ubatch & ubatch,
102
+ llm_graph_type gtype,
103
+ llama_memory_state_i * mstate,
104
+ ggml_status & ret);
105
+
106
  int encode(llama_batch & inp_batch);
107
  int decode(llama_batch & inp_batch);
108
 
 
195
  ggml_cgraph * graph_init();
196
 
197
  // returns the result of ggml_backend_sched_graph_compute_async execution
198
+ ggml_status graph_compute(ggml_cgraph * gf, bool batched);
199
+
200
+ // reserve a graph with a dummy ubatch of the specified size
201
+ ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate);
202
 
203
  private:
204
  llm_graph_result_ptr graph_build(
205
+ ggml_context * ctx,
206
+ ggml_cgraph * gf,
207
+ const llama_ubatch & ubatch,
208
+ llm_graph_type gtype,
209
+ const llama_memory_state_i * mstate);
210
 
211
  llm_graph_cb graph_get_cb() const;
212
 
examples/talk-llama/llama-graph.cpp CHANGED
@@ -3,7 +3,10 @@
3
  #include "llama-impl.h"
4
  #include "llama-batch.h"
5
  #include "llama-cparams.h"
6
- #include "llama-kv-cache.h"
 
 
 
7
 
8
  #include <cassert>
9
  #include <cmath>
@@ -83,7 +86,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
83
 
84
  void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
85
  if (pos_bucket) {
86
- kv_self->set_input_pos_bucket(pos_bucket, ubatch);
87
  }
88
  }
89
 
@@ -234,7 +237,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
234
  void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
235
  GGML_UNUSED(ubatch);
236
 
237
- const int64_t n_kv = kv_self->n;
238
 
239
  if (s_copy) {
240
  GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
@@ -242,7 +245,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
242
 
243
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
244
  for (uint32_t i = 0; i < n_kv; ++i) {
245
- data[i] = kv_self->s_copy(i);
246
  }
247
  }
248
  }
@@ -250,7 +253,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
250
  void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
251
  GGML_UNUSED(ubatch);
252
 
253
- const int64_t n_kv = kv_self->n;
254
 
255
  if (s_mask) {
256
  GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
@@ -258,7 +261,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
258
 
259
  // clear unused states
260
  for (int i = 0; i < n_kv; ++i) {
261
- data[i] = kv_self->s_mask(i);
262
  }
263
  }
264
  }
@@ -362,17 +365,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
362
 
363
  void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
364
  if (self_kq_mask) {
365
- kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
366
  }
367
  }
368
 
369
  void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
370
  if (self_kq_mask) {
371
- kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
372
  }
373
 
374
  if (self_kq_mask_swa) {
375
- kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
376
  }
377
  }
378
 
@@ -448,14 +451,14 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
448
  backend_cpu (params.backend_cpu),
449
  cvec (params.cvec),
450
  loras (params.loras),
451
- memory (params.memory),
452
  cross (params.cross),
453
  cb_func (params.cb),
454
  res (std::make_unique<llm_graph_result>()) {
455
  }
456
 
457
  int64_t llm_graph_context::n_pos_per_embd() const {
458
- return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
459
  }
460
 
461
  void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
@@ -954,11 +957,11 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
954
  }
955
 
956
  ggml_tensor * llm_graph_context::build_inp_s_copy() const {
957
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
958
 
959
- auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
960
 
961
- const auto n_kv = kv_self->n;
962
 
963
  auto & cur = inp->s_copy;
964
 
@@ -971,11 +974,11 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
971
  }
972
 
973
  ggml_tensor * llm_graph_context::build_inp_s_mask() const {
974
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
975
 
976
- auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
977
 
978
- const auto n_kv = kv_self->n;
979
 
980
  auto & cur = inp->s_mask;
981
 
@@ -1025,11 +1028,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
1025
  }
1026
 
1027
  ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1028
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1029
 
1030
- auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
1031
 
1032
- const auto n_kv = kv_self->get_n();
1033
 
1034
  auto & cur = inp->pos_bucket;
1035
 
@@ -1231,14 +1234,14 @@ ggml_tensor * llm_graph_context::build_attn(
1231
  }
1232
 
1233
  llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1234
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1235
 
1236
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
1237
 
1238
  {
1239
  GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1240
 
1241
- const auto n_kv = kv_self->get_n();
1242
 
1243
  inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1244
  //cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1268,19 +1271,19 @@ ggml_tensor * llm_graph_context::build_attn(
1268
  ggml_build_forward_expand(gf, k_cur);
1269
  ggml_build_forward_expand(gf, v_cur);
1270
 
1271
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1272
 
1273
  // store to KV cache
1274
  {
1275
- ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
1276
- ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
1277
  }
1278
 
1279
  const auto & kq_mask = inp->get_kq_mask();
1280
 
1281
  ggml_tensor * q = q_cur;
1282
- ggml_tensor * k = kv_self->get_k(ctx0, il);
1283
- ggml_tensor * v = kv_self->get_v(ctx0, il);
1284
 
1285
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1286
  cb(cur, "kqv_out", il);
@@ -1301,12 +1304,12 @@ ggml_tensor * llm_graph_context::build_attn(
1301
  }
1302
 
1303
  llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1304
- const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1305
 
1306
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
1307
 
1308
  {
1309
- const auto n_kv = kv_self->get_kv_base()->get_n();
1310
 
1311
  inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1312
  //cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1318,7 +1321,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1318
  {
1319
  GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1320
 
1321
- const auto n_kv = kv_self->get_kv_swa()->get_n();
1322
 
1323
  inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1324
  //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
@@ -1348,23 +1351,23 @@ ggml_tensor * llm_graph_context::build_attn(
1348
  ggml_build_forward_expand(gf, k_cur);
1349
  ggml_build_forward_expand(gf, v_cur);
1350
 
1351
- const bool is_swa = hparams.is_swa(il);
1352
 
1353
- const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1354
 
1355
- const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
1356
 
1357
  // store to KV cache
1358
  {
1359
- ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
1360
- ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
1361
  }
1362
 
1363
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1364
 
1365
  ggml_tensor * q = q_cur;
1366
- ggml_tensor * k = kv->get_k(ctx0, il);
1367
- ggml_tensor * v = kv->get_v(ctx0, il);
1368
 
1369
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1370
  cb(cur, "kqv_out", il);
@@ -1446,12 +1449,12 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
1446
  ggml_tensor * state_mask,
1447
  int32_t n_state,
1448
  int32_t n_seqs) const {
1449
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1450
 
1451
- const auto n_kv = kv_self->n;
1452
- const auto kv_head = kv_self->head;
1453
 
1454
- ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size);
1455
 
1456
  // copy states
1457
  // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
@@ -1478,13 +1481,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1478
  ggml_tensor * state_mask,
1479
  const llama_ubatch & ubatch,
1480
  int il) const {
1481
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1482
 
1483
  const auto token_shift_count = hparams.token_shift_count;
1484
 
1485
  const int64_t n_seqs = ubatch.n_seqs;
1486
 
1487
- ggml_tensor * token_shift_all = kv_self->k_l[il];
1488
 
1489
  ggml_tensor * token_shift = build_copy_mask_state(
1490
  gf, token_shift_all, state_copy, state_mask,
@@ -1499,19 +1502,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1499
  ggml_tensor * token_shift,
1500
  const llama_ubatch & ubatch,
1501
  int il) const {
1502
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1503
 
1504
  const auto token_shift_count = hparams.token_shift_count;
1505
  const auto n_embd = hparams.n_embd;
1506
 
1507
  const int64_t n_seqs = ubatch.n_seqs;
1508
 
1509
- const auto kv_head = kv_self->head;
1510
 
1511
  return ggml_cpy(
1512
  ctx0,
1513
  ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
1514
- ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il]))
1515
  );
1516
  }
1517
 
@@ -1562,20 +1565,25 @@ void llm_graph_context::build_pooling(
1562
  ggml_tensor * inp_cls = build_inp_cls();
1563
  inp = ggml_get_rows(ctx0, inp, inp_cls);
1564
 
1565
- // classification head
1566
- // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1567
- GGML_ASSERT(cls != nullptr);
1568
- GGML_ASSERT(cls_b != nullptr);
1569
-
1570
- cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
1571
- cur = ggml_tanh(ctx0, cur);
1572
-
1573
- // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1574
- // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1575
- if (cls_out) {
 
 
 
 
1576
  GGML_ASSERT(cls_out_b != nullptr);
1577
-
1578
- cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
 
1579
  }
1580
  } break;
1581
  default:
 
3
  #include "llama-impl.h"
4
  #include "llama-batch.h"
5
  #include "llama-cparams.h"
6
+
7
+ #include "llama-kv-cache-unified.h"
8
+ #include "llama-kv-cache-unified-iswa.h"
9
+ #include "llama-kv-cache-recurrent.h"
10
 
11
  #include <cassert>
12
  #include <cmath>
 
86
 
87
  void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
88
  if (pos_bucket) {
89
+ kv_state->set_input_pos_bucket(pos_bucket, ubatch);
90
  }
91
  }
92
 
 
237
  void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
238
  GGML_UNUSED(ubatch);
239
 
240
+ const int64_t n_kv = kv_state->get_n_kv();
241
 
242
  if (s_copy) {
243
  GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
 
245
 
246
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
247
  for (uint32_t i = 0; i < n_kv; ++i) {
248
+ data[i] = kv_state->s_copy(i);
249
  }
250
  }
251
  }
 
253
  void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
254
  GGML_UNUSED(ubatch);
255
 
256
+ const int64_t n_kv = kv_state->get_n_kv();
257
 
258
  if (s_mask) {
259
  GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
 
261
 
262
  // clear unused states
263
  for (int i = 0; i < n_kv; ++i) {
264
+ data[i] = kv_state->s_mask(i);
265
  }
266
  }
267
  }
 
365
 
366
  void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
367
  if (self_kq_mask) {
368
+ kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
369
  }
370
  }
371
 
372
  void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
373
  if (self_kq_mask) {
374
+ kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
375
  }
376
 
377
  if (self_kq_mask_swa) {
378
+ kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
379
  }
380
  }
381
 
 
451
  backend_cpu (params.backend_cpu),
452
  cvec (params.cvec),
453
  loras (params.loras),
454
+ mstate (params.mstate),
455
  cross (params.cross),
456
  cb_func (params.cb),
457
  res (std::make_unique<llm_graph_result>()) {
458
  }
459
 
460
  int64_t llm_graph_context::n_pos_per_embd() const {
461
+ return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
462
  }
463
 
464
  void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
 
957
  }
958
 
959
  ggml_tensor * llm_graph_context::build_inp_s_copy() const {
960
+ const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
961
 
962
+ auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
963
 
964
+ const auto n_kv = kv_state->get_n_kv();
965
 
966
  auto & cur = inp->s_copy;
967
 
 
974
  }
975
 
976
  ggml_tensor * llm_graph_context::build_inp_s_mask() const {
977
+ const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
978
 
979
+ auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
980
 
981
+ const auto n_kv = kv_state->get_n_kv();
982
 
983
  auto & cur = inp->s_mask;
984
 
 
1028
  }
1029
 
1030
  ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1031
+ const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1032
 
1033
+ auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
1034
 
1035
+ const auto n_kv = kv_state->get_n_kv();
1036
 
1037
  auto & cur = inp->pos_bucket;
1038
 
 
1234
  }
1235
 
1236
  llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1237
+ const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1238
 
1239
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
1240
 
1241
  {
1242
  GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1243
 
1244
+ const auto n_kv = kv_state->get_n_kv();
1245
 
1246
  inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1247
  //cb(inp->self_kq_mask, "KQ_mask", -1);
 
1271
  ggml_build_forward_expand(gf, k_cur);
1272
  ggml_build_forward_expand(gf, v_cur);
1273
 
1274
+ const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1275
 
1276
  // store to KV cache
1277
  {
1278
+ ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1279
+ ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1280
  }
1281
 
1282
  const auto & kq_mask = inp->get_kq_mask();
1283
 
1284
  ggml_tensor * q = q_cur;
1285
+ ggml_tensor * k = kv_state->get_k(ctx0, il);
1286
+ ggml_tensor * v = kv_state->get_v(ctx0, il);
1287
 
1288
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1289
  cb(cur, "kqv_out", il);
 
1304
  }
1305
 
1306
  llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1307
+ const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
1308
 
1309
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
1310
 
1311
  {
1312
+ const auto n_kv = kv_state->get_base()->get_n_kv();
1313
 
1314
  inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1315
  //cb(inp->self_kq_mask, "KQ_mask", -1);
 
1321
  {
1322
  GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1323
 
1324
+ const auto n_kv = kv_state->get_swa()->get_n_kv();
1325
 
1326
  inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1327
  //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
 
1351
  ggml_build_forward_expand(gf, k_cur);
1352
  ggml_build_forward_expand(gf, v_cur);
1353
 
1354
+ const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
1355
 
1356
+ const bool is_swa = hparams.is_swa(il);
1357
 
1358
+ const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base();
1359
 
1360
  // store to KV cache
1361
  {
1362
+ ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1363
+ ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1364
  }
1365
 
1366
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1367
 
1368
  ggml_tensor * q = q_cur;
1369
+ ggml_tensor * k = kv_state->get_k(ctx0, il);
1370
+ ggml_tensor * v = kv_state->get_v(ctx0, il);
1371
 
1372
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1373
  cb(cur, "kqv_out", il);
 
1449
  ggml_tensor * state_mask,
1450
  int32_t n_state,
1451
  int32_t n_seqs) const {
1452
+ const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1453
 
1454
+ const auto n_kv = kv_state->get_n_kv();
1455
+ const auto kv_head = kv_state->get_head();
1456
 
1457
+ ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size());
1458
 
1459
  // copy states
1460
  // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
 
1481
  ggml_tensor * state_mask,
1482
  const llama_ubatch & ubatch,
1483
  int il) const {
1484
+ const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1485
 
1486
  const auto token_shift_count = hparams.token_shift_count;
1487
 
1488
  const int64_t n_seqs = ubatch.n_seqs;
1489
 
1490
+ ggml_tensor * token_shift_all = kv_state->get_k_l(il);
1491
 
1492
  ggml_tensor * token_shift = build_copy_mask_state(
1493
  gf, token_shift_all, state_copy, state_mask,
 
1502
  ggml_tensor * token_shift,
1503
  const llama_ubatch & ubatch,
1504
  int il) const {
1505
+ const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1506
 
1507
  const auto token_shift_count = hparams.token_shift_count;
1508
  const auto n_embd = hparams.n_embd;
1509
 
1510
  const int64_t n_seqs = ubatch.n_seqs;
1511
 
1512
+ const auto kv_head = kv_state->get_head();
1513
 
1514
  return ggml_cpy(
1515
  ctx0,
1516
  ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
1517
+ ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il)))
1518
  );
1519
  }
1520
 
 
1565
  ggml_tensor * inp_cls = build_inp_cls();
1566
  inp = ggml_get_rows(ctx0, inp, inp_cls);
1567
 
1568
+ if (cls != nullptr && cls_b != nullptr) {
1569
+ // classification head
1570
+ // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1571
+ cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
1572
+ cur = ggml_tanh(ctx0, cur);
1573
+
1574
+ // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1575
+ // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1576
+ if (cls_out) {
1577
+ GGML_ASSERT(cls_out_b != nullptr);
1578
+ cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
1579
+ }
1580
+ } else if (cls_out) {
1581
+ // Single layer classification head (direct projection)
1582
+ // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
1583
  GGML_ASSERT(cls_out_b != nullptr);
1584
+ cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
1585
+ } else {
1586
+ GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
1587
  }
1588
  } break;
1589
  default:
examples/talk-llama/llama-graph.h CHANGED
@@ -17,10 +17,11 @@ struct ggml_tensor;
17
  struct llama_ubatch;
18
  struct llama_cparams;
19
 
20
- class llama_memory_i;
21
- class llama_kv_cache_unified;
22
- class llama_kv_cache_unified_iswa;
23
- class llama_kv_cache_recurrent;
 
24
 
25
  // certain models (typically multi-modal) can produce different types of graphs
26
  enum llm_graph_type {
@@ -133,7 +134,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
133
  public:
134
  llm_graph_input_pos_bucket_kv(
135
  const llama_hparams & hparams,
136
- const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {}
137
  virtual ~llm_graph_input_pos_bucket_kv() = default;
138
 
139
  void set_input(const llama_ubatch * ubatch) override;
@@ -141,7 +142,7 @@ public:
141
  ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
142
 
143
  const llama_hparams & hparams;
144
- const llama_kv_cache_unified * kv_self;
145
  };
146
 
147
  class llm_graph_input_out_ids : public llm_graph_input_i {
@@ -188,26 +189,26 @@ public:
188
 
189
  class llm_graph_input_s_copy : public llm_graph_input_i {
190
  public:
191
- llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
192
  virtual ~llm_graph_input_s_copy() = default;
193
 
194
  void set_input(const llama_ubatch * ubatch) override;
195
 
196
  ggml_tensor * s_copy; // I32 [kv_size]
197
 
198
- const llama_kv_cache_recurrent * kv_self;
199
  };
200
 
201
  class llm_graph_input_s_mask : public llm_graph_input_i {
202
  public:
203
- llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
204
  virtual ~llm_graph_input_s_mask() = default;
205
 
206
  void set_input(const llama_ubatch * ubatch) override;
207
 
208
  ggml_tensor * s_mask; // F32 [1, n_kv]
209
 
210
- const llama_kv_cache_recurrent * kv_self;
211
  };
212
 
213
  class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -247,10 +248,10 @@ public:
247
  llm_graph_input_attn_kv_unified(
248
  const llama_hparams & hparams,
249
  const llama_cparams & cparams,
250
- const llama_kv_cache_unified * kv_self) :
251
  hparams(hparams),
252
  cparams(cparams),
253
- kv_self(kv_self) {
254
  }
255
  ~llm_graph_input_attn_kv_unified() = default;
256
 
@@ -264,7 +265,7 @@ public:
264
  const llama_hparams & hparams;
265
  const llama_cparams & cparams;
266
 
267
- const llama_kv_cache_unified * kv_self;
268
  };
269
 
270
  class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
@@ -272,10 +273,10 @@ public:
272
  llm_graph_input_attn_kv_unified_iswa(
273
  const llama_hparams & hparams,
274
  const llama_cparams & cparams,
275
- const llama_kv_cache_unified_iswa * kv_self) :
276
  hparams(hparams),
277
  cparams(cparams),
278
- kv_self(kv_self) {
279
  }
280
  ~llm_graph_input_attn_kv_unified_iswa() = default;
281
 
@@ -292,7 +293,7 @@ public:
292
  const llama_hparams & hparams;
293
  const llama_cparams & cparams;
294
 
295
- const llama_kv_cache_unified_iswa * kv_self;
296
  };
297
 
298
  class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -383,10 +384,10 @@ struct llm_graph_params {
383
  ggml_backend_sched_t sched;
384
  ggml_backend_t backend_cpu;
385
 
386
- const llama_adapter_cvec * cvec;
387
- const llama_adapter_loras * loras;
388
- const llama_memory_i * memory;
389
- const llama_cross * cross;
390
 
391
  int32_t n_outputs;
392
 
@@ -435,10 +436,10 @@ struct llm_graph_context {
435
 
436
  ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
437
 
438
- const llama_adapter_cvec * cvec;
439
- const llama_adapter_loras * loras;
440
- const llama_memory_i * memory;
441
- const llama_cross * cross;
442
 
443
  const llm_graph_cb & cb_func;
444
 
 
17
  struct llama_ubatch;
18
  struct llama_cparams;
19
 
20
+ class llama_memory_state_i;
21
+
22
+ class llama_kv_cache_unified_state;
23
+ class llama_kv_cache_unified_iswa_state;
24
+ class llama_kv_cache_recurrent_state;
25
 
26
  // certain models (typically multi-modal) can produce different types of graphs
27
  enum llm_graph_type {
 
134
  public:
135
  llm_graph_input_pos_bucket_kv(
136
  const llama_hparams & hparams,
137
+ const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
138
  virtual ~llm_graph_input_pos_bucket_kv() = default;
139
 
140
  void set_input(const llama_ubatch * ubatch) override;
 
142
  ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
143
 
144
  const llama_hparams & hparams;
145
+ const llama_kv_cache_unified_state * kv_state;
146
  };
147
 
148
  class llm_graph_input_out_ids : public llm_graph_input_i {
 
189
 
190
  class llm_graph_input_s_copy : public llm_graph_input_i {
191
  public:
192
+ llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
193
  virtual ~llm_graph_input_s_copy() = default;
194
 
195
  void set_input(const llama_ubatch * ubatch) override;
196
 
197
  ggml_tensor * s_copy; // I32 [kv_size]
198
 
199
+ const llama_kv_cache_recurrent_state * kv_state;
200
  };
201
 
202
  class llm_graph_input_s_mask : public llm_graph_input_i {
203
  public:
204
+ llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
205
  virtual ~llm_graph_input_s_mask() = default;
206
 
207
  void set_input(const llama_ubatch * ubatch) override;
208
 
209
  ggml_tensor * s_mask; // F32 [1, n_kv]
210
 
211
+ const llama_kv_cache_recurrent_state * kv_state;
212
  };
213
 
214
  class llm_graph_input_cross_embd : public llm_graph_input_i {
 
248
  llm_graph_input_attn_kv_unified(
249
  const llama_hparams & hparams,
250
  const llama_cparams & cparams,
251
+ const llama_kv_cache_unified_state * kv_state) :
252
  hparams(hparams),
253
  cparams(cparams),
254
+ kv_state(kv_state) {
255
  }
256
  ~llm_graph_input_attn_kv_unified() = default;
257
 
 
265
  const llama_hparams & hparams;
266
  const llama_cparams & cparams;
267
 
268
+ const llama_kv_cache_unified_state * kv_state;
269
  };
270
 
271
  class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
 
273
  llm_graph_input_attn_kv_unified_iswa(
274
  const llama_hparams & hparams,
275
  const llama_cparams & cparams,
276
+ const llama_kv_cache_unified_iswa_state * kv_state) :
277
  hparams(hparams),
278
  cparams(cparams),
279
+ kv_state(kv_state) {
280
  }
281
  ~llm_graph_input_attn_kv_unified_iswa() = default;
282
 
 
293
  const llama_hparams & hparams;
294
  const llama_cparams & cparams;
295
 
296
+ const llama_kv_cache_unified_iswa_state * kv_state;
297
  };
298
 
299
  class llm_graph_input_attn_cross : public llm_graph_input_i {
 
384
  ggml_backend_sched_t sched;
385
  ggml_backend_t backend_cpu;
386
 
387
+ const llama_adapter_cvec * cvec;
388
+ const llama_adapter_loras * loras;
389
+ const llama_memory_state_i * mstate;
390
+ const llama_cross * cross;
391
 
392
  int32_t n_outputs;
393
 
 
436
 
437
  ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
438
 
439
+ const llama_adapter_cvec * cvec;
440
+ const llama_adapter_loras * loras;
441
+ const llama_memory_state_i * mstate;
442
+ const llama_cross * cross;
443
 
444
  const llm_graph_cb & cb_func;
445
 
examples/talk-llama/llama-hparams.h CHANGED
@@ -131,6 +131,9 @@ struct llama_hparams {
131
  bool attn_soft_cap = false;
132
  bool use_kq_norm = true;
133
 
 
 
 
134
  // llama4
135
  uint32_t n_moe_layer_step = 0;
136
  uint32_t n_no_rope_layer_step = 4;
 
131
  bool attn_soft_cap = false;
132
  bool use_kq_norm = true;
133
 
134
+ // for Classifiers
135
+ uint32_t n_cls_out = 1;
136
+
137
  // llama4
138
  uint32_t n_moe_layer_step = 0;
139
  uint32_t n_no_rope_layer_step = 4;
examples/talk-llama/llama-kv-cache-recurrent.cpp ADDED
@@ -0,0 +1,1132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "llama-kv-cache-recurrent.h"
2
+
3
+ #include "llama-impl.h"
4
+ #include "llama-batch.h"
5
+ #include "llama-model.h"
6
+
7
+ #include <algorithm>
8
+ #include <cassert>
9
+ #include <limits>
10
+ #include <map>
11
+ #include <stdexcept>
12
+
13
+ //
14
+ // llama_kv_cache_recurrent
15
+ //
16
+
17
+ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
18
+ const llama_model & model,
19
+ ggml_type type_k,
20
+ ggml_type type_v,
21
+ bool offload,
22
+ uint32_t kv_size,
23
+ uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
24
+ const int32_t n_layer = hparams.n_layer;
25
+
26
+ LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
27
+ __func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
28
+
29
+ head = 0;
30
+ size = kv_size;
31
+ used = 0;
32
+
33
+ cells.clear();
34
+ cells.resize(kv_size);
35
+
36
+ // create a context for each buffer type
37
+ std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
38
+ auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
39
+ auto it = ctx_map.find(buft);
40
+ if (it == ctx_map.end()) {
41
+ ggml_init_params params = {
42
+ /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
43
+ /*.mem_buffer =*/ NULL,
44
+ /*.no_alloc =*/ true,
45
+ };
46
+
47
+ ggml_context * ctx = ggml_init(params);
48
+ if (!ctx) {
49
+ return nullptr;
50
+ }
51
+
52
+ ctx_map[buft] = ctx;
53
+ ctxs.emplace_back(ctx);
54
+
55
+ return ctx;
56
+ }
57
+
58
+ return it->second;
59
+ };
60
+
61
+ k_l.reserve(n_layer);
62
+ v_l.reserve(n_layer);
63
+
64
+ for (int i = 0; i < n_layer; i++) {
65
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
66
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
67
+
68
+ const char * dev_name = "CPU";
69
+
70
+ ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
71
+
72
+ if (offload) {
73
+ auto * dev = model.dev_layer(i);
74
+ buft = ggml_backend_dev_buffer_type(dev);
75
+
76
+ dev_name = ggml_backend_dev_name(dev);
77
+ }
78
+
79
+ LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name);
80
+
81
+ ggml_context * ctx = ctx_for_buft(buft);
82
+ if (!ctx) {
83
+ throw std::runtime_error("failed to create ggml context for kv cache");
84
+ }
85
+
86
+ ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
87
+ ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
88
+ ggml_format_name(k, "cache_k_l%d", i);
89
+ ggml_format_name(v, "cache_v_l%d", i);
90
+ k_l.push_back(k);
91
+ v_l.push_back(v);
92
+ }
93
+
94
+ // allocate tensors and initialize the buffers to avoid NaNs in the padding
95
+ for (auto it : ctx_map) {
96
+ auto * buft = it.first;
97
+ auto * ctx = it.second;
98
+
99
+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
100
+ if (!buf) {
101
+ throw std::runtime_error("failed to allocate buffer for kv cache");
102
+ }
103
+ ggml_backend_buffer_clear(buf, 0);
104
+ LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
105
+ bufs.emplace_back(buf);
106
+ }
107
+
108
+ {
109
+ const size_t memory_size_k = size_k_bytes();
110
+ const size_t memory_size_v = size_v_bytes();
111
+
112
+ LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
113
+ (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
114
+ ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
115
+ ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
116
+ }
117
+ }
118
+
119
+ void llama_kv_cache_recurrent::clear() {
120
+ for (int32_t i = 0; i < (int32_t) size; ++i) {
121
+ cells[i].pos = -1;
122
+ cells[i].seq_id.clear();
123
+ cells[i].src = -1;
124
+ cells[i].tail = -1;
125
+ }
126
+ head = 0;
127
+ used = 0;
128
+
129
+ for (auto & buf : bufs) {
130
+ ggml_backend_buffer_clear(buf.get(), 0);
131
+ }
132
+ }
133
+
134
+ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
135
+ uint32_t new_head = size;
136
+
137
+ if (p0 < 0) {
138
+ p0 = 0;
139
+ }
140
+
141
+ if (p1 < 0) {
142
+ p1 = std::numeric_limits<llama_pos>::max();
143
+ }
144
+
145
+ // models like Mamba or RWKV can't have a state partially erased
146
+ if (seq_id >= (int64_t) size) {
147
+ // could be fatal
148
+ return false;
149
+ }
150
+ if (0 <= seq_id) {
151
+ int32_t & tail_id = cells[seq_id].tail;
152
+ if (tail_id >= 0) {
153
+ const kv_cell & cell = cells[tail_id];
154
+ // partial intersection is invalid
155
+ if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
156
+ return false;
157
+ }
158
+ // invalidate tails which will be cleared
159
+ if (p0 <= cell.pos && cell.pos < p1) {
160
+ tail_id = -1;
161
+ }
162
+ }
163
+ } else {
164
+ // seq_id is negative, then the range should include everything or nothing
165
+ if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
166
+ return false;
167
+ }
168
+ }
169
+
170
+ for (uint32_t i = 0; i < size; ++i) {
171
+ if (cells[i].pos >= p0 && cells[i].pos < p1) {
172
+ if (seq_id < 0) {
173
+ cells[i].seq_id.clear();
174
+ } else if (cells[i].has_seq_id(seq_id)) {
175
+ cells[i].seq_id.erase(seq_id);
176
+ } else {
177
+ continue;
178
+ }
179
+ if (cells[i].is_empty()) {
180
+ // keep count of the number of used cells
181
+ if (cells[i].pos >= 0) {
182
+ used--;
183
+ }
184
+ cells[i].pos = -1;
185
+ cells[i].src = -1;
186
+ if (new_head == size) {
187
+ new_head = i;
188
+ }
189
+ }
190
+ }
191
+ }
192
+
193
+ // If we freed up a slot, set head to it so searching can start there.
194
+ if (new_head != size && new_head < head) {
195
+ head = new_head;
196
+ }
197
+
198
+ return true;
199
+ }
200
+
201
+ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
202
+ if (seq_id_src == seq_id_dst) {
203
+ return;
204
+ }
205
+
206
+ if (p0 < 0) {
207
+ p0 = 0;
208
+ }
209
+
210
+ if (p1 < 0) {
211
+ p1 = std::numeric_limits<llama_pos>::max();
212
+ }
213
+
214
+ if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
215
+ kv_cell & tail_src = cells[seq_id_src];
216
+ kv_cell & tail_dst = cells[seq_id_dst];
217
+ if (tail_dst.tail >= 0) {
218
+ // clear destination seq_id if it wasn't empty
219
+ kv_cell & cell_dst = cells[tail_dst.tail];
220
+
221
+ cell_dst.seq_id.erase(seq_id_dst);
222
+ tail_dst.tail = -1;
223
+ if (cell_dst.seq_id.empty()) {
224
+ cell_dst.pos = -1;
225
+ cell_dst.src = -1;
226
+ used -= 1;
227
+ }
228
+ }
229
+ if (tail_src.tail >= 0) {
230
+ kv_cell & cell_src = cells[tail_src.tail];
231
+
232
+ cell_src.seq_id.insert(seq_id_dst);
233
+ tail_dst.tail = tail_src.tail;
234
+ }
235
+ }
236
+ }
237
+
238
+ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
239
+ uint32_t new_head = size;
240
+
241
+ for (uint32_t i = 0; i < size; ++i) {
242
+ if ((llama_seq_id) i != seq_id) {
243
+ cells[i].tail = -1;
244
+ }
245
+
246
+ if (!cells[i].has_seq_id(seq_id)) {
247
+ if (cells[i].pos >= 0) {
248
+ used--;
249
+ }
250
+
251
+ cells[i].pos = -1;
252
+ cells[i].src = -1;
253
+ cells[i].seq_id.clear();
254
+
255
+ if (new_head == size){
256
+ new_head = i;
257
+ }
258
+ } else {
259
+ cells[i].seq_id.clear();
260
+ cells[i].seq_id.insert(seq_id);
261
+ }
262
+ }
263
+
264
+ // If we freed up a slot, set head to it so searching can start there.
265
+ if (new_head != size && new_head < head) {
266
+ head = new_head;
267
+ }
268
+ }
269
+
270
+ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
271
+ if (shift == 0) {
272
+ return;
273
+ }
274
+
275
+ if (p0 < 0) {
276
+ p0 = 0;
277
+ }
278
+
279
+ if (p1 < 0) {
280
+ p1 = std::numeric_limits<llama_pos>::max();
281
+ }
282
+
283
+ // If there is no range then return early to avoid looping over the
284
+ if (p0 == p1) {
285
+ return;
286
+ }
287
+
288
+ // for Mamba-like or RWKV models, only the pos needs to be shifted
289
+ if (0 <= seq_id && seq_id < (int64_t) size) {
290
+ const int32_t tail_id = cells[seq_id].tail;
291
+ if (tail_id >= 0) {
292
+ kv_cell & cell = cells[tail_id];
293
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
294
+ cell.pos += shift;
295
+ }
296
+ }
297
+ }
298
+ }
299
+
300
+ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
301
+ if (d == 1) {
302
+ return;
303
+ }
304
+
305
+ if (p0 < 0) {
306
+ p0 = 0;
307
+ }
308
+
309
+ if (p1 < 0) {
310
+ p1 = std::numeric_limits<llama_pos>::max();
311
+ }
312
+
313
+ // If there is no range then return early to avoid looping over the cache.
314
+ if (p0 == p1) {
315
+ return;
316
+ }
317
+
318
+ // for Mamba-like or RWKV models, only the pos needs to be changed
319
+ if (0 <= seq_id && seq_id < (int64_t) size) {
320
+ const int32_t tail_id = cells[seq_id].tail;
321
+ if (tail_id >= 0) {
322
+ kv_cell & cell = cells[tail_id];
323
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
324
+ cell.pos /= d;
325
+ }
326
+ }
327
+ }
328
+ }
329
+
330
+ llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
331
+ llama_pos result = std::numeric_limits<llama_pos>::max();
332
+
333
+ for (uint32_t i = 0; i < size; ++i) {
334
+ if (cells[i].has_seq_id(seq_id)) {
335
+ result = std::min(result, cells[i].pos);
336
+ }
337
+ }
338
+
339
+ if (result == std::numeric_limits<llama_pos>::max()) {
340
+ result = -1;
341
+ }
342
+
343
+ return result;
344
+ }
345
+
346
+ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
347
+ llama_pos result = -1;
348
+
349
+ for (uint32_t i = 0; i < size; ++i) {
350
+ if (cells[i].has_seq_id(seq_id)) {
351
+ result = std::max(result, cells[i].pos);
352
+ }
353
+ }
354
+
355
+ return result;
356
+ }
357
+
358
+ llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
359
+ GGML_UNUSED(embd_pooled);
360
+
361
+ auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
362
+
363
+ std::vector<llama_ubatch> ubatches;
364
+
365
+ while (sbatch.n_tokens > 0) {
366
+ llama_ubatch ubatch;
367
+
368
+ if (embd_pooled) {
369
+ // Pooled embeddings cannot be split across ubatches (yet)
370
+ ubatch = sbatch.split_seq(n_ubatch);
371
+ } else {
372
+ ubatch = sbatch.split_equal(n_ubatch);
373
+ }
374
+
375
+ ubatches.push_back(ubatch);
376
+ }
377
+
378
+ if (!prepare(ubatches)) {
379
+ return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
380
+ }
381
+
382
+ return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches));
383
+ }
384
+
385
+ llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
386
+ return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
387
+ }
388
+
389
+ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
390
+ // simply remember the full state because it is very small for this type of cache
391
+ // TODO: optimize
392
+ auto org_cells = cells;
393
+ auto org_used = used;
394
+ auto org_head = head;
395
+
396
+ bool success = true;
397
+
398
+ // TODO: here we have to verify that all ubatches can fit in the cells
399
+ // however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells
400
+ // during the compute of each ubatch. to reproduce, uncomment the following loop and run:
401
+ //
402
+ // $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8
403
+ //
404
+ // recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed
405
+ //
406
+ GGML_UNUSED(ubatches);
407
+ //for (const auto & ubatch : ubatches) {
408
+ // if (!find_slot(ubatch)) {
409
+ // success = false;
410
+ // break;
411
+ // }
412
+ //}
413
+
414
+ // restore the original state
415
+ cells = std::move(org_cells);
416
+ used = org_used;
417
+ head = org_head;
418
+
419
+ return success;
420
+ }
421
+
422
+ bool llama_kv_cache_recurrent::update(llama_context & lctx) {
423
+ GGML_UNUSED(lctx);
424
+ // noop
425
+ return false;
426
+ }
427
+
428
+ void llama_kv_cache_recurrent::defrag_sched(float thold) {
429
+ GGML_UNUSED(thold);
430
+ // noop
431
+ }
432
+
433
+ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
434
+ const uint32_t n_tokens = ubatch.n_tokens;
435
+ const uint32_t n_seqs = ubatch.n_seqs;
436
+
437
+ const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
438
+
439
+ // if we have enough unused cells before the current head ->
440
+ // better to start searching from the beginning of the cache, hoping to fill it
441
+ if (head > used + 2*n_tokens) {
442
+ head = 0;
443
+ }
444
+
445
+ // For recurrent state architectures (like Mamba or RWKV),
446
+ // each cache cell can store the state for a whole sequence.
447
+ // A slot should be always be contiguous.
448
+
449
+ // can only process batches with an equal number of new tokens in each sequence
450
+ GGML_ASSERT(ubatch.equal_seqs);
451
+
452
+ int32_t min = size - 1;
453
+ int32_t max = 0;
454
+
455
+ // everything should fit if all seq_ids are smaller than the max
456
+ for (uint32_t s = 0; s < n_seqs; ++s) {
457
+ const uint32_t n_seq_id = ubatch.n_seq_id[s];
458
+ for (uint32_t j = 0; j < n_seq_id; ++j) {
459
+ const llama_seq_id seq_id = ubatch.seq_id[s][j];
460
+
461
+ if (seq_id < 0 || (uint32_t) seq_id >= size) {
462
+ // too big seq_id
463
+ // TODO: would it be possible to resize the cache instead?
464
+ LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max);
465
+ return false;
466
+ }
467
+ if (j > 0) {
468
+ kv_cell & seq = cells[seq_id];
469
+ if (seq.tail >= 0) {
470
+ kv_cell & cell = cells[seq.tail];
471
+ // clear cells from seq_ids that become shared
472
+ // (should not normally happen, but let's handle it anyway)
473
+ cell.seq_id.erase(seq_id);
474
+ seq.tail = -1;
475
+ if (cell.seq_id.empty()) {
476
+ cell.pos = -1;
477
+ cell.src = -1;
478
+ used -= 1;
479
+ }
480
+ }
481
+ }
482
+ }
483
+ }
484
+
485
+ #ifndef NDEBUG
486
+ {
487
+ std::vector<int32_t> tails_verif;
488
+ tails_verif.assign(size, -1);
489
+ for (uint32_t i = 0; i < size; ++i) {
490
+ kv_cell & cell = cells[i];
491
+ for (llama_seq_id seq_id : cell.seq_id) {
492
+ if (tails_verif[seq_id] != -1) {
493
+ LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
494
+ }
495
+ tails_verif[seq_id] = i;
496
+ }
497
+ }
498
+ for (uint32_t i = 0; i < size; ++i) {
499
+ if (tails_verif[i] != cells[i].tail) {
500
+ LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]);
501
+ }
502
+ }
503
+ }
504
+ #endif
505
+
506
+ // find next empty cell
507
+ uint32_t next_empty_cell = head;
508
+
509
+ for (uint32_t i = 0; i < size; ++i) {
510
+ if (next_empty_cell >= size) { next_empty_cell -= size; }
511
+ kv_cell & cell = cells[next_empty_cell];
512
+ if (cell.is_empty()) { break; }
513
+ next_empty_cell += 1;
514
+ }
515
+
516
+ // find usable cell range
517
+ for (uint32_t s = 0; s < n_seqs; ++s) {
518
+ const llama_seq_id seq_id = ubatch.seq_id[s][0];
519
+ kv_cell & seq_meta = cells[seq_id];
520
+ bool has_cell = false;
521
+ if (seq_meta.tail >= 0) {
522
+ kv_cell & cell = cells[seq_meta.tail];
523
+ GGML_ASSERT(cell.has_seq_id(seq_id));
524
+ // does this seq_id "own" the cell?
525
+ if (cell.seq_id.size() == 1) { has_cell = true; }
526
+ }
527
+ if (!has_cell) {
528
+ kv_cell & empty_cell = cells[next_empty_cell];
529
+ GGML_ASSERT(empty_cell.is_empty());
530
+ // copy old tail into the empty cell
531
+ if (seq_meta.tail >= 0) {
532
+ kv_cell & orig_cell = cells[seq_meta.tail];
533
+ empty_cell.pos = orig_cell.pos;
534
+ empty_cell.src = orig_cell.src;
535
+ orig_cell.seq_id.erase(seq_id);
536
+ empty_cell.seq_id.insert(seq_id); // will be overwritten
537
+ }
538
+ seq_meta.tail = next_empty_cell;
539
+ // find next empty cell
540
+ if (s + 1 < n_seqs) {
541
+ next_empty_cell += 1;
542
+ for (uint32_t i = 0; i < size; ++i) {
543
+ if (next_empty_cell >= size) { next_empty_cell -= size; }
544
+ kv_cell & cell = cells[next_empty_cell];
545
+ if (cell.is_empty()) { break; }
546
+ next_empty_cell += 1;
547
+ }
548
+ }
549
+ }
550
+ if (min > seq_meta.tail) { min = seq_meta.tail; }
551
+ if (max < seq_meta.tail) { max = seq_meta.tail; }
552
+ }
553
+
554
+ // gather and re-order
555
+ for (uint32_t s = 0; s < n_seqs; ++s) {
556
+ int32_t dst_id = s + min;
557
+ int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
558
+ if (dst_id != src_id) {
559
+ kv_cell & dst_cell = cells[dst_id];
560
+ kv_cell & src_cell = cells[src_id];
561
+
562
+ std::swap(dst_cell.pos, src_cell.pos);
563
+ std::swap(dst_cell.src, src_cell.src);
564
+ std::swap(dst_cell.seq_id, src_cell.seq_id);
565
+
566
+ // swap tails (assuming they NEVER overlap)
567
+ for (const llama_seq_id seq_id : src_cell.seq_id) {
568
+ cells[seq_id].tail = src_id;
569
+ }
570
+ for (const llama_seq_id seq_id : dst_cell.seq_id) {
571
+ cells[seq_id].tail = dst_id;
572
+ }
573
+ }
574
+ }
575
+
576
+ // update the pos of the used seqs
577
+ for (uint32_t s = 0; s < n_seqs; ++s) {
578
+ const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
579
+ int32_t cell_id = s + min;
580
+ kv_cell & cell = cells[cell_id];
581
+
582
+ if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
583
+ // What should happen when the pos backtracks or skips a value?
584
+ // Clearing the state mid-batch would require special-casing which isn't done.
585
+ LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
586
+ __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
587
+ }
588
+ cell.pos = last_pos;
589
+ cell.seq_id.clear();
590
+ for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
591
+ const llama_seq_id seq_id = ubatch.seq_id[s][j];
592
+ cell.seq_id.insert(seq_id);
593
+ cells[seq_id].tail = cell_id;
594
+ }
595
+ }
596
+
597
+ // allow getting the range of used cells, from head to head + n
598
+ head = min;
599
+ n = max - min + 1;
600
+ used = std::count_if(cells.begin(), cells.end(),
601
+ [](const kv_cell & cell){ return !cell.is_empty(); });
602
+
603
+ // sanity check
604
+ return n >= n_seqs;
605
+ }
606
+
607
+ bool llama_kv_cache_recurrent::get_can_shift() const {
608
+ return false;
609
+ }
610
+
611
+ int32_t llama_kv_cache_recurrent::s_copy(int i) const {
612
+ const uint32_t cell_id = i + head;
613
+
614
+ //////////////////////////////////////////////
615
+ // TODO: this should not mutate the KV cache !
616
+ kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
617
+
618
+ // prevent out-of-bound sources
619
+ if (cell.src < 0 || (uint32_t) cell.src >= size) {
620
+ cell.src = cell_id;
621
+ }
622
+
623
+ int32_t res = cell.src;
624
+
625
+ // TODO: do not mutate the KV cache
626
+ // ensure copy only happens once
627
+ if (cell.src != (int32_t) cell_id) {
628
+ cell.src = cell_id;
629
+ }
630
+
631
+ return res;
632
+ }
633
+
634
+ float llama_kv_cache_recurrent::s_mask(int i) const {
635
+ const uint32_t cell_id = i + head;
636
+
637
+ //////////////////////////////////////////////
638
+ // TODO: this should not mutate the KV cache !
639
+ kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
640
+
641
+ float res = (float) (cell.src >= 0);
642
+
643
+ // only clear once
644
+ if (cell.src < 0) {
645
+ cell.src = cell_id;
646
+ }
647
+
648
+ return res;
649
+ }
650
+
651
+ size_t llama_kv_cache_recurrent::total_size() const {
652
+ size_t size = 0;
653
+ for (const auto & buf : bufs) {
654
+ size += ggml_backend_buffer_get_size(buf.get());
655
+ }
656
+
657
+ return size;
658
+ }
659
+
660
+ size_t llama_kv_cache_recurrent::size_k_bytes() const {
661
+ size_t size_k_bytes = 0;
662
+
663
+ for (const auto & k : k_l) {
664
+ size_k_bytes += ggml_nbytes(k);
665
+ }
666
+
667
+ return size_k_bytes;
668
+ }
669
+
670
+ size_t llama_kv_cache_recurrent::size_v_bytes() const {
671
+ size_t size_v_bytes = 0;
672
+
673
+ for (const auto & v : v_l) {
674
+ size_v_bytes += ggml_nbytes(v);
675
+ }
676
+
677
+ return size_v_bytes;
678
+ }
679
+
680
+ void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
681
+ std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
682
+ uint32_t cell_count = 0;
683
+
684
+ // Count the number of cells with the specified seq_id
685
+ // Find all the ranges of cells with this seq id (or all, when -1)
686
+ uint32_t cell_range_begin = size;
687
+ for (uint32_t i = 0; i < size; ++i) {
688
+ const auto & cell = cells[i];
689
+ if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
690
+ ++cell_count;
691
+ if (cell_range_begin == size) {
692
+ cell_range_begin = i;
693
+ }
694
+ } else {
695
+ if (cell_range_begin != size) {
696
+ cell_ranges.emplace_back(cell_range_begin, i);
697
+ cell_range_begin = size;
698
+ }
699
+ }
700
+ }
701
+ if (cell_range_begin != size) {
702
+ cell_ranges.emplace_back(cell_range_begin, size);
703
+ }
704
+
705
+ // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
706
+ uint32_t cell_count_check = 0;
707
+ for (const auto & range : cell_ranges) {
708
+ cell_count_check += range.second - range.first;
709
+ }
710
+ GGML_ASSERT(cell_count == cell_count_check);
711
+
712
+ io.write(&cell_count, sizeof(cell_count));
713
+
714
+ state_write_meta(io, cell_ranges, seq_id);
715
+ state_write_data(io, cell_ranges);
716
+ }
717
+
718
+ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
719
+ uint32_t cell_count;
720
+ io.read_to(&cell_count, sizeof(cell_count));
721
+
722
+ bool res = true;
723
+
724
+ res = res && state_read_meta(io, cell_count, seq_id);
725
+ res = res && state_read_data(io, cell_count);
726
+
727
+ if (!res) {
728
+ if (seq_id == -1) {
729
+ clear();
730
+ } else {
731
+ seq_rm(seq_id, -1, -1);
732
+ }
733
+ throw std::runtime_error("failed to restore kv cache");
734
+ }
735
+ }
736
+
737
+ void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
738
+ for (const auto & range : cell_ranges) {
739
+ for (uint32_t i = range.first; i < range.second; ++i) {
740
+ const auto & cell = cells[i];
741
+ const llama_pos pos = cell.pos;
742
+ const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
743
+
744
+ io.write(&pos, sizeof(pos));
745
+ io.write(&n_seq_id, sizeof(n_seq_id));
746
+
747
+ if (n_seq_id) {
748
+ for (auto seq_id : cell.seq_id) {
749
+ io.write(&seq_id, sizeof(seq_id));
750
+ }
751
+ }
752
+ }
753
+ }
754
+ }
755
+
756
+ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
757
+ const uint32_t v_trans = 0;
758
+ const uint32_t n_layer = hparams.n_layer;
759
+
760
+ io.write(&v_trans, sizeof(v_trans));
761
+ io.write(&n_layer, sizeof(n_layer));
762
+
763
+ std::vector<uint8_t> tmp_buf;
764
+
765
+ // Iterate and write all the keys first, each row is a cell
766
+ // Get whole range at a time
767
+ for (uint32_t il = 0; il < n_layer; ++il) {
768
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
769
+
770
+ // Write key type
771
+ const int32_t k_type_i = (int32_t)k_l[il]->type;
772
+ io.write(&k_type_i, sizeof(k_type_i));
773
+
774
+ // Write row size of key
775
+ const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
776
+ io.write(&k_size_row, sizeof(k_size_row));
777
+
778
+ // Read each range of cells of k_size length each into tmp_buf and write out
779
+ for (const auto & range : cell_ranges) {
780
+ const size_t range_size = range.second - range.first;
781
+ const size_t buf_size = range_size * k_size_row;
782
+ io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
783
+ }
784
+ }
785
+
786
+ if (!v_trans) {
787
+ for (uint32_t il = 0; il < n_layer; ++il) {
788
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
789
+
790
+ // Write value type
791
+ const int32_t v_type_i = (int32_t)v_l[il]->type;
792
+ io.write(&v_type_i, sizeof(v_type_i));
793
+
794
+ // Write row size of value
795
+ const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
796
+ io.write(&v_size_row, sizeof(v_size_row));
797
+
798
+ // Read each range of cells of v_size length each into tmp_buf and write out
799
+ for (const auto & range : cell_ranges) {
800
+ const size_t range_size = range.second - range.first;
801
+ const size_t buf_size = range_size * v_size_row;
802
+ io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
803
+ }
804
+ }
805
+ } else {
806
+ // When v is transposed, we also need the element size and get the element ranges from each row
807
+ const uint32_t kv_size = size;
808
+ for (uint32_t il = 0; il < n_layer; ++il) {
809
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
810
+
811
+ // Write value type
812
+ const int32_t v_type_i = (int32_t)v_l[il]->type;
813
+ io.write(&v_type_i, sizeof(v_type_i));
814
+
815
+ // Write element size
816
+ const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
817
+ io.write(&v_size_el, sizeof(v_size_el));
818
+
819
+ // Write GQA embedding size
820
+ io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
821
+
822
+ // For each row, we get the element values of each cell
823
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
824
+ // Read each range of cells of v_size_el length each into tmp_buf and write out
825
+ for (const auto & range : cell_ranges) {
826
+ const size_t range_size = range.second - range.first;
827
+ const size_t src_offset = (range.first + j * kv_size) * v_size_el;
828
+ const size_t buf_size = range_size * v_size_el;
829
+ io.write_tensor(v_l[il], src_offset, buf_size);
830
+ }
831
+ }
832
+ }
833
+ }
834
+ }
835
+
836
+ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
837
+ if (dest_seq_id != -1) {
838
+ // single sequence
839
+
840
+ seq_rm(dest_seq_id, -1, -1);
841
+
842
+ llama_sbatch sbatch;
843
+ llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
844
+
845
+ batch.n_tokens = cell_count;
846
+ batch.n_seq_tokens = cell_count;
847
+ batch.n_seqs = 1;
848
+
849
+ for (uint32_t i = 0; i < cell_count; ++i) {
850
+ llama_pos pos;
851
+ uint32_t n_seq_id;
852
+
853
+ io.read_to(&pos, sizeof(pos));
854
+ io.read_to(&n_seq_id, sizeof(n_seq_id));
855
+
856
+ if (n_seq_id != 0) {
857
+ LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
858
+ return false;
859
+ }
860
+
861
+ batch.pos[i] = pos;
862
+ }
863
+ batch.n_seq_id[0] = 1;
864
+ batch.seq_id[0] = &dest_seq_id;
865
+
866
+ if (!find_slot(batch)) {
867
+ LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
868
+ return false;
869
+ }
870
+
871
+ // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
872
+ // Assume that this is one contiguous block of cells
873
+ GGML_ASSERT(head + cell_count <= size);
874
+ GGML_ASSERT(cells[head].pos == batch.pos[0]);
875
+ GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
876
+ GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
877
+ GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
878
+ } else {
879
+ // whole KV cache restore
880
+
881
+ if (cell_count > size) {
882
+ LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
883
+ return false;
884
+ }
885
+
886
+ clear();
887
+
888
+ for (uint32_t i = 0; i < cell_count; ++i) {
889
+ kv_cell & cell = cells[i];
890
+
891
+ llama_pos pos;
892
+ uint32_t n_seq_id;
893
+
894
+ io.read_to(&pos, sizeof(pos));
895
+ io.read_to(&n_seq_id, sizeof(n_seq_id));
896
+
897
+ cell.pos = pos;
898
+
899
+ for (uint32_t j = 0; j < n_seq_id; ++j) {
900
+ llama_seq_id seq_id;
901
+ io.read_to(&seq_id, sizeof(seq_id));
902
+
903
+ // TODO: llama_kv_cache_recurrent should have a notion of max sequences
904
+ //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
905
+ if (seq_id < 0) {
906
+ //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
907
+ LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
908
+ return false;
909
+ }
910
+
911
+ cell.seq_id.insert(seq_id);
912
+
913
+ int32_t & tail = cells[seq_id].tail;
914
+ if (tail != -1) {
915
+ LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
916
+ return false;
917
+ }
918
+ tail = i;
919
+ }
920
+ }
921
+
922
+ head = 0;
923
+ used = cell_count;
924
+ }
925
+
926
+ for (uint32_t i = 0; i < cell_count; ++i) {
927
+ uint32_t cell_id = head + i;
928
+ // make sure the recurrent states will keep their restored state
929
+ cells[cell_id].src = cell_id;
930
+ }
931
+
932
+ return true;
933
+ }
934
+
935
+ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
936
+ uint32_t v_trans;
937
+ uint32_t n_layer;
938
+ io.read_to(&v_trans, sizeof(v_trans));
939
+ io.read_to(&n_layer, sizeof(n_layer));
940
+
941
+ if (n_layer != hparams.n_layer) {
942
+ LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
943
+ return false;
944
+ }
945
+ if (cell_count > size) {
946
+ LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
947
+ return false;
948
+ }
949
+ if (false != (bool) v_trans) {
950
+ LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
951
+ return false;
952
+ }
953
+
954
+ // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
955
+ for (uint32_t il = 0; il < n_layer; ++il) {
956
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
957
+
958
+ // Read type of key
959
+ int32_t k_type_i_ref;
960
+ io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
961
+ const int32_t k_type_i = (int32_t) k_l[il]->type;
962
+ if (k_type_i != k_type_i_ref) {
963
+ LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
964
+ return false;
965
+ }
966
+
967
+ // Read row size of key
968
+ uint64_t k_size_row_ref;
969
+ io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
970
+ const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
971
+ if (k_size_row != k_size_row_ref) {
972
+ LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
973
+ return false;
974
+ }
975
+
976
+ if (cell_count) {
977
+ // Read and set the keys for the whole cell range
978
+ ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
979
+ }
980
+ }
981
+
982
+ if (!v_trans) {
983
+ for (uint32_t il = 0; il < n_layer; ++il) {
984
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
985
+
986
+ // Read type of value
987
+ int32_t v_type_i_ref;
988
+ io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
989
+ const int32_t v_type_i = (int32_t)v_l[il]->type;
990
+ if (v_type_i != v_type_i_ref) {
991
+ LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
992
+ return false;
993
+ }
994
+
995
+ // Read row size of value
996
+ uint64_t v_size_row_ref;
997
+ io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
998
+ const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
999
+ if (v_size_row != v_size_row_ref) {
1000
+ LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
1001
+ return false;
1002
+ }
1003
+
1004
+ if (cell_count) {
1005
+ // Read and set the values for the whole cell range
1006
+ ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
1007
+ }
1008
+ }
1009
+ } else {
1010
+ // For each layer, read the values for each cell (transposed)
1011
+ for (uint32_t il = 0; il < n_layer; ++il) {
1012
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1013
+
1014
+ // Read type of value
1015
+ int32_t v_type_i_ref;
1016
+ io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1017
+ const int32_t v_type_i = (int32_t)v_l[il]->type;
1018
+ if (v_type_i != v_type_i_ref) {
1019
+ LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1020
+ return false;
1021
+ }
1022
+
1023
+ // Read element size of value
1024
+ uint32_t v_size_el_ref;
1025
+ io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
1026
+ const size_t v_size_el = ggml_type_size(v_l[il]->type);
1027
+ if (v_size_el != v_size_el_ref) {
1028
+ LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
1029
+ return false;
1030
+ }
1031
+
1032
+ // Read GQA embedding size
1033
+ uint32_t n_embd_v_gqa_ref;
1034
+ io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
1035
+ if (n_embd_v_gqa != n_embd_v_gqa_ref) {
1036
+ LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
1037
+ return false;
1038
+ }
1039
+
1040
+ if (cell_count) {
1041
+ // For each row in the transposed matrix, read the values for the whole cell range
1042
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1043
+ const size_t dst_offset = (head + j * size) * v_size_el;
1044
+ ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1045
+ }
1046
+ }
1047
+ }
1048
+ }
1049
+
1050
+ return true;
1051
+ }
1052
+
1053
+ //
1054
+ // llama_kv_cache_recurrent_state
1055
+ //
1056
+
1057
+ llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {}
1058
+
1059
+ llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
1060
+ llama_memory_status status,
1061
+ llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) {
1062
+ }
1063
+
1064
+ llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
1065
+ llama_memory_status status,
1066
+ llama_kv_cache_recurrent * kv,
1067
+ llama_sbatch sbatch,
1068
+ std::vector<llama_ubatch> ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
1069
+
1070
+ llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default;
1071
+
1072
+ bool llama_kv_cache_recurrent_state::next() {
1073
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1074
+
1075
+ if (++i_next >= ubatches.size()) {
1076
+ return false;
1077
+ }
1078
+
1079
+ return true;
1080
+ }
1081
+
1082
+ bool llama_kv_cache_recurrent_state::apply() {
1083
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1084
+
1085
+ kv->find_slot(ubatches[i_next]);
1086
+
1087
+ return true;
1088
+ }
1089
+
1090
+ std::vector<int64_t> & llama_kv_cache_recurrent_state::out_ids() {
1091
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1092
+
1093
+ return sbatch.out_ids;
1094
+ }
1095
+
1096
+ llama_memory_status llama_kv_cache_recurrent_state::get_status() const {
1097
+ return status;
1098
+ }
1099
+
1100
+ const llama_ubatch & llama_kv_cache_recurrent_state::get_ubatch() const {
1101
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1102
+
1103
+ return ubatches[i_next];
1104
+ }
1105
+
1106
+ uint32_t llama_kv_cache_recurrent_state::get_n_kv() const {
1107
+ return is_full ? kv->size : kv->n;
1108
+ }
1109
+
1110
+ uint32_t llama_kv_cache_recurrent_state::get_head() const {
1111
+ return is_full ? 0 : kv->head;
1112
+ }
1113
+
1114
+ uint32_t llama_kv_cache_recurrent_state::get_size() const {
1115
+ return kv->size;
1116
+ }
1117
+
1118
+ ggml_tensor * llama_kv_cache_recurrent_state::get_k_l(int32_t il) const {
1119
+ return kv->k_l[il];
1120
+ }
1121
+
1122
+ ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
1123
+ return kv->v_l[il];
1124
+ }
1125
+
1126
+ int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
1127
+ return kv->s_copy(i);
1128
+ }
1129
+
1130
+ float llama_kv_cache_recurrent_state::s_mask(int i) const {
1131
+ return kv->s_mask(i);
1132
+ }
examples/talk-llama/llama-kv-cache-recurrent.h ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "llama-batch.h"
4
+ #include "llama-graph.h"
5
+ #include "llama-kv-cache.h"
6
+
7
+ #include <set>
8
+ #include <vector>
9
+
10
+ //
11
+ // llama_kv_cache_recurrent
12
+ //
13
+
14
+ // TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
15
+ // see the implementation of llama_kv_cache_unified_state_i for an example how to do it
16
+ class llama_kv_cache_recurrent : public llama_kv_cache {
17
+ public:
18
+ llama_kv_cache_recurrent(
19
+ const llama_model & model,
20
+ ggml_type type_k,
21
+ ggml_type type_v,
22
+ bool offload,
23
+ uint32_t kv_size,
24
+ uint32_t n_seq_max);
25
+
26
+ ~llama_kv_cache_recurrent() = default;
27
+
28
+ //
29
+ // llama_memory_i
30
+ //
31
+
32
+ void clear() override;
33
+
34
+ bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
35
+ void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
36
+ void seq_keep(llama_seq_id seq_id) override;
37
+ void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
38
+ void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
39
+
40
+ llama_pos seq_pos_min(llama_seq_id seq_id) const override;
41
+ llama_pos seq_pos_max(llama_seq_id seq_id) const override;
42
+
43
+ //
44
+ // llama_kv_cache
45
+ //
46
+
47
+ llama_memory_state_ptr init_batch(
48
+ const llama_batch & batch,
49
+ uint32_t n_ubatch,
50
+ bool embd_pooled,
51
+ bool logits_all) override;
52
+
53
+ llama_memory_state_ptr init_full() override;
54
+
55
+ bool update(llama_context & lctx) override;
56
+
57
+ void defrag_sched(float thold) override;
58
+
59
+ bool prepare(const std::vector<llama_ubatch> & ubatches);
60
+
61
+ // find a contiguous slot of kv cells and emplace the ubatch there
62
+ bool find_slot(const llama_ubatch & ubatch);
63
+
64
+ bool get_can_shift() const override;
65
+
66
+ // TODO: temporary methods - they are not really const as they do const_cast<>, fix this
67
+ int32_t s_copy(int i) const;
68
+ float s_mask(int i) const;
69
+
70
+ // state write/load
71
+
72
+ void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
73
+ void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
74
+
75
+ uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
76
+ uint32_t size = 0; // total number of cells, shared across all sequences
77
+ uint32_t used = 0; // used cells (i.e. at least one seq_id)
78
+
79
+ // computed before each graph build
80
+ uint32_t n = 0;
81
+
82
+ // TODO: optimize for recurrent state needs
83
+ struct kv_cell {
84
+ llama_pos pos = -1;
85
+ int32_t src = -1; // used to copy states
86
+ int32_t tail = -1;
87
+
88
+ std::set<llama_seq_id> seq_id;
89
+
90
+ bool has_seq_id(const llama_seq_id & id) const {
91
+ return seq_id.find(id) != seq_id.end();
92
+ }
93
+
94
+ bool is_empty() const {
95
+ return seq_id.empty();
96
+ }
97
+
98
+ bool is_same_seq(const kv_cell & other) const {
99
+ return seq_id == other.seq_id;
100
+ }
101
+ };
102
+
103
+ std::vector<kv_cell> cells;
104
+
105
+ std::vector<ggml_tensor *> k_l; // per layer
106
+ std::vector<ggml_tensor *> v_l;
107
+
108
+ private:
109
+ //const llama_model & model;
110
+ const llama_hparams & hparams;
111
+
112
+ const uint32_t n_seq_max = 1;
113
+
114
+ std::vector<ggml_context_ptr> ctxs;
115
+ std::vector<ggml_backend_buffer_ptr> bufs;
116
+
117
+ size_t total_size() const;
118
+
119
+ size_t size_k_bytes() const;
120
+ size_t size_v_bytes() const;
121
+
122
+ void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
123
+ void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
124
+
125
+ bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
126
+ bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
127
+ };
128
+
129
+ class llama_kv_cache_recurrent_state : public llama_memory_state_i {
130
+ public:
131
+ // used for errors
132
+ llama_kv_cache_recurrent_state(llama_memory_status status);
133
+
134
+ // used to create a full-cache state
135
+ llama_kv_cache_recurrent_state(
136
+ llama_memory_status status,
137
+ llama_kv_cache_recurrent * kv);
138
+
139
+ // used to create a state from a batch
140
+ llama_kv_cache_recurrent_state(
141
+ llama_memory_status status,
142
+ llama_kv_cache_recurrent * kv,
143
+ llama_sbatch sbatch,
144
+ std::vector<llama_ubatch> ubatches);
145
+
146
+ virtual ~llama_kv_cache_recurrent_state();
147
+
148
+ //
149
+ // llama_memory_state_i
150
+ //
151
+
152
+ bool next() override;
153
+ bool apply() override;
154
+
155
+ std::vector<int64_t> & out_ids() override;
156
+
157
+ llama_memory_status get_status() const override;
158
+ const llama_ubatch & get_ubatch() const override;
159
+
160
+ //
161
+ // llama_kv_cache_recurrent_state specific API
162
+ //
163
+
164
+ uint32_t get_n_kv() const;
165
+ uint32_t get_head() const;
166
+ uint32_t get_size() const;
167
+
168
+ ggml_tensor * get_k_l(int32_t il) const;
169
+ ggml_tensor * get_v_l(int32_t il) const;
170
+
171
+ int32_t s_copy(int i) const;
172
+ float s_mask(int i) const;
173
+
174
+ private:
175
+ const llama_memory_status status;
176
+
177
+ llama_kv_cache_recurrent * kv;
178
+
179
+ llama_sbatch sbatch;
180
+
181
+ size_t i_next = 0;
182
+
183
+ std::vector<llama_ubatch> ubatches;
184
+
185
+ //
186
+ // data needed for building the compute graph for the current ubatch:
187
+ // TODO: extract all the state like `head` and `n` here
188
+ //
189
+
190
+ const bool is_full = false;
191
+ };
examples/talk-llama/llama-kv-cache-unified-iswa.cpp ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "llama-kv-cache-unified-iswa.h"
2
+
3
+ #include "llama-impl.h"
4
+ #include "llama-batch.h"
5
+ #include "llama-model.h"
6
+
7
+ #include <algorithm>
8
+ #include <cassert>
9
+
10
+ //
11
+ // llama_kv_cache_unified_iswa
12
+ //
13
+
14
+ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
15
+ const llama_model & model,
16
+ ggml_type type_k,
17
+ ggml_type type_v,
18
+ bool v_trans,
19
+ bool offload,
20
+ bool swa_full,
21
+ uint32_t kv_size,
22
+ uint32_t n_seq_max,
23
+ uint32_t n_ubatch,
24
+ uint32_t n_pad) : hparams(model.hparams) {
25
+ llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
26
+ llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
27
+
28
+ const uint32_t size_base = kv_size;
29
+
30
+ uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
31
+
32
+ // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
33
+ if (swa_full) {
34
+ LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
35
+ __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
36
+
37
+ size_swa = size_base;
38
+ }
39
+
40
+ LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
41
+
42
+ kv_base = std::make_unique<llama_kv_cache_unified>(
43
+ model, std::move(filter_base), type_k, type_v,
44
+ v_trans, offload, size_base, n_seq_max, n_pad,
45
+ 0, LLAMA_SWA_TYPE_NONE);
46
+
47
+ LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
48
+
49
+ kv_swa = std::make_unique<llama_kv_cache_unified>(
50
+ model, std::move(filter_swa), type_k, type_v,
51
+ v_trans, offload, size_swa, n_seq_max, n_pad,
52
+ hparams.n_swa, hparams.swa_type);
53
+ }
54
+
55
+ void llama_kv_cache_unified_iswa::clear() {
56
+ kv_base->clear();
57
+ kv_swa ->clear();
58
+ }
59
+
60
+ bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
61
+ bool res = true;
62
+
63
+ res = res & kv_base->seq_rm(seq_id, p0, p1);
64
+ res = res & kv_swa ->seq_rm(seq_id, p0, p1);
65
+
66
+ return res;
67
+ }
68
+
69
+ void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
70
+ kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
71
+ kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
72
+ }
73
+
74
+ void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
75
+ kv_base->seq_keep(seq_id);
76
+ kv_swa ->seq_keep(seq_id);
77
+ }
78
+
79
+ void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
80
+ kv_base->seq_add(seq_id, p0, p1, shift);
81
+ kv_swa ->seq_add(seq_id, p0, p1, shift);
82
+ }
83
+
84
+ void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
85
+ kv_base->seq_div(seq_id, p0, p1, d);
86
+ kv_swa ->seq_div(seq_id, p0, p1, d);
87
+ }
88
+
89
+ llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
90
+ // the base cache is a superset of the SWA cache, so we can just check the SWA cache
91
+ return kv_swa->seq_pos_min(seq_id);
92
+ }
93
+
94
+ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
95
+ return kv_swa->seq_pos_max(seq_id);
96
+ }
97
+
98
+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
99
+ GGML_UNUSED(embd_pooled);
100
+
101
+ // TODO: if we fail with split_simple, we should attempt different splitting strategies
102
+ // but to do that properly, we first have to refactor the batches to be more flexible
103
+
104
+ auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
105
+
106
+ std::vector<llama_ubatch> ubatches;
107
+
108
+ while (sbatch.n_tokens > 0) {
109
+ auto ubatch = sbatch.split_simple(n_ubatch);
110
+
111
+ ubatches.push_back(ubatch);
112
+ }
113
+
114
+ auto heads_base = kv_base->prepare(ubatches);
115
+ if (heads_base.empty()) {
116
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
117
+ }
118
+
119
+ auto heads_swa = kv_swa->prepare(ubatches);
120
+ if (heads_swa.empty()) {
121
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
122
+ }
123
+
124
+ assert(heads_base.size() == heads_swa.size());
125
+
126
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS,
127
+ this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
128
+ }
129
+
130
+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
131
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
132
+ }
133
+
134
+ bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
135
+ bool res = false;
136
+
137
+ res = res | kv_base->update(lctx);
138
+ res = res | kv_swa ->update(lctx);
139
+
140
+ return res;
141
+ }
142
+
143
+ void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
144
+ kv_base->defrag_sched(thold);
145
+ kv_swa ->defrag_sched(thold);
146
+ }
147
+
148
+ bool llama_kv_cache_unified_iswa::get_can_shift() const {
149
+ return kv_base->get_size() == kv_swa->get_size();
150
+ }
151
+
152
+ void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
153
+ kv_base->state_write(io, seq_id);
154
+ kv_swa ->state_write(io, seq_id);
155
+ }
156
+
157
+ void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
158
+ kv_base->state_read(io, seq_id);
159
+ kv_swa ->state_read(io, seq_id);
160
+ }
161
+
162
+ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
163
+ return kv_base.get();
164
+ }
165
+
166
+ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
167
+ return kv_swa.get();
168
+ }
169
+
170
+ //
171
+ // llama_kv_cache_unified_iswa_state
172
+ //
173
+
174
+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
175
+
176
+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
177
+ llama_memory_status status,
178
+ llama_kv_cache_unified_iswa * kv) : status(status) {
179
+ state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base()));
180
+ state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ()));
181
+ }
182
+
183
+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
184
+ llama_memory_status status,
185
+ llama_kv_cache_unified_iswa * kv,
186
+ llama_sbatch sbatch,
187
+ std::vector<uint32_t> heads_base,
188
+ std::vector<uint32_t> heads_swa,
189
+ std::vector<llama_ubatch> ubatches)
190
+ : status(status),
191
+ sbatch(std::move(sbatch)),
192
+ ubatches(std::move(ubatches)) {
193
+ // note: here we copy the ubatches. not sure if this is ideal
194
+ state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches));
195
+ state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
196
+ }
197
+
198
+ llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
199
+
200
+ bool llama_kv_cache_unified_iswa_state::next() {
201
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
202
+
203
+ state_base->next();
204
+ state_swa ->next();
205
+
206
+ if (++i_next >= ubatches.size()) {
207
+ return false;
208
+ }
209
+
210
+ return true;
211
+ }
212
+
213
+ bool llama_kv_cache_unified_iswa_state::apply() {
214
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
215
+
216
+ bool res = true;
217
+
218
+ res = res & state_base->apply();
219
+ res = res & state_swa ->apply();
220
+
221
+ return res;
222
+ }
223
+
224
+ std::vector<int64_t> & llama_kv_cache_unified_iswa_state::out_ids() {
225
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
226
+
227
+ return sbatch.out_ids;
228
+ }
229
+
230
+ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
231
+ return status;
232
+ }
233
+
234
+ const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
235
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
236
+ return ubatches[i_next];
237
+ }
238
+
239
+ const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
240
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
241
+
242
+ return state_base.get();
243
+ }
244
+
245
+ const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
246
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
247
+
248
+ return state_swa.get();
249
+ }
examples/talk-llama/llama-kv-cache-unified-iswa.h ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "llama-kv-cache-unified.h"
4
+
5
+ #include <vector>
6
+
7
+ //
8
+ // llama_kv_cache_unified_iswa
9
+ //
10
+
11
+ // utilizes two instances of llama_kv_cache_unified
12
+ // the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
13
+
14
+ class llama_kv_cache_unified_iswa : public llama_kv_cache {
15
+ public:
16
+ llama_kv_cache_unified_iswa(
17
+ const llama_model & model,
18
+ ggml_type type_k,
19
+ ggml_type type_v,
20
+ bool v_trans,
21
+ bool offload,
22
+ bool swa_full,
23
+ uint32_t kv_size,
24
+ uint32_t n_seq_max,
25
+ uint32_t n_ubatch,
26
+ uint32_t n_pad);
27
+
28
+ ~llama_kv_cache_unified_iswa() = default;
29
+
30
+ //
31
+ // llama_memory_i
32
+ //
33
+
34
+ void clear() override;
35
+
36
+ bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
37
+ void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
38
+ void seq_keep(llama_seq_id seq_id) override;
39
+ void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
40
+ void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
41
+
42
+ llama_pos seq_pos_min(llama_seq_id seq_id) const override;
43
+ llama_pos seq_pos_max(llama_seq_id seq_id) const override;
44
+
45
+ //
46
+ // llama_kv_cache
47
+ //
48
+
49
+ llama_memory_state_ptr init_batch(
50
+ const llama_batch & batch,
51
+ uint32_t n_ubatch,
52
+ bool embd_pooled,
53
+ bool logits_all) override;
54
+
55
+ llama_memory_state_ptr init_full() override;
56
+
57
+ bool update(llama_context & lctx) override;
58
+
59
+ void defrag_sched(float thold) override;
60
+
61
+ bool get_can_shift() const override;
62
+
63
+ // state write/load
64
+
65
+ void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
66
+ void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
67
+
68
+ //
69
+ // llama_kv_cache_unified_iswa specific API
70
+ //
71
+
72
+ llama_kv_cache_unified * get_base() const;
73
+ llama_kv_cache_unified * get_swa () const;
74
+
75
+ private:
76
+ const llama_hparams & hparams;
77
+
78
+ std::unique_ptr<llama_kv_cache_unified> kv_base;
79
+ std::unique_ptr<llama_kv_cache_unified> kv_swa;
80
+ };
81
+
82
+ class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
83
+ public:
84
+ // used for errors
85
+ llama_kv_cache_unified_iswa_state(llama_memory_status status);
86
+
87
+ // used to create a full-cache state
88
+ llama_kv_cache_unified_iswa_state(
89
+ llama_memory_status status,
90
+ llama_kv_cache_unified_iswa * kv);
91
+
92
+ // used to create a state from a batch
93
+ llama_kv_cache_unified_iswa_state(
94
+ llama_memory_status status,
95
+ llama_kv_cache_unified_iswa * kv,
96
+ llama_sbatch sbatch,
97
+ std::vector<uint32_t> heads_base,
98
+ std::vector<uint32_t> heads_swa,
99
+ std::vector<llama_ubatch> ubatches);
100
+
101
+ virtual ~llama_kv_cache_unified_iswa_state();
102
+
103
+ //
104
+ // llama_memory_state_i
105
+ //
106
+
107
+ bool next() override;
108
+ bool apply() override;
109
+
110
+ std::vector<int64_t> & out_ids() override;
111
+
112
+ llama_memory_status get_status() const override;
113
+ const llama_ubatch & get_ubatch() const override;
114
+
115
+ //
116
+ // llama_kv_cache_unified_iswa_state specific API
117
+ //
118
+
119
+ const llama_kv_cache_unified_state * get_base() const;
120
+ const llama_kv_cache_unified_state * get_swa() const;
121
+
122
+ private:
123
+ const llama_memory_status status;
124
+
125
+ //llama_kv_cache_unified_iswa * kv;
126
+
127
+ llama_sbatch sbatch;
128
+
129
+ // the index of the next ubatch to process
130
+ size_t i_next = 0;
131
+
132
+ std::vector<llama_ubatch> ubatches;
133
+
134
+ std::unique_ptr<llama_kv_cache_unified_state> state_base;
135
+ std::unique_ptr<llama_kv_cache_unified_state> state_swa;
136
+ };
examples/talk-llama/llama-kv-cache-unified.cpp ADDED
@@ -0,0 +1,1717 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "llama-kv-cache-unified.h"
2
+
3
+ #include "llama-impl.h"
4
+ #include "llama-model.h"
5
+ #include "llama-context.h"
6
+
7
+ #include <algorithm>
8
+ #include <cassert>
9
+ #include <cmath>
10
+ #include <limits>
11
+ #include <map>
12
+ #include <stdexcept>
13
+
14
+ //
15
+ // llama_kv_cache_unified
16
+ //
17
+
18
+ llama_kv_cache_unified::llama_kv_cache_unified(
19
+ const llama_model & model,
20
+ layer_filter_cb && filter,
21
+ ggml_type type_k,
22
+ ggml_type type_v,
23
+ bool v_trans,
24
+ bool offload,
25
+ uint32_t kv_size,
26
+ uint32_t n_seq_max,
27
+ uint32_t n_pad,
28
+ uint32_t n_swa,
29
+ llama_swa_type swa_type) :
30
+ model(model), hparams(model.hparams), v_trans(v_trans),
31
+ n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
32
+
33
+ GGML_ASSERT(kv_size % n_pad == 0);
34
+
35
+ // create a context for each buffer type
36
+ std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
37
+ auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
38
+ auto it = ctx_map.find(buft);
39
+ if (it == ctx_map.end()) {
40
+ ggml_init_params params = {
41
+ /*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()),
42
+ /*.mem_buffer =*/ NULL,
43
+ /*.no_alloc =*/ true,
44
+ };
45
+
46
+ ggml_context * ctx = ggml_init(params);
47
+ if (!ctx) {
48
+ return nullptr;
49
+ }
50
+
51
+ ctx_map[buft] = ctx;
52
+ ctxs.emplace_back(ctx);
53
+
54
+ return ctx;
55
+ }
56
+
57
+ return it->second;
58
+ };
59
+
60
+ head = 0;
61
+
62
+ cells.resize(kv_size);
63
+
64
+ for (uint32_t il = 0; il < hparams.n_layer; il++) {
65
+ if (filter && !filter(il)) {
66
+ LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
67
+ continue;
68
+ }
69
+
70
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
71
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
72
+
73
+ const char * dev_name = "CPU";
74
+
75
+ ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
76
+
77
+ if (offload) {
78
+ auto * dev = model.dev_layer(il);
79
+ buft = ggml_backend_dev_buffer_type(dev);
80
+
81
+ dev_name = ggml_backend_dev_name(dev);
82
+ }
83
+
84
+ LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name);
85
+
86
+ ggml_context * ctx = ctx_for_buft(buft);
87
+ if (!ctx) {
88
+ throw std::runtime_error("failed to create ggml context for kv cache");
89
+ }
90
+
91
+ ggml_tensor * k;
92
+ ggml_tensor * v;
93
+
94
+ k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
95
+ v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
96
+
97
+ ggml_format_name(k, "cache_k_l%d", il);
98
+ ggml_format_name(v, "cache_v_l%d", il);
99
+
100
+ map_layer_ids[il] = layers.size();
101
+ layers.push_back({ il, k, v });
102
+ }
103
+
104
+ // allocate tensors and initialize the buffers to avoid NaNs in the padding
105
+ for (auto it : ctx_map) {
106
+ auto * buft = it.first;
107
+ auto * ctx = it.second;
108
+
109
+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
110
+ if (!buf) {
111
+ throw std::runtime_error("failed to allocate buffer for kv cache");
112
+ }
113
+
114
+ LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
115
+
116
+ ggml_backend_buffer_clear(buf, 0);
117
+ bufs.emplace_back(buf);
118
+ }
119
+
120
+ {
121
+ const size_t memory_size_k = size_k_bytes();
122
+ const size_t memory_size_v = size_v_bytes();
123
+
124
+ LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
125
+ (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
126
+ ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
127
+ ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
128
+ }
129
+ }
130
+
131
+ void llama_kv_cache_unified::clear() {
132
+ cells.reset();
133
+
134
+ head = 0;
135
+
136
+ for (auto & buf : bufs) {
137
+ ggml_backend_buffer_clear(buf.get(), 0);
138
+ }
139
+ }
140
+
141
+ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
142
+ uint32_t new_head = cells.size();
143
+
144
+ if (p0 < 0) {
145
+ p0 = 0;
146
+ }
147
+
148
+ if (p1 < 0) {
149
+ p1 = std::numeric_limits<llama_pos>::max();
150
+ }
151
+
152
+ for (uint32_t i = 0; i < cells.size(); ++i) {
153
+ if (!cells.pos_in(i, p0, p1)) {
154
+ continue;
155
+ }
156
+
157
+ if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
158
+ if (new_head == cells.size()) {
159
+ new_head = i;
160
+ }
161
+ }
162
+ }
163
+
164
+ // If we freed up a slot, set head to it so searching can start there.
165
+ if (new_head != cells.size() && new_head < head) {
166
+ head = new_head;
167
+ }
168
+
169
+ return true;
170
+ }
171
+
172
+ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
173
+ if (seq_id_src == seq_id_dst) {
174
+ return;
175
+ }
176
+
177
+ if (p0 < 0) {
178
+ p0 = 0;
179
+ }
180
+
181
+ if (p1 < 0) {
182
+ p1 = std::numeric_limits<llama_pos>::max();
183
+ }
184
+
185
+ for (uint32_t i = 0; i < cells.size(); ++i) {
186
+ if (!cells.pos_in(i, p0, p1)) {
187
+ continue;
188
+ }
189
+
190
+ if (cells.seq_has(i, seq_id_src)) {
191
+ cells.seq_add(i, seq_id_dst);
192
+ }
193
+ }
194
+ }
195
+
196
+ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
197
+ uint32_t new_head = cells.size();
198
+
199
+ for (uint32_t i = 0; i < cells.size(); ++i) {
200
+ if (cells.seq_keep(i, seq_id)) {
201
+ if (new_head == cells.size()) {
202
+ new_head = i;
203
+ }
204
+ }
205
+ }
206
+
207
+ // If we freed up a slot, set head to it so searching can start there.
208
+ if (new_head != cells.size() && new_head < head) {
209
+ head = new_head;
210
+ }
211
+ }
212
+
213
+ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
214
+ if (shift == 0) {
215
+ return;
216
+ }
217
+
218
+ uint32_t new_head = cells.size();
219
+
220
+ if (p0 < 0) {
221
+ p0 = 0;
222
+ }
223
+
224
+ if (p1 < 0) {
225
+ p1 = std::numeric_limits<llama_pos>::max();
226
+ }
227
+
228
+ // If there is no range then return early to avoid looping over all cells.
229
+ if (p0 == p1) {
230
+ return;
231
+ }
232
+
233
+ for (uint32_t i = 0; i < cells.size(); ++i) {
234
+ if (!cells.pos_in(i, p0, p1)) {
235
+ continue;
236
+ }
237
+
238
+ if (cells.seq_has(i, seq_id)) {
239
+ if (cells.pos_add(i, shift)) {
240
+ if (new_head == cells.size()) {
241
+ new_head = i;
242
+ }
243
+ }
244
+ }
245
+ }
246
+
247
+ // If we freed up a slot, set head to it so searching can start there.
248
+ // Otherwise we just start the next search from the beginning.
249
+ head = new_head != cells.size() ? new_head : 0;
250
+ }
251
+
252
+ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
253
+ if (d == 1) {
254
+ return;
255
+ }
256
+
257
+ if (p0 < 0) {
258
+ p0 = 0;
259
+ }
260
+
261
+ if (p1 < 0) {
262
+ p1 = std::numeric_limits<llama_pos>::max();
263
+ }
264
+
265
+ // If there is no range then return early to avoid looping over the cache.
266
+ if (p0 == p1) {
267
+ return;
268
+ }
269
+
270
+ for (uint32_t i = 0; i < cells.size(); ++i) {
271
+ if (!cells.pos_in(i, p0, p1)) {
272
+ continue;
273
+ }
274
+
275
+ if (cells.seq_has(i, seq_id)) {
276
+ cells.pos_div(i, d);
277
+ }
278
+ }
279
+ }
280
+
281
+ llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
282
+ return cells.seq_pos_min(seq_id);
283
+ }
284
+
285
+ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
286
+ return cells.seq_pos_max(seq_id);
287
+ }
288
+
289
+ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
290
+ const llama_batch & batch,
291
+ uint32_t n_ubatch,
292
+ bool embd_pooled,
293
+ bool logits_all) {
294
+ GGML_UNUSED(embd_pooled);
295
+
296
+ auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
297
+
298
+ std::vector<llama_ubatch> ubatches;
299
+ while (sbatch.n_tokens > 0) {
300
+ ubatches.push_back(sbatch.split_simple(n_ubatch));
301
+ }
302
+
303
+ auto heads = prepare(ubatches);
304
+ if (heads.empty()) {
305
+ return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
306
+ }
307
+
308
+ return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS,
309
+ this, std::move(sbatch), std::move(heads), std::move(ubatches));
310
+ }
311
+
312
+ llama_memory_state_ptr llama_kv_cache_unified::init_full() {
313
+ return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
314
+ }
315
+
316
+ std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
317
+ std::vector<uint32_t> res;
318
+
319
+ struct state {
320
+ uint32_t head_old; // old position of the head, before placing the ubatch
321
+ uint32_t head_new; // new position of the head, after placing the ubatch
322
+
323
+ llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
324
+ };
325
+
326
+ // remember the old state of the cells so we can restore it in the end
327
+ std::vector<state> states;
328
+
329
+ bool success = true;
330
+
331
+ for (const auto & ubatch : ubatches) {
332
+ // only find a suitable slot for the ubatch. don't modify the cells yet
333
+ const int32_t head_new = find_slot(ubatch);
334
+ if (head_new < 0) {
335
+ success = false;
336
+ break;
337
+ }
338
+
339
+ // remeber the position that we found
340
+ res.push_back(head_new);
341
+
342
+ // store the old state of the cells in the recovery stack
343
+ states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)});
344
+
345
+ // now emplace the ubatch
346
+ apply_ubatch(head_new, ubatch);
347
+ }
348
+
349
+ // iterate backwards and restore the cells to their original state
350
+ for (auto it = states.rbegin(); it != states.rend(); ++it) {
351
+ cells.set(it->head_new, it->cells);
352
+ head = it->head_old;
353
+ }
354
+
355
+ if (!success) {
356
+ return {};
357
+ }
358
+
359
+ return res;
360
+ }
361
+
362
+ bool llama_kv_cache_unified::update(llama_context & lctx) {
363
+ bool updated = false;
364
+
365
+ auto * sched = lctx.get_sched();
366
+
367
+ if (cells.get_has_shift()) {
368
+ if (!get_can_shift()) {
369
+ GGML_ABORT("The current KV cache / model configuration does not support K-shift");
370
+ }
371
+
372
+ LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
373
+
374
+ // apply K-shift if needed
375
+ if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
376
+ ggml_backend_sched_reset(sched);
377
+
378
+ auto * gf = lctx.graph_init();
379
+
380
+ auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
381
+ if (!res) {
382
+ LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
383
+ return updated;
384
+ }
385
+
386
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
387
+ LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
388
+ return updated;
389
+ }
390
+
391
+ res->set_inputs(nullptr);
392
+
393
+ if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
394
+ LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
395
+ return updated;
396
+ }
397
+
398
+ updated = true;
399
+ }
400
+
401
+ cells.reset_shift();
402
+ }
403
+
404
+ if (do_defrag) {
405
+ LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
406
+
407
+ if (defrag_prepare(lctx.graph_max_nodes())) {
408
+ ggml_backend_sched_reset(sched);
409
+
410
+ auto * gf = lctx.graph_init();
411
+
412
+ auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
413
+ if (!res) {
414
+ LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
415
+ return updated;
416
+ }
417
+
418
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
419
+ LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
420
+ return updated;
421
+ }
422
+
423
+ res->set_inputs(nullptr);
424
+
425
+ if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
426
+ LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
427
+ return updated;
428
+ }
429
+
430
+ updated = true;
431
+ }
432
+
433
+ do_defrag = false;
434
+ }
435
+
436
+ return updated;
437
+ }
438
+
439
+ void llama_kv_cache_unified::defrag_sched(float thold) {
440
+ const auto n_kv = cells.used_max_p1();
441
+
442
+ // - do not defrag small contexts (i.e. < 2048 tokens)
443
+ // - count the padding towards the number of used tokens
444
+ const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
445
+
446
+ // queue defragmentation for next llama_kv_cache_update
447
+ if (fragmentation > thold) {
448
+ LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
449
+
450
+ do_defrag = true;
451
+ }
452
+ }
453
+
454
+ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
455
+ const uint32_t n_tokens = ubatch.n_tokens;
456
+
457
+ uint32_t head_cur = this->head;
458
+
459
+ // if we have enough unused cells before the current head ->
460
+ // better to start searching from the beginning of the cache, hoping to fill it
461
+ if (head_cur > cells.get_used() + 2*ubatch.n_tokens) {
462
+ head_cur = 0;
463
+ }
464
+
465
+ // otherwise, one cell per token.
466
+
467
+ if (n_tokens > cells.size()) {
468
+ LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
469
+ return -1;
470
+ }
471
+
472
+ //#define FIND_SLOT_DEBUG 1
473
+ #if FIND_SLOT_DEBUG
474
+ LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", cells.used_max_p1(), cells.get_used(), head, n_swa);
475
+
476
+ // for debugging
477
+ {
478
+ std::string ss;
479
+ if (n_swa > 0) {
480
+ for (uint32_t i = 0; i < cells.size(); ++i) {
481
+ if (cells.is_empty(i)) {
482
+ ss += '.';
483
+ } else {
484
+ ss += std::to_string(cells.seq_get(i));
485
+ }
486
+ if (i%256 == 255) {
487
+ ss += '\n';
488
+ }
489
+ }
490
+ }
491
+ LLAMA_LOG_WARN("\n%s\n", ss.c_str());
492
+ }
493
+
494
+ for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
495
+ if (cells.seq_pos_min(s) < 0) {
496
+ continue;
497
+ }
498
+
499
+ LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[%d] = %5d, max[%d] = %5d\n", n_swa, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
500
+ }
501
+ #endif
502
+
503
+ uint32_t n_tested = 0;
504
+
505
+ while (true) {
506
+ if (head_cur + n_tokens > cells.size()) {
507
+ n_tested += cells.size() - head_cur;
508
+ head_cur = 0;
509
+ continue;
510
+ }
511
+
512
+ // keep track of what the minimum sequence positions would be if we accept the ubatch
513
+ llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
514
+ for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
515
+ seq_pos_min[s] = cells.seq_pos_min(s);
516
+ }
517
+
518
+ bool found = true;
519
+ for (uint32_t i = 0; i < n_tokens; i++) {
520
+ const llama_pos pos = ubatch.pos[i];
521
+ const llama_seq_id seq_id = ubatch.seq_id[i][0];
522
+
523
+ // can we use this cell? either:
524
+ // - the cell is empty
525
+ // - the cell is occupied only by one sequence:
526
+ // - mask causally, if the sequence is the same as the one we are inserting
527
+ // - mask SWA, using current max pos for that sequence in the cache
528
+ // always insert in the cell with minimum pos
529
+ bool can_use = cells.is_empty(head_cur + i);
530
+
531
+ if (!can_use && cells.seq_count(head_cur + i) == 1) {
532
+ const llama_pos pos_cell = cells.pos_get(head_cur + i);
533
+
534
+ // causal mask
535
+ if (cells.seq_has(head_cur + i, seq_id)) {
536
+ can_use = pos_cell >= pos;
537
+ }
538
+
539
+ if (!can_use) {
540
+ const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
541
+
542
+ // SWA mask
543
+ // note: we insert only in the cell with minimum pos in order to preserve the invariant that
544
+ // all positions between [pos_min, pos_max] for each sequence will be present in the cache
545
+ // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
546
+ if (pos_cell == seq_pos_min[seq_id_cell] &&
547
+ is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
548
+ seq_pos_min[seq_id_cell]++;
549
+ can_use = true;
550
+ }
551
+ }
552
+ }
553
+
554
+ if (!can_use) {
555
+ found = false;
556
+ head_cur += i + 1;
557
+ n_tested += i + 1;
558
+ break;
559
+ }
560
+ }
561
+
562
+ if (found) {
563
+ break;
564
+ }
565
+
566
+ if (n_tested >= cells.size()) {
567
+ //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
568
+ return -1;
569
+ }
570
+ }
571
+
572
+ return head_cur;
573
+ }
574
+
575
+ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
576
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
577
+ if (!cells.is_empty(head_cur + i)) {
578
+ cells.rm(head_cur + i);
579
+ }
580
+
581
+ cells.pos_set(head_cur + i, ubatch.pos[i]);
582
+
583
+ for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
584
+ cells.seq_add(head_cur + i, ubatch.seq_id[i][j]);
585
+ }
586
+ }
587
+
588
+ // move the head at the end of the slot
589
+ head = head_cur + ubatch.n_tokens;
590
+ }
591
+
592
+ bool llama_kv_cache_unified::get_can_shift() const {
593
+ return true;
594
+ }
595
+
596
+ uint32_t llama_kv_cache_unified::get_size() const {
597
+ return cells.size();
598
+ }
599
+
600
+ uint32_t llama_kv_cache_unified::get_n_kv() const {
601
+ return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
602
+ }
603
+
604
+ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
605
+ const int32_t ikv = map_layer_ids.at(il);
606
+
607
+ auto * k = layers[ikv].k;
608
+
609
+ return ggml_view_3d(ctx, k,
610
+ hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv,
611
+ ggml_row_size(k->type, hparams.n_embd_head_k),
612
+ ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
613
+ 0);
614
+ }
615
+
616
+ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
617
+ const int32_t ikv = map_layer_ids.at(il);
618
+
619
+ auto * v = layers[ikv].v;
620
+
621
+ if (!v_trans) {
622
+ // note: v->nb[1] <= v->nb[2]
623
+ return ggml_view_3d(ctx, v,
624
+ hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv,
625
+ ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
626
+ ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
627
+ 0);
628
+ }
629
+
630
+ // note: v->nb[1] > v->nb[2]
631
+ return ggml_view_3d(ctx, v,
632
+ n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v,
633
+ ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
634
+ ggml_row_size(v->type, v->ne[1]), // v->nb[2]
635
+ 0);
636
+ }
637
+
638
+ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const {
639
+ const int32_t ikv = map_layer_ids.at(il);
640
+
641
+ auto * k = layers[ikv].k;
642
+
643
+ const int64_t n_tokens = k_cur->ne[2];
644
+
645
+ ggml_tensor * k_view = ggml_view_1d(ctx, k,
646
+ n_tokens*hparams.n_embd_k_gqa(il),
647
+ ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur);
648
+
649
+ return ggml_cpy(ctx, k_cur, k_view);
650
+ }
651
+
652
+ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const {
653
+ const int32_t ikv = map_layer_ids.at(il);
654
+
655
+ auto * v = layers[ikv].v;
656
+
657
+ const int64_t n_tokens = v_cur->ne[2];
658
+
659
+ v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens);
660
+
661
+ ggml_tensor * v_view = nullptr;
662
+
663
+ if (!v_trans) {
664
+ v_view = ggml_view_1d(ctx, v,
665
+ n_tokens*hparams.n_embd_v_gqa(il),
666
+ ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur);
667
+ } else {
668
+ // note: the V cache is transposed when not using flash attention
669
+ v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
670
+ (v->ne[1])*ggml_element_size(v),
671
+ (head_cur)*ggml_element_size(v));
672
+
673
+ v_cur = ggml_transpose(ctx, v_cur);
674
+ }
675
+
676
+ return ggml_cpy(ctx, v_cur, v_view);
677
+ }
678
+
679
+ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
680
+ const int64_t n_tokens = ubatch->n_tokens;
681
+ const int64_t n_seq_tokens = ubatch->n_seq_tokens;
682
+ const int64_t n_seqs = ubatch->n_seqs;
683
+
684
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
685
+ float * data = (float *) dst->data;
686
+
687
+ const auto n_kv = dst->ne[0];
688
+
689
+ // Use only the previous KV cells of the correct sequence for each token of the ubatch.
690
+ // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
691
+ // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
692
+ // Causal mask:
693
+ // xxx-------
694
+ // xxxx------
695
+ // xxxxx-----
696
+ // Non-causal mask:
697
+ // xxxxx-----
698
+ // xxxxx-----
699
+ // xxxxx-----
700
+ // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
701
+ for (int h = 0; h < 1; ++h) {
702
+ for (int s = 0; s < n_seqs; ++s) {
703
+ const llama_seq_id seq_id = ubatch->seq_id[s][0];
704
+
705
+ for (int j = 0; j < n_seq_tokens; ++j) {
706
+ const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
707
+
708
+ for (uint32_t i = 0; i < n_kv; ++i) {
709
+ float f = 0.0f;
710
+
711
+ bool masked = false;
712
+
713
+ if (cells.is_empty(i)) {
714
+ masked = true;
715
+ } else {
716
+ const llama_pos p0 = cells.pos_get(i);
717
+
718
+ // mask the token if not the same sequence
719
+ masked = masked || (!cells.seq_has(i, seq_id));
720
+
721
+ // mask future tokens
722
+ masked = masked || (causal_attn && p0 > p1);
723
+
724
+ // apply SWA if any
725
+ masked = masked || (is_masked_swa(p0, p1));
726
+
727
+ if (!masked && hparams.use_alibi) {
728
+ f = -std::abs(p0 - p1);
729
+ }
730
+ }
731
+
732
+ if (masked) {
733
+ f = -INFINITY;
734
+ }
735
+
736
+ data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
737
+ }
738
+ }
739
+ }
740
+
741
+ // mask padded tokens
742
+ if (data) {
743
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
744
+ for (uint32_t j = 0; j < n_kv; ++j) {
745
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
746
+ }
747
+ }
748
+ }
749
+ }
750
+ }
751
+
752
+ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
753
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
754
+
755
+ int32_t * data = (int32_t *) dst->data;
756
+
757
+ for (uint32_t i = 0; i < cells.size(); ++i) {
758
+ data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
759
+ }
760
+ }
761
+
762
+ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
763
+ const int64_t n_tokens = ubatch->n_tokens;
764
+
765
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
766
+ GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
767
+
768
+ int32_t * data = (int32_t *) dst->data;
769
+
770
+ const int32_t n_kv = dst->ne[0];
771
+
772
+ for (int h = 0; h < 1; ++h) {
773
+ for (int j = 0; j < n_tokens; ++j) {
774
+ for (int i = 0; i < n_kv; ++i) {
775
+ // the position when the cells is empty is irrelevant - it will be masked out later in the attention
776
+ const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i);
777
+
778
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
779
+ }
780
+ }
781
+ }
782
+ }
783
+
784
+ size_t llama_kv_cache_unified::total_size() const {
785
+ size_t size = 0;
786
+
787
+ for (const auto & buf : bufs) {
788
+ size += ggml_backend_buffer_get_size(buf.get());
789
+ }
790
+
791
+ return size;
792
+ }
793
+
794
+ size_t llama_kv_cache_unified::size_k_bytes() const {
795
+ size_t size_k_bytes = 0;
796
+
797
+ for (const auto & layer : layers) {
798
+ size_k_bytes += ggml_nbytes(layer.k);
799
+ }
800
+
801
+ return size_k_bytes;
802
+ }
803
+
804
+ size_t llama_kv_cache_unified::size_v_bytes() const {
805
+ size_t size_v_bytes = 0;
806
+
807
+ for (const auto & layer : layers) {
808
+ size_v_bytes += ggml_nbytes(layer.v);
809
+ }
810
+
811
+ return size_v_bytes;
812
+ }
813
+
814
+ ggml_tensor * llama_kv_cache_unified::build_rope_shift(
815
+ const llama_cparams & cparams,
816
+ ggml_context * ctx,
817
+ ggml_tensor * cur,
818
+ ggml_tensor * shift,
819
+ ggml_tensor * factors,
820
+ float freq_base,
821
+ float freq_scale) const {
822
+ const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
823
+
824
+ const auto & yarn_ext_factor = cparams.yarn_ext_factor;
825
+ const auto & yarn_beta_fast = cparams.yarn_beta_fast;
826
+ const auto & yarn_beta_slow = cparams.yarn_beta_slow;
827
+
828
+ const auto & n_rot = hparams.n_rot;
829
+ const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
830
+ // @ngxson : this is a workaround
831
+ // for M-RoPE, we want to rotate the whole vector when doing KV shift
832
+ // a normal RoPE should work, we just need to use the correct ordering
833
+ // ref: https://github.com/ggml-org/llama.cpp/pull/13870
834
+ ? LLAMA_ROPE_TYPE_NEOX
835
+ : hparams.rope_type;
836
+
837
+ // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
838
+ // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
839
+ const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
840
+ ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
841
+ : cparams.yarn_attn_factor;
842
+
843
+ ggml_tensor * tmp;
844
+
845
+ if (ggml_is_quantized(cur->type)) {
846
+ // dequantize to f32 -> RoPE -> quantize back
847
+ tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
848
+
849
+ tmp = ggml_rope_ext(ctx, tmp,
850
+ shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
851
+ yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
852
+
853
+ tmp = ggml_cpy(ctx, tmp, cur);
854
+ } else {
855
+ // we rotate only the first n_rot dimensions
856
+ tmp = ggml_rope_ext_inplace(ctx, cur,
857
+ shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
858
+ yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
859
+ }
860
+
861
+ return tmp;
862
+ }
863
+
864
+ class llm_graph_input_k_shift : public llm_graph_input_i {
865
+ public:
866
+ llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
867
+ virtual ~llm_graph_input_k_shift() = default;
868
+
869
+ void set_input(const llama_ubatch * ubatch) override;
870
+
871
+ ggml_tensor * k_shift; // I32 [kv_size]
872
+
873
+ const llama_kv_cache_unified * kv_self;
874
+ };
875
+
876
+ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
877
+ GGML_UNUSED(ubatch);
878
+
879
+ if (k_shift) {
880
+ kv_self->set_input_k_shift(k_shift);
881
+ }
882
+ }
883
+
884
+ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
885
+ const llama_cparams & cparams,
886
+ ggml_context * ctx,
887
+ ggml_cgraph * gf) const {
888
+ auto res = std::make_unique<llm_graph_result>();
889
+
890
+ const auto & n_embd_head_k = hparams.n_embd_head_k;
891
+ //const auto & n_embd_head_v = hparams.n_embd_head_v;
892
+
893
+ //GGML_ASSERT(kv_self->size == n_ctx);
894
+
895
+ auto inp = std::make_unique<llm_graph_input_k_shift>(this);
896
+
897
+ inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
898
+ ggml_set_input(inp->k_shift);
899
+
900
+ for (const auto & layer : layers) {
901
+ const uint32_t il = layer.il;
902
+
903
+ const int64_t n_head_kv = hparams.n_head_kv(il);
904
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
905
+
906
+ const float freq_base_l = model.get_rope_freq_base (cparams, il);
907
+ const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
908
+
909
+ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
910
+
911
+ ggml_tensor * k =
912
+ ggml_view_3d(ctx, layer.k,
913
+ n_embd_head_k, n_head_kv, cells.size(),
914
+ ggml_row_size(layer.k->type, n_embd_head_k),
915
+ ggml_row_size(layer.k->type, n_embd_k_gqa),
916
+ 0);
917
+
918
+ ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
919
+
920
+ ggml_build_forward_expand(gf, cur);
921
+ }
922
+
923
+ res->add_input(std::move(inp));
924
+
925
+ return res;
926
+ }
927
+
928
+ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
929
+ const llama_cparams & cparams,
930
+ ggml_context * ctx,
931
+ ggml_cgraph * gf) const {
932
+ auto res = std::make_unique<llm_graph_result>();
933
+
934
+ const auto & ids = defrag_info.ids;
935
+
936
+ #if 0
937
+ // CPU defrag
938
+ //
939
+ // TODO: optimizations are possible:
940
+ // - multiple threads
941
+ // - avoid copying to the host memory when already there
942
+ //
943
+ // likely not worth the effort, as we have ggml_graph based defrag
944
+ //
945
+
946
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
947
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
948
+
949
+ const uint32_t kv_size = size;
950
+
951
+ std::vector<uint8_t> buf_k;
952
+ std::vector<uint8_t> buf_v;
953
+
954
+ for (uint32_t il = 0; il < n_layer; ++il) {
955
+ const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
956
+ const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
957
+
958
+ const size_t v_size_el = ggml_type_size(v_l[il]->type);
959
+ const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
960
+
961
+ buf_k.resize(k_size);
962
+ buf_v.resize(v_size);
963
+
964
+ ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
965
+ ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
966
+
967
+ // batch move [i, i+nm) to [id, id+nm)
968
+ // note: cells can move only to a lower index
969
+ for (uint32_t i = 0; i < n_kv; ++i) {
970
+ const uint32_t id = ids[i];
971
+
972
+ if (i == id || id == n_kv) {
973
+ continue;
974
+ }
975
+
976
+ uint32_t nm = 1;
977
+
978
+ while (i + nm < n_kv && ids[i + nm] == id + nm) {
979
+ nm++;
980
+ }
981
+
982
+ // move keys
983
+ {
984
+ const int64_t os = i*k_size_row;
985
+ const int64_t od = id*k_size_row;
986
+
987
+ memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
988
+ }
989
+
990
+ // move values (note: they are transposed)
991
+ {
992
+ const int64_t os = i;
993
+ const int64_t od = id;
994
+
995
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
996
+ memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
997
+ }
998
+ }
999
+
1000
+ i += nm - 1;
1001
+ }
1002
+
1003
+ ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
1004
+ ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
1005
+ }
1006
+ #else
1007
+ for (uint32_t i = 0; i < ids.size(); ++i) {
1008
+ const uint32_t id = ids[i];
1009
+
1010
+ if (i == id || id == ids.size()) {
1011
+ continue;
1012
+ }
1013
+
1014
+ uint32_t nm = 1;
1015
+
1016
+ while (i + nm < ids.size() && ids[i + nm] == id + nm) {
1017
+ nm++;
1018
+ }
1019
+
1020
+ for (const auto & layer : layers) {
1021
+ const uint32_t il = layer.il;
1022
+
1023
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1024
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1025
+
1026
+ ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k,
1027
+ n_embd_k_gqa, nm,
1028
+ ggml_row_size(layer.k->type, n_embd_k_gqa),
1029
+ ggml_row_size(layer.k->type, n_embd_k_gqa*i));
1030
+
1031
+ ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k,
1032
+ n_embd_k_gqa, nm,
1033
+ ggml_row_size(layer.k->type, n_embd_k_gqa),
1034
+ ggml_row_size(layer.k->type, n_embd_k_gqa*id));
1035
+
1036
+ ggml_tensor * view_v_src;
1037
+ ggml_tensor * view_v_dst;
1038
+
1039
+ if (cparams.flash_attn) {
1040
+ // NOTE: the V cache is not transposed when using flash attention
1041
+ view_v_src = ggml_view_2d(ctx, layer.v,
1042
+ n_embd_v_gqa, nm,
1043
+ ggml_row_size(layer.v->type, n_embd_v_gqa),
1044
+ ggml_row_size(layer.v->type, n_embd_v_gqa*i));
1045
+
1046
+ view_v_dst = ggml_view_2d(ctx, layer.v,
1047
+ n_embd_v_gqa, nm,
1048
+ ggml_row_size(layer.v->type, n_embd_v_gqa),
1049
+ ggml_row_size(layer.v->type, n_embd_v_gqa*id));
1050
+ } else {
1051
+ view_v_src = ggml_view_2d(ctx, layer.v,
1052
+ nm, n_embd_v_gqa,
1053
+ ggml_row_size(layer.v->type, cells.size()),
1054
+ ggml_row_size(layer.v->type, i));
1055
+
1056
+ view_v_dst = ggml_view_2d(ctx, layer.v,
1057
+ nm, n_embd_v_gqa,
1058
+ ggml_row_size(layer.v->type, cells.size()),
1059
+ ggml_row_size(layer.v->type, id));
1060
+ }
1061
+
1062
+ ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
1063
+ ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst));
1064
+ }
1065
+
1066
+ i += nm - 1;
1067
+ }
1068
+
1069
+ //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
1070
+ #endif
1071
+
1072
+ return res;
1073
+ }
1074
+
1075
+ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
1076
+ const uint32_t n_layer = layers.size();
1077
+
1078
+ const uint32_t n_kv = cells.used_max_p1();
1079
+ const uint32_t n_used = cells.get_used();
1080
+
1081
+ assert(n_used <= n_kv);
1082
+
1083
+ //const int64_t t_start = ggml_time_us();
1084
+
1085
+ // number of cells moved
1086
+ uint32_t n_moves = 0;
1087
+
1088
+ // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
1089
+ // - source view, destination view, copy operation
1090
+ // - x2 for keys and values
1091
+ //const uint32_t max_moves = max_nodes()/(6*n_layer);
1092
+ // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
1093
+ const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
1094
+
1095
+ // determine which KV cells to move where
1096
+ //
1097
+ // cell i moves to ids[i]
1098
+ //
1099
+ // if ids[i] == i || ids[i] == n_kv, then cell i is not moved
1100
+ //
1101
+ auto & ids = defrag_info.ids;
1102
+
1103
+ ids.clear();
1104
+ ids.resize(n_kv, n_kv);
1105
+
1106
+ for (uint32_t i0 = 0; i0 < n_used; ++i0) {
1107
+ if (!cells.is_empty(i0)) {
1108
+ ids[i0] = i0;
1109
+
1110
+ continue;
1111
+ }
1112
+
1113
+ // found a hole - fill it with data from the end of the cache
1114
+
1115
+ uint32_t nh = 1;
1116
+
1117
+ // determine the size of the hole
1118
+ while (i0 + nh < n_used && cells.is_empty(i0 + nh)) {
1119
+ nh++;
1120
+ }
1121
+
1122
+ uint32_t nf = 0;
1123
+ uint32_t is = n_kv - 1;
1124
+
1125
+ // starting from the end, find nh non-empty cells
1126
+ for (; is > i0; --is) {
1127
+ if (cells.is_empty(is) || ids[is] != n_kv) {
1128
+ continue;
1129
+ }
1130
+
1131
+ // non-empty cell which is not yet moved
1132
+ nf++;
1133
+
1134
+ if (nf == nh) {
1135
+ break;
1136
+ }
1137
+ }
1138
+
1139
+ // this can only happen if `n_used` is not accurate, which would be a bug
1140
+ GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
1141
+
1142
+ nf = 0;
1143
+
1144
+ uint32_t i1 = is;
1145
+
1146
+ // are we moving a continuous block of memory?
1147
+ bool cont = false;
1148
+
1149
+ // should we stop searching for the next move?
1150
+ bool stop = false;
1151
+
1152
+ // go back and move the nf cells to the hole
1153
+ for (; i1 < n_kv; ++i1) {
1154
+ if (cells.is_empty(i1) || ids[i1] != n_kv) {
1155
+ if (n_moves == max_moves) {
1156
+ stop = true;
1157
+ break;
1158
+ }
1159
+
1160
+ cont = false;
1161
+ continue;
1162
+ }
1163
+
1164
+ // this cell goes to (i0 + nf)
1165
+ ids[i1] = i0 + nf;
1166
+
1167
+ // move the cell meta data
1168
+ cells.mv(i1, i0 + nf);
1169
+
1170
+ head = n_used;
1171
+
1172
+ if (!cont) {
1173
+ n_moves++;
1174
+ cont = true;
1175
+ }
1176
+
1177
+ nf++;
1178
+
1179
+ if (nf == nh) {
1180
+ break;
1181
+ }
1182
+ }
1183
+
1184
+ if (stop || n_moves == max_moves) {
1185
+ break;
1186
+ }
1187
+
1188
+ //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
1189
+
1190
+ i0 += nh - 1;
1191
+ }
1192
+
1193
+ if (n_moves == 0) {
1194
+ return false;
1195
+ }
1196
+
1197
+ LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
1198
+
1199
+ LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
1200
+
1201
+ return true;
1202
+ }
1203
+
1204
+ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
1205
+ assert(p0 >= 0 && p1 >= 0);
1206
+
1207
+ switch (swa_type) {
1208
+ case LLAMA_SWA_TYPE_NONE:
1209
+ {
1210
+ } break;
1211
+ case LLAMA_SWA_TYPE_STANDARD:
1212
+ {
1213
+ if (p1 - p0 >= (int32_t) n_swa) {
1214
+ return true;
1215
+ }
1216
+ } break;
1217
+ case LLAMA_SWA_TYPE_CHUNKED:
1218
+ {
1219
+ const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
1220
+
1221
+ if (p0 < pos_chunk_start) {
1222
+ return true;
1223
+ }
1224
+ } break;
1225
+ }
1226
+
1227
+ return false;
1228
+ }
1229
+
1230
+ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
1231
+ std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
1232
+ uint32_t cell_count = 0;
1233
+
1234
+ // Count the number of cells with the specified seq_id
1235
+ // Find all the ranges of cells with this seq id (or all, when -1)
1236
+ uint32_t cell_range_begin = cells.size();
1237
+
1238
+ for (uint32_t i = 0; i < cells.size(); ++i) {
1239
+ if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
1240
+ ++cell_count;
1241
+ if (cell_range_begin == cells.size()) {
1242
+ cell_range_begin = i;
1243
+ }
1244
+ } else {
1245
+ if (cell_range_begin != cells.size()) {
1246
+ cell_ranges.emplace_back(cell_range_begin, i);
1247
+ cell_range_begin = cells.size();
1248
+ }
1249
+ }
1250
+ }
1251
+
1252
+ if (cell_range_begin != cells.size()) {
1253
+ cell_ranges.emplace_back(cell_range_begin, cells.size());
1254
+ }
1255
+
1256
+ // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
1257
+ uint32_t cell_count_check = 0;
1258
+ for (const auto & range : cell_ranges) {
1259
+ cell_count_check += range.second - range.first;
1260
+ }
1261
+ GGML_ASSERT(cell_count == cell_count_check);
1262
+
1263
+ io.write(&cell_count, sizeof(cell_count));
1264
+
1265
+ state_write_meta(io, cell_ranges, seq_id);
1266
+ state_write_data(io, cell_ranges);
1267
+ }
1268
+
1269
+ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
1270
+ uint32_t cell_count;
1271
+ io.read_to(&cell_count, sizeof(cell_count));
1272
+
1273
+ bool res = true;
1274
+ res = res && state_read_meta(io, cell_count, seq_id);
1275
+ res = res && state_read_data(io, cell_count);
1276
+
1277
+ if (!res) {
1278
+ if (seq_id == -1) {
1279
+ clear();
1280
+ } else {
1281
+ seq_rm(seq_id, -1, -1);
1282
+ }
1283
+ throw std::runtime_error("failed to restore kv cache");
1284
+ }
1285
+ }
1286
+
1287
+ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
1288
+ for (const auto & range : cell_ranges) {
1289
+ for (uint32_t i = range.first; i < range.second; ++i) {
1290
+ std::vector<llama_seq_id> seq_ids;
1291
+
1292
+ for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
1293
+ if (cur == seq_id || seq_id == -1) {
1294
+ if (cells.seq_has(i, cur)) {
1295
+ seq_ids.push_back(cur);
1296
+ }
1297
+ }
1298
+ }
1299
+
1300
+ const llama_pos pos = cells.pos_get(i);
1301
+ const uint32_t n_seq_id = seq_ids.size();
1302
+
1303
+ io.write(&pos, sizeof(pos));
1304
+ io.write(&n_seq_id, sizeof(n_seq_id));
1305
+
1306
+ for (const auto & seq_id : seq_ids) {
1307
+ io.write(&seq_id, sizeof(seq_id));
1308
+ }
1309
+ }
1310
+ }
1311
+ }
1312
+
1313
+ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
1314
+ const uint32_t v_trans = this->v_trans ? 1 : 0;
1315
+ const uint32_t n_layer = layers.size();
1316
+
1317
+ io.write(&v_trans, sizeof(v_trans));
1318
+ io.write(&n_layer, sizeof(n_layer));
1319
+
1320
+ std::vector<uint8_t> tmp_buf;
1321
+
1322
+ // Iterate and write all the keys first, each row is a cell
1323
+ // Get whole range at a time
1324
+ for (const auto & layer : layers) {
1325
+ const uint32_t il = layer.il;
1326
+
1327
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1328
+
1329
+ // Write key type
1330
+ const int32_t k_type_i = (int32_t)layer.k->type;
1331
+ io.write(&k_type_i, sizeof(k_type_i));
1332
+
1333
+ // Write row size of key
1334
+ const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
1335
+ io.write(&k_size_row, sizeof(k_size_row));
1336
+
1337
+ // Read each range of cells of k_size length each into tmp_buf and write out
1338
+ for (const auto & range : cell_ranges) {
1339
+ const size_t range_size = range.second - range.first;
1340
+ const size_t buf_size = range_size * k_size_row;
1341
+ io.write_tensor(layer.k, range.first * k_size_row, buf_size);
1342
+ }
1343
+ }
1344
+
1345
+ if (!v_trans) {
1346
+ for (const auto & layer : layers) {
1347
+ const uint32_t il = layer.il;
1348
+
1349
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1350
+
1351
+ // Write value type
1352
+ const int32_t v_type_i = (int32_t)layer.v->type;
1353
+ io.write(&v_type_i, sizeof(v_type_i));
1354
+
1355
+ // Write row size of value
1356
+ const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
1357
+ io.write(&v_size_row, sizeof(v_size_row));
1358
+
1359
+ // Read each range of cells of v_size length each into tmp_buf and write out
1360
+ for (const auto & range : cell_ranges) {
1361
+ const size_t range_size = range.second - range.first;
1362
+ const size_t buf_size = range_size * v_size_row;
1363
+ io.write_tensor(layer.v, range.first * v_size_row, buf_size);
1364
+ }
1365
+ }
1366
+ } else {
1367
+ // When v is transposed, we also need the element size and get the element ranges from each row
1368
+ const uint32_t kv_size = cells.size();
1369
+
1370
+ for (const auto & layer : layers) {
1371
+ const uint32_t il = layer.il;
1372
+
1373
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1374
+
1375
+ // Write value type
1376
+ const int32_t v_type_i = (int32_t)layer.v->type;
1377
+ io.write(&v_type_i, sizeof(v_type_i));
1378
+
1379
+ // Write element size
1380
+ const uint32_t v_size_el = ggml_type_size(layer.v->type);
1381
+ io.write(&v_size_el, sizeof(v_size_el));
1382
+
1383
+ // Write GQA embedding size
1384
+ io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
1385
+
1386
+ // For each row, we get the element values of each cell
1387
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1388
+ // Read each range of cells of v_size_el length each into tmp_buf and write out
1389
+ for (const auto & range : cell_ranges) {
1390
+ const size_t range_size = range.second - range.first;
1391
+ const size_t src_offset = (range.first + j * kv_size) * v_size_el;
1392
+ const size_t buf_size = range_size * v_size_el;
1393
+ io.write_tensor(layer.v, src_offset, buf_size);
1394
+ }
1395
+ }
1396
+ }
1397
+ }
1398
+ }
1399
+
1400
+ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
1401
+ if (dest_seq_id != -1) {
1402
+ // single sequence
1403
+
1404
+ seq_rm(dest_seq_id, -1, -1);
1405
+
1406
+ llama_sbatch sbatch;
1407
+ llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
1408
+
1409
+ batch.n_tokens = cell_count;
1410
+
1411
+ for (uint32_t i = 0; i < cell_count; ++i) {
1412
+ llama_pos pos;
1413
+ uint32_t n_seq_id;
1414
+
1415
+ io.read_to(&pos, sizeof(pos));
1416
+ io.read_to(&n_seq_id, sizeof(n_seq_id));
1417
+
1418
+ if (n_seq_id != 1) {
1419
+ LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
1420
+ return false;
1421
+ }
1422
+
1423
+ // read the sequence id, but directly discard it - we will use dest_seq_id instead
1424
+ {
1425
+ llama_seq_id seq_id;
1426
+ io.read_to(&seq_id, sizeof(seq_id));
1427
+ }
1428
+
1429
+ batch.pos[i] = pos;
1430
+ batch.n_seq_id[i] = n_seq_id;
1431
+ batch.seq_id[i] = &dest_seq_id;
1432
+ }
1433
+
1434
+ const auto head_cur = find_slot(batch);
1435
+ if (head_cur < 0) {
1436
+ LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1437
+ return false;
1438
+ }
1439
+
1440
+ apply_ubatch(head_cur, batch);
1441
+
1442
+ // keep the head at the old position because we will read the KV data into it in state_read_data()
1443
+ head = head_cur;
1444
+
1445
+ // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
1446
+ // Assume that this is one contiguous block of cells
1447
+ GGML_ASSERT(head_cur + cell_count <= cells.size());
1448
+ GGML_ASSERT(cells.pos_get(head_cur) == batch.pos[0]);
1449
+ GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == batch.pos[cell_count - 1]);
1450
+ GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
1451
+ GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
1452
+ } else {
1453
+ // whole KV cache restore
1454
+
1455
+ if (cell_count > cells.size()) {
1456
+ LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
1457
+ return false;
1458
+ }
1459
+
1460
+ clear();
1461
+
1462
+ for (uint32_t i = 0; i < cell_count; ++i) {
1463
+ llama_pos pos;
1464
+ uint32_t n_seq_id;
1465
+
1466
+ io.read_to(&pos, sizeof(pos));
1467
+ io.read_to(&n_seq_id, sizeof(n_seq_id));
1468
+
1469
+ cells.pos_set(i, pos);
1470
+
1471
+ for (uint32_t j = 0; j < n_seq_id; ++j) {
1472
+ llama_seq_id seq_id;
1473
+ io.read_to(&seq_id, sizeof(seq_id));
1474
+
1475
+ if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
1476
+ LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
1477
+ return false;
1478
+ }
1479
+
1480
+ cells.seq_add(i, seq_id);
1481
+ }
1482
+ }
1483
+
1484
+ head = 0;
1485
+ }
1486
+
1487
+ return true;
1488
+ }
1489
+
1490
+ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
1491
+ uint32_t v_trans;
1492
+ uint32_t n_layer;
1493
+
1494
+ io.read_to(&v_trans, sizeof(v_trans));
1495
+ io.read_to(&n_layer, sizeof(n_layer));
1496
+
1497
+ if (n_layer != layers.size()) {
1498
+ LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
1499
+ return false;
1500
+ }
1501
+
1502
+ if (cell_count > cells.size()) {
1503
+ LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
1504
+ return false;
1505
+ }
1506
+
1507
+ if (this->v_trans != (bool) v_trans) {
1508
+ LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
1509
+ return false;
1510
+ }
1511
+
1512
+ // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
1513
+ for (const auto & layer : layers) {
1514
+ const uint32_t il = layer.il;
1515
+
1516
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1517
+
1518
+ // Read type of key
1519
+ int32_t k_type_i_ref;
1520
+ io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
1521
+ const int32_t k_type_i = (int32_t) layer.k->type;
1522
+ if (k_type_i != k_type_i_ref) {
1523
+ LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
1524
+ return false;
1525
+ }
1526
+
1527
+ // Read row size of key
1528
+ uint64_t k_size_row_ref;
1529
+ io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
1530
+ const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
1531
+ if (k_size_row != k_size_row_ref) {
1532
+ LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
1533
+ return false;
1534
+ }
1535
+
1536
+ if (cell_count) {
1537
+ // Read and set the keys for the whole cell range
1538
+ ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
1539
+ }
1540
+ }
1541
+
1542
+ if (!this->v_trans) {
1543
+ for (const auto & layer : layers) {
1544
+ const uint32_t il = layer.il;
1545
+
1546
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1547
+
1548
+ // Read type of value
1549
+ int32_t v_type_i_ref;
1550
+ io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1551
+ const int32_t v_type_i = (int32_t)layer.v->type;
1552
+ if (v_type_i != v_type_i_ref) {
1553
+ LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1554
+ return false;
1555
+ }
1556
+
1557
+ // Read row size of value
1558
+ uint64_t v_size_row_ref;
1559
+ io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
1560
+ const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
1561
+ if (v_size_row != v_size_row_ref) {
1562
+ LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
1563
+ return false;
1564
+ }
1565
+
1566
+ if (cell_count) {
1567
+ // Read and set the values for the whole cell range
1568
+ ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
1569
+ }
1570
+ }
1571
+ } else {
1572
+ // For each layer, read the values for each cell (transposed)
1573
+ for (const auto & layer : layers) {
1574
+ const uint32_t il = layer.il;
1575
+
1576
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1577
+
1578
+ // Read type of value
1579
+ int32_t v_type_i_ref;
1580
+ io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1581
+ const int32_t v_type_i = (int32_t)layer.v->type;
1582
+ if (v_type_i != v_type_i_ref) {
1583
+ LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1584
+ return false;
1585
+ }
1586
+
1587
+ // Read element size of value
1588
+ uint32_t v_size_el_ref;
1589
+ io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
1590
+ const size_t v_size_el = ggml_type_size(layer.v->type);
1591
+ if (v_size_el != v_size_el_ref) {
1592
+ LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
1593
+ return false;
1594
+ }
1595
+
1596
+ // Read GQA embedding size
1597
+ uint32_t n_embd_v_gqa_ref;
1598
+ io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
1599
+ if (n_embd_v_gqa != n_embd_v_gqa_ref) {
1600
+ LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
1601
+ return false;
1602
+ }
1603
+
1604
+ if (cell_count) {
1605
+ // For each row in the transposed matrix, read the values for the whole cell range
1606
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1607
+ const size_t dst_offset = (head + j * cells.size()) * v_size_el;
1608
+ ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1609
+ }
1610
+ }
1611
+ }
1612
+ }
1613
+
1614
+ return true;
1615
+ }
1616
+
1617
+ //
1618
+ // llama_kv_cache_unified_state
1619
+ //
1620
+
1621
+ llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
1622
+
1623
+ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1624
+ llama_memory_status status,
1625
+ llama_kv_cache_unified * kv) : status(status), kv(kv) {
1626
+ n_kv = kv->get_size();
1627
+ head = 0;
1628
+ }
1629
+
1630
+ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1631
+ llama_memory_status status,
1632
+ llama_kv_cache_unified * kv,
1633
+ llama_sbatch sbatch,
1634
+ std::vector<uint32_t> heads,
1635
+ std::vector<llama_ubatch> ubatches)
1636
+ : status(status),
1637
+ kv(kv),
1638
+ sbatch(std::move(sbatch)),
1639
+ heads(std::move(heads)),
1640
+ ubatches(std::move(ubatches)) {
1641
+ }
1642
+
1643
+ llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
1644
+
1645
+ bool llama_kv_cache_unified_state::next() {
1646
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1647
+
1648
+ if (++i_next >= ubatches.size()) {
1649
+ return false;
1650
+ }
1651
+
1652
+ return true;
1653
+ }
1654
+
1655
+ bool llama_kv_cache_unified_state::apply() {
1656
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1657
+
1658
+ kv->apply_ubatch(heads[i_next], ubatches[i_next]);
1659
+
1660
+ n_kv = kv->get_n_kv();
1661
+ head = heads[i_next];
1662
+
1663
+ return true;
1664
+ }
1665
+
1666
+ std::vector<int64_t> & llama_kv_cache_unified_state::out_ids() {
1667
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1668
+
1669
+ return sbatch.out_ids;
1670
+ }
1671
+
1672
+ llama_memory_status llama_kv_cache_unified_state::get_status() const {
1673
+ return status;
1674
+ }
1675
+
1676
+ const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const {
1677
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1678
+
1679
+ return ubatches[i_next];
1680
+ }
1681
+
1682
+ uint32_t llama_kv_cache_unified_state::get_n_kv() const {
1683
+ return n_kv;
1684
+ }
1685
+
1686
+ ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const {
1687
+ return kv->get_k(ctx, il, n_kv);
1688
+ }
1689
+
1690
+ ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const {
1691
+ return kv->get_v(ctx, il, n_kv);
1692
+ }
1693
+
1694
+ ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
1695
+ return kv->cpy_k(ctx, k_cur, il, head);
1696
+ }
1697
+
1698
+ ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
1699
+ return kv->cpy_v(ctx, v_cur, il, head);
1700
+ }
1701
+
1702
+ void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const {
1703
+ kv->set_input_k_shift(dst);
1704
+ }
1705
+
1706
+ void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
1707
+ kv->set_input_kq_mask(dst, ubatch, causal_attn);
1708
+ }
1709
+
1710
+ void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1711
+ kv->set_input_pos_bucket(dst, ubatch);
1712
+ }
1713
+
1714
+ uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
1715
+ // the FA kernels require padding to avoid extra runtime boundary checks
1716
+ return cparams.flash_attn ? 256u : 32u;
1717
+ }
examples/talk-llama/llama-kv-cache-unified.h ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "llama-batch.h"
4
+ #include "llama-graph.h"
5
+ #include "llama-kv-cache.h"
6
+ #include "llama-kv-cells.h"
7
+
8
+ #include <unordered_map>
9
+ #include <vector>
10
+
11
+ struct llama_cparams;
12
+ struct llama_hparams;
13
+ struct llama_model;
14
+ struct llama_context;
15
+
16
+ //
17
+ // llama_kv_cache_unified
18
+ //
19
+
20
+ class llama_kv_cache_unified : public llama_kv_cache {
21
+ public:
22
+ static uint32_t get_padding(const llama_cparams & cparams);
23
+
24
+ // this callback is used to filter out layers that should not be included in the cache
25
+ using layer_filter_cb = std::function<bool(int32_t il)>;
26
+
27
+ llama_kv_cache_unified(
28
+ const llama_model & model,
29
+ layer_filter_cb && filter,
30
+ ggml_type type_k,
31
+ ggml_type type_v,
32
+ bool v_trans,
33
+ bool offload,
34
+ uint32_t kv_size,
35
+ uint32_t n_seq_max,
36
+ uint32_t n_pad,
37
+ uint32_t n_swa,
38
+ llama_swa_type swa_type);
39
+
40
+ ~llama_kv_cache_unified() = default;
41
+
42
+ //
43
+ // llama_memory_i
44
+ //
45
+
46
+ void clear() override;
47
+
48
+ bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
49
+ void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
50
+ void seq_keep(llama_seq_id seq_id) override;
51
+ void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
52
+ void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
53
+
54
+ llama_pos seq_pos_min(llama_seq_id seq_id) const override;
55
+ llama_pos seq_pos_max(llama_seq_id seq_id) const override;
56
+
57
+ //
58
+ // llama_kv_cache
59
+ //
60
+
61
+ llama_memory_state_ptr init_batch(
62
+ const llama_batch & batch,
63
+ uint32_t n_ubatch,
64
+ bool embd_pooled,
65
+ bool logits_all) override;
66
+
67
+ llama_memory_state_ptr init_full() override;
68
+
69
+ bool update(llama_context & lctx) override;
70
+
71
+ void defrag_sched(float thold) override;
72
+
73
+ bool get_can_shift() const override;
74
+
75
+ // state write/load
76
+
77
+ void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
78
+ void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
79
+
80
+ //
81
+ // llama_kv_cache_unified specific API
82
+ //
83
+
84
+ uint32_t get_size() const;
85
+
86
+ //
87
+ // graph_build API
88
+ //
89
+
90
+ uint32_t get_n_kv() const;
91
+
92
+ // get views of the current state of the cache
93
+ ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
94
+ ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
95
+
96
+ // store k_cur and v_cur in the cache based on the provided head location
97
+ ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const;
98
+ ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const;
99
+
100
+ //
101
+ // preparation API
102
+ //
103
+
104
+ // find places for the provided ubatches in the cache, returns the head locations
105
+ // return empty vector on failure
106
+ std::vector<uint32_t> prepare(const std::vector<llama_ubatch> & ubatches);
107
+
108
+ // return the cell position where we can insert the ubatch
109
+ // return -1 on failure to find a contiguous slot of kv cells
110
+ int32_t find_slot(const llama_ubatch & ubatch) const;
111
+
112
+ // emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens)
113
+ void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch);
114
+
115
+ //
116
+ // set_input API
117
+ //
118
+
119
+ void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
120
+ void set_input_k_shift (ggml_tensor * dst) const;
121
+ void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
122
+
123
+ private:
124
+ const llama_model & model;
125
+ const llama_hparams & hparams;
126
+
127
+ struct kv_layer {
128
+ // layer index in the model
129
+ // note: can be different from the layer index in the KV cache
130
+ uint32_t il;
131
+
132
+ ggml_tensor * k;
133
+ ggml_tensor * v;
134
+ };
135
+
136
+ bool do_defrag = false;
137
+ bool v_trans = true; // the value tensor is transposed
138
+
139
+ // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
140
+ // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
141
+ uint32_t head = 0;
142
+
143
+ const uint32_t n_seq_max = 1;
144
+
145
+ // required padding
146
+ const uint32_t n_pad = 1;
147
+
148
+ // SWA
149
+ const uint32_t n_swa = 0;
150
+
151
+ const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
152
+
153
+ std::vector<ggml_context_ptr> ctxs;
154
+ std::vector<ggml_backend_buffer_ptr> bufs;
155
+
156
+ llama_kv_cells_unified cells;
157
+
158
+ std::vector<kv_layer> layers;
159
+
160
+ // model layer id -> KV cache layer id
161
+ std::unordered_map<int32_t, int32_t> map_layer_ids;
162
+
163
+ // defrag
164
+ struct {
165
+ std::vector<uint32_t> ids;
166
+ } defrag_info;
167
+
168
+ // return true if cells have been moved
169
+ bool defrag_prepare(int32_t n_max_nodes);
170
+
171
+ size_t total_size() const;
172
+
173
+ size_t size_k_bytes() const;
174
+ size_t size_v_bytes() const;
175
+
176
+ bool is_masked_swa(llama_pos p0, llama_pos p1) const;
177
+
178
+ ggml_tensor * build_rope_shift(
179
+ const llama_cparams & cparams,
180
+ ggml_context * ctx,
181
+ ggml_tensor * cur,
182
+ ggml_tensor * shift,
183
+ ggml_tensor * factors,
184
+ float freq_base,
185
+ float freq_scale) const;
186
+
187
+ llm_graph_result_ptr build_graph_shift(
188
+ const llama_cparams & cparams,
189
+ ggml_context * ctx,
190
+ ggml_cgraph * gf) const;
191
+
192
+ llm_graph_result_ptr build_graph_defrag(
193
+ const llama_cparams & cparams,
194
+ ggml_context * ctx,
195
+ ggml_cgraph * gf) const;
196
+
197
+ void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
198
+ void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
199
+
200
+ bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
201
+ bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
202
+ };
203
+
204
+ class llama_kv_cache_unified_state : public llama_memory_state_i {
205
+ public:
206
+ // used for errors
207
+ llama_kv_cache_unified_state(llama_memory_status status);
208
+
209
+ // used to create a full-cache state
210
+ llama_kv_cache_unified_state(
211
+ llama_memory_status status,
212
+ llama_kv_cache_unified * kv);
213
+
214
+ // used to create a state from a batch
215
+ llama_kv_cache_unified_state(
216
+ llama_memory_status status,
217
+ llama_kv_cache_unified * kv,
218
+ llama_sbatch sbatch,
219
+ std::vector<uint32_t> heads,
220
+ std::vector<llama_ubatch> ubatches);
221
+
222
+ virtual ~llama_kv_cache_unified_state();
223
+
224
+ //
225
+ // llama_memory_state_i
226
+ //
227
+
228
+ bool next() override;
229
+ bool apply() override;
230
+
231
+ std::vector<int64_t> & out_ids() override;
232
+
233
+ llama_memory_status get_status() const override;
234
+ const llama_ubatch & get_ubatch() const override;
235
+
236
+ //
237
+ // llama_kv_cache_unified_state specific API
238
+ //
239
+
240
+ uint32_t get_n_kv() const;
241
+
242
+ // get views of the current state of the cache
243
+ ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
244
+ ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
245
+
246
+ // store k_cur and v_cur in the cache based on the provided head location
247
+ ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
248
+ ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
249
+
250
+ void set_input_k_shift(ggml_tensor * dst) const;
251
+
252
+ void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
253
+ void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
254
+
255
+ private:
256
+ const llama_memory_status status;
257
+
258
+ llama_kv_cache_unified * kv;
259
+
260
+ llama_sbatch sbatch;
261
+
262
+ // the index of the next ubatch to process
263
+ size_t i_next = 0;
264
+
265
+ std::vector<uint32_t> heads;
266
+ std::vector<llama_ubatch> ubatches;
267
+
268
+ //
269
+ // data needed for building the compute graph for the current ubatch:
270
+ //
271
+
272
+ // a heuristic, to avoid attending the full cache if it is not yet utilized
273
+ // as the cache gets filled, the benefit from this heuristic disappears
274
+ int32_t n_kv;
275
+
276
+ // the beginning of the current slot in which the ubatch will be inserted
277
+ int32_t head;
278
+ };
examples/talk-llama/llama-kv-cache.cpp CHANGED
@@ -1,2739 +1 @@
1
  #include "llama-kv-cache.h"
2
-
3
- #include "llama-impl.h"
4
- #include "llama-batch.h"
5
- #include "llama-cparams.h"
6
- #include "llama-model.h"
7
- #include "llama-context.h"
8
-
9
- #include <algorithm>
10
- #include <cassert>
11
- #include <cmath>
12
- #include <limits>
13
- #include <map>
14
- #include <stdexcept>
15
-
16
- //
17
- // llama_kv_cache_unified
18
- //
19
-
20
- uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
21
- // the FA kernels require padding to avoid extra runtime boundary checks
22
- return cparams.flash_attn ? 256u : 32u;
23
- }
24
-
25
- llama_kv_cache_unified::llama_kv_cache_unified(
26
- const llama_model & model,
27
- layer_filter_cb && filter,
28
- ggml_type type_k,
29
- ggml_type type_v,
30
- bool v_trans,
31
- bool offload,
32
- uint32_t kv_size,
33
- uint32_t n_seq_max,
34
- uint32_t n_pad,
35
- uint32_t n_swa,
36
- llama_swa_type swa_type) :
37
- model(model), hparams(model.hparams), v_trans(v_trans),
38
- n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
39
-
40
- GGML_ASSERT(kv_size % n_pad == 0);
41
-
42
- // create a context for each buffer type
43
- std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
44
- auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
45
- auto it = ctx_map.find(buft);
46
- if (it == ctx_map.end()) {
47
- ggml_init_params params = {
48
- /*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()),
49
- /*.mem_buffer =*/ NULL,
50
- /*.no_alloc =*/ true,
51
- };
52
-
53
- ggml_context * ctx = ggml_init(params);
54
- if (!ctx) {
55
- return nullptr;
56
- }
57
-
58
- ctx_map[buft] = ctx;
59
- ctxs.emplace_back(ctx);
60
-
61
- return ctx;
62
- }
63
-
64
- return it->second;
65
- };
66
-
67
- head = 0;
68
-
69
- cells.resize(kv_size);
70
-
71
- for (uint32_t il = 0; il < hparams.n_layer; il++) {
72
- if (filter && !filter(il)) {
73
- LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
74
- continue;
75
- }
76
-
77
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
78
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
79
-
80
- const char * dev_name = "CPU";
81
-
82
- ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
83
-
84
- if (offload) {
85
- auto * dev = model.dev_layer(il);
86
- buft = ggml_backend_dev_buffer_type(dev);
87
-
88
- dev_name = ggml_backend_dev_name(dev);
89
- }
90
-
91
- LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name);
92
-
93
- ggml_context * ctx = ctx_for_buft(buft);
94
- if (!ctx) {
95
- throw std::runtime_error("failed to create ggml context for kv cache");
96
- }
97
-
98
- ggml_tensor * k;
99
- ggml_tensor * v;
100
-
101
- k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
102
- v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
103
-
104
- ggml_format_name(k, "cache_k_l%d", il);
105
- ggml_format_name(v, "cache_v_l%d", il);
106
-
107
- map_layer_ids[il] = layers.size();
108
- layers.push_back({ il, k, v });
109
- }
110
-
111
- // allocate tensors and initialize the buffers to avoid NaNs in the padding
112
- for (auto it : ctx_map) {
113
- auto * buft = it.first;
114
- auto * ctx = it.second;
115
-
116
- ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
117
- if (!buf) {
118
- throw std::runtime_error("failed to allocate buffer for kv cache");
119
- }
120
-
121
- LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
122
-
123
- ggml_backend_buffer_clear(buf, 0);
124
- bufs.emplace_back(buf);
125
- }
126
-
127
- {
128
- const size_t memory_size_k = size_k_bytes();
129
- const size_t memory_size_v = size_v_bytes();
130
-
131
- LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
132
- (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
133
- ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
134
- ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
135
- }
136
- }
137
-
138
- void llama_kv_cache_unified::clear() {
139
- cells.reset();
140
-
141
- head = 0;
142
-
143
- for (auto & buf : bufs) {
144
- ggml_backend_buffer_clear(buf.get(), 0);
145
- }
146
- }
147
-
148
- bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
149
- uint32_t new_head = cells.size();
150
-
151
- if (p0 < 0) {
152
- p0 = 0;
153
- }
154
-
155
- if (p1 < 0) {
156
- p1 = std::numeric_limits<llama_pos>::max();
157
- }
158
-
159
- for (uint32_t i = 0; i < cells.size(); ++i) {
160
- if (!cells.pos_in(i, p0, p1)) {
161
- continue;
162
- }
163
-
164
- if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
165
- if (new_head == cells.size()) {
166
- new_head = i;
167
- }
168
- }
169
- }
170
-
171
- // If we freed up a slot, set head to it so searching can start there.
172
- if (new_head != cells.size() && new_head < head) {
173
- head = new_head;
174
- }
175
-
176
- return true;
177
- }
178
-
179
- void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
180
- if (seq_id_src == seq_id_dst) {
181
- return;
182
- }
183
-
184
- if (p0 < 0) {
185
- p0 = 0;
186
- }
187
-
188
- if (p1 < 0) {
189
- p1 = std::numeric_limits<llama_pos>::max();
190
- }
191
-
192
- for (uint32_t i = 0; i < cells.size(); ++i) {
193
- if (!cells.pos_in(i, p0, p1)) {
194
- continue;
195
- }
196
-
197
- if (cells.seq_has(i, seq_id_src)) {
198
- cells.seq_add(i, seq_id_dst);
199
- }
200
- }
201
- }
202
-
203
- void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
204
- uint32_t new_head = cells.size();
205
-
206
- for (uint32_t i = 0; i < cells.size(); ++i) {
207
- if (cells.seq_keep(i, seq_id)) {
208
- if (new_head == cells.size()) {
209
- new_head = i;
210
- }
211
- }
212
- }
213
-
214
- // If we freed up a slot, set head to it so searching can start there.
215
- if (new_head != cells.size() && new_head < head) {
216
- head = new_head;
217
- }
218
- }
219
-
220
- void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
221
- if (shift == 0) {
222
- return;
223
- }
224
-
225
- uint32_t new_head = cells.size();
226
-
227
- if (p0 < 0) {
228
- p0 = 0;
229
- }
230
-
231
- if (p1 < 0) {
232
- p1 = std::numeric_limits<llama_pos>::max();
233
- }
234
-
235
- // If there is no range then return early to avoid looping over all cells.
236
- if (p0 == p1) {
237
- return;
238
- }
239
-
240
- for (uint32_t i = 0; i < cells.size(); ++i) {
241
- if (!cells.pos_in(i, p0, p1)) {
242
- continue;
243
- }
244
-
245
- if (cells.seq_has(i, seq_id)) {
246
- if (cells.pos_add(i, shift)) {
247
- if (new_head == cells.size()) {
248
- new_head = i;
249
- }
250
- }
251
- }
252
- }
253
-
254
- // If we freed up a slot, set head to it so searching can start there.
255
- // Otherwise we just start the next search from the beginning.
256
- head = new_head != cells.size() ? new_head : 0;
257
- }
258
-
259
- void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
260
- if (d == 1) {
261
- return;
262
- }
263
-
264
- if (p0 < 0) {
265
- p0 = 0;
266
- }
267
-
268
- if (p1 < 0) {
269
- p1 = std::numeric_limits<llama_pos>::max();
270
- }
271
-
272
- // If there is no range then return early to avoid looping over the cache.
273
- if (p0 == p1) {
274
- return;
275
- }
276
-
277
- for (uint32_t i = 0; i < cells.size(); ++i) {
278
- if (!cells.pos_in(i, p0, p1)) {
279
- continue;
280
- }
281
-
282
- if (cells.seq_has(i, seq_id)) {
283
- cells.pos_div(i, d);
284
- }
285
- }
286
- }
287
-
288
- llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
289
- return cells.seq_pos_min(seq_id);
290
- }
291
-
292
- llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
293
- return cells.seq_pos_max(seq_id);
294
- }
295
-
296
- void llama_kv_cache_unified::restore() {
297
- for (auto & state : recovery.states) {
298
- cells.set(state.i, state.cells);
299
- }
300
-
301
- recovery.clear();
302
- }
303
-
304
- void llama_kv_cache_unified::commit() {
305
- if (recovery.states.empty()) {
306
- LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n",
307
- __func__, "https://github.com/ggml-org/llama.cpp/pull/13194");
308
- return;
309
- }
310
-
311
- recovery.clear();
312
- }
313
-
314
- bool llama_kv_cache_unified::update(llama_context & lctx) {
315
- bool need_reserve = false;
316
-
317
- auto * sched = lctx.get_sched();
318
-
319
- if (cells.get_has_shift()) {
320
- if (!get_can_shift()) {
321
- GGML_ABORT("The current KV cache / model configuration does not support K-shift");
322
- }
323
-
324
- LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
325
-
326
- // apply K-shift if needed
327
- if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
328
- ggml_backend_sched_reset(sched);
329
-
330
- auto * gf = lctx.graph_init();
331
-
332
- auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
333
-
334
- ggml_backend_sched_alloc_graph(sched, gf);
335
-
336
- res->set_inputs(nullptr);
337
-
338
- lctx.graph_compute(gf, false);
339
-
340
- need_reserve = true;
341
- }
342
-
343
- cells.reset_shift();
344
- }
345
-
346
- if (do_defrag) {
347
- LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
348
-
349
- if (defrag_prepare(lctx.graph_max_nodes())) {
350
- ggml_backend_sched_reset(sched);
351
-
352
- auto * gf = lctx.graph_init();
353
-
354
- auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
355
-
356
- ggml_backend_sched_alloc_graph(sched, gf);
357
-
358
- res->set_inputs(nullptr);
359
-
360
- lctx.graph_compute(gf, false);
361
-
362
- need_reserve = true;
363
- }
364
-
365
- do_defrag = false;
366
- }
367
-
368
- return need_reserve;
369
- }
370
-
371
- void llama_kv_cache_unified::defrag_sched(float thold) {
372
- // - do not defrag small contexts (i.e. < 2048 tokens)
373
- // - count the padding towards the number of used tokens
374
- const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n)) : 0.0f;
375
-
376
- // queue defragmentation for next llama_kv_cache_update
377
- if (fragmentation > thold) {
378
- LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
379
-
380
- do_defrag = true;
381
- }
382
- }
383
-
384
- void llama_kv_cache_unified::set_full() {
385
- n = cells.size();
386
-
387
- // when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
388
- // affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
389
- // we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so
390
- // setting it to 0 is the simplest way to achieve that
391
- // ref: https://github.com/ggml-org/llama.cpp/issues/13359
392
- head = 0;
393
- }
394
-
395
- llama_sbatch llama_kv_cache_unified::sbatch_init(const llama_batch & batch, bool logits_all) {
396
- return llama_sbatch(batch, hparams.n_embd, true, logits_all);
397
- }
398
-
399
- llama_ubatch llama_kv_cache_unified::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
400
- GGML_UNUSED(embd_pooled);
401
- return sbatch.split_simple(n_ubatch);
402
- }
403
-
404
- bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
405
- const uint32_t n_tokens = ubatch.n_tokens;
406
-
407
- // if we have enough unused cells before the current head ->
408
- // better to start searching from the beginning of the cache, hoping to fill it
409
- if (head > cells.get_used() + 2*ubatch.n_tokens) {
410
- head = 0;
411
- }
412
-
413
- // otherwise, one cell per token.
414
-
415
- if (n_tokens > cells.size()) {
416
- LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
417
- return false;
418
- }
419
-
420
- //#define FIND_SLOT_DEBUG 1
421
- #if FIND_SLOT_DEBUG
422
- LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
423
-
424
- // for debugging
425
- {
426
- std::string ss;
427
- if (n_swa > 0) {
428
- for (uint32_t i = 0; i < size; ++i) {
429
- if (cells.is_empty(i)) {
430
- ss += '.';
431
- } else {
432
- ss += 'x';
433
- }
434
- if (i%256 == 255) {
435
- ss += '\n';
436
- }
437
- }
438
- }
439
- LLAMA_LOG_WARN("\n%s\n", ss.c_str());
440
- }
441
- #endif
442
-
443
- uint32_t n_tested = 0;
444
-
445
- while (true) {
446
- if (head + n_tokens > cells.size()) {
447
- n_tested += cells.size() - head;
448
- head = 0;
449
- continue;
450
- }
451
-
452
- bool found = true;
453
- for (uint32_t i = 0; i < n_tokens; i++) {
454
- // TODO: improve to accept cells that are masked by the SWA
455
- if (!cells.is_empty(head + i)) {
456
- found = false;
457
- head += i + 1;
458
- n_tested += i + 1;
459
- break;
460
- }
461
- }
462
-
463
- if (found) {
464
- break;
465
- }
466
-
467
- if (n_tested >= cells.size()) {
468
- //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
469
- return false;
470
- }
471
- }
472
-
473
- // store the old state of the cells in the recovery stack
474
- recovery.states.push_back({head, cells.cp(head, n_tokens)});
475
-
476
- for (uint32_t i = 0; i < n_tokens; ++i) {
477
- cells.pos_set(head + i, ubatch.pos[i]);
478
-
479
- for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
480
- cells.seq_add(head + i, ubatch.seq_id[i][j]);
481
- }
482
- }
483
-
484
- // a heuristic, to avoid attending the full cache if it is not yet utilized
485
- // after enough generations, the benefit from this heuristic disappears
486
- // if we start defragmenting the cache, the benefit from this will be more important
487
- n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
488
-
489
- #ifdef FIND_SLOT_DEBUG
490
- LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
491
- #endif
492
-
493
- return true;
494
- }
495
-
496
- bool llama_kv_cache_unified::get_can_shift() const {
497
- return true;
498
- }
499
-
500
- uint32_t llama_kv_cache_unified::get_n() const {
501
- return n;
502
- }
503
-
504
- uint32_t llama_kv_cache_unified::get_size() const {
505
- return cells.size();
506
- }
507
-
508
- ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const {
509
- const int32_t ikv = map_layer_ids.at(il);
510
-
511
- auto * k = layers[ikv].k;
512
-
513
- return ggml_view_3d(ctx, k,
514
- hparams.n_embd_head_k, hparams.n_head_kv(il), n,
515
- ggml_row_size(k->type, hparams.n_embd_head_k),
516
- ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
517
- 0);
518
- }
519
-
520
- ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) const {
521
- const int32_t ikv = map_layer_ids.at(il);
522
-
523
- auto * v = layers[ikv].v;
524
-
525
- if (!v_trans) {
526
- // note: v->nb[1] <= v->nb[2]
527
- return ggml_view_3d(ctx, v,
528
- hparams.n_embd_head_v, hparams.n_head_kv(il), n,
529
- ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
530
- ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
531
- 0);
532
- }
533
-
534
- // note: v->nb[1] > v->nb[2]
535
- return ggml_view_3d(ctx, v,
536
- n, hparams.n_head_kv(il), hparams.n_embd_head_v,
537
- ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
538
- ggml_row_size(v->type, v->ne[1]), // v->nb[2]
539
- 0);
540
- }
541
-
542
- ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
543
- const int32_t ikv = map_layer_ids.at(il);
544
-
545
- auto * k = layers[ikv].k;
546
-
547
- const int64_t n_tokens = k_cur->ne[2];
548
-
549
- ggml_tensor * k_view = ggml_view_1d(ctx, k,
550
- n_tokens*hparams.n_embd_k_gqa(il),
551
- ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head);
552
-
553
- return ggml_cpy(ctx, k_cur, k_view);
554
- }
555
-
556
- ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
557
- const int32_t ikv = map_layer_ids.at(il);
558
-
559
- auto * v = layers[ikv].v;
560
-
561
- const int64_t n_tokens = v_cur->ne[2];
562
-
563
- v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens);
564
-
565
- ggml_tensor * v_view = nullptr;
566
-
567
- if (!v_trans) {
568
- v_view = ggml_view_1d(ctx, v,
569
- n_tokens*hparams.n_embd_v_gqa(il),
570
- ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head);
571
- } else {
572
- // note: the V cache is transposed when not using flash attention
573
- v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
574
- (v->ne[1])*ggml_element_size(v),
575
- ( head)*ggml_element_size(v));
576
-
577
- v_cur = ggml_transpose(ctx, v_cur);
578
- }
579
-
580
- return ggml_cpy(ctx, v_cur, v_view);
581
- }
582
-
583
- void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax) {
584
- // no pruning is needed when the cache does not use SWA
585
- GGML_ASSERT(swa_type != LLAMA_SWA_TYPE_NONE && "do not prune non-SWA cache");
586
-
587
- int n_attended = 0;
588
-
589
- for (uint32_t i = 0; i < cells.size(); ++i) {
590
- if (!cells.seq_has(i, seq_id)) {
591
- continue;
592
- }
593
-
594
- const llama_pos p0 = cells.pos_get(i);
595
-
596
- if (p0 <= pmin && !is_masked_swa(p0, pmin)) {
597
- n_attended++;
598
- }
599
-
600
- if (is_masked_swa(p0, pmax)) {
601
- cells.seq_rm(i, seq_id);
602
- }
603
- }
604
-
605
- if (n_attended < std::min<int>(n_swa, pmin)) {
606
- LLAMA_LOG_WARN("%s: partial SWA cache detected - possible loss of information, pmin = %d, n_attended = %d, n_swa = %d\n", __func__, pmin, n_attended, n_swa);
607
- }
608
- }
609
-
610
- void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
611
- const int64_t n_tokens = ubatch->n_tokens;
612
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
613
- const int64_t n_seqs = ubatch->n_seqs;
614
-
615
- GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
616
- float * data = (float *) dst->data;
617
-
618
- const int64_t n_kv = n;
619
-
620
- // Use only the previous KV cells of the correct sequence for each token of the ubatch.
621
- // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
622
- // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
623
- // Causal mask:
624
- // xxx-------
625
- // xxxx------
626
- // xxxxx-----
627
- // Non-causal mask:
628
- // xxxxx-----
629
- // xxxxx-----
630
- // xxxxx-----
631
- // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
632
- for (int h = 0; h < 1; ++h) {
633
- for (int s = 0; s < n_seqs; ++s) {
634
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
635
-
636
- for (int j = 0; j < n_seq_tokens; ++j) {
637
- const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
638
-
639
- for (int i = 0; i < n_kv; ++i) {
640
- float f = 0.0f;
641
-
642
- bool masked = false;
643
-
644
- if (cells.is_empty(i)) {
645
- masked = true;
646
- } else {
647
- const llama_pos p0 = cells.pos_get(i);
648
-
649
- // mask the token if not the same sequence
650
- masked = masked || (!cells.seq_has(i, seq_id));
651
-
652
- // mask future tokens
653
- masked = masked || (causal_attn && p0 > p1);
654
-
655
- // apply SWA if any
656
- masked = masked || (is_masked_swa(p0, p1));
657
-
658
- if (!masked && hparams.use_alibi) {
659
- f = -std::abs(p0 - p1);
660
- }
661
- }
662
-
663
- if (masked) {
664
- f = -INFINITY;
665
- }
666
-
667
- data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
668
- }
669
- }
670
- }
671
-
672
- // mask padded tokens
673
- if (data) {
674
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
675
- for (int j = 0; j < n_kv; ++j) {
676
- data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
677
- }
678
- }
679
- }
680
- }
681
- }
682
-
683
- void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
684
- GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
685
-
686
- int32_t * data = (int32_t *) dst->data;
687
-
688
- for (uint32_t i = 0; i < cells.size(); ++i) {
689
- data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
690
- }
691
- }
692
-
693
- void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
694
- const int64_t n_tokens = ubatch->n_tokens;
695
-
696
- GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
697
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
698
-
699
- int32_t * data = (int32_t *) dst->data;
700
-
701
- const int64_t n_kv = n;
702
-
703
- for (int h = 0; h < 1; ++h) {
704
- for (int j = 0; j < n_tokens; ++j) {
705
- for (int i = 0; i < n_kv; ++i) {
706
- // the position when the cells is empty is irrelevant - it will be masked out later in the attention
707
- const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i);
708
-
709
- data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
710
- }
711
- }
712
- }
713
- }
714
-
715
- size_t llama_kv_cache_unified::total_size() const {
716
- size_t size = 0;
717
-
718
- for (const auto & buf : bufs) {
719
- size += ggml_backend_buffer_get_size(buf.get());
720
- }
721
-
722
- return size;
723
- }
724
-
725
- size_t llama_kv_cache_unified::size_k_bytes() const {
726
- size_t size_k_bytes = 0;
727
-
728
- for (const auto & layer : layers) {
729
- size_k_bytes += ggml_nbytes(layer.k);
730
- }
731
-
732
- return size_k_bytes;
733
- }
734
-
735
- size_t llama_kv_cache_unified::size_v_bytes() const {
736
- size_t size_v_bytes = 0;
737
-
738
- for (const auto & layer : layers) {
739
- size_v_bytes += ggml_nbytes(layer.v);
740
- }
741
-
742
- return size_v_bytes;
743
- }
744
-
745
- ggml_tensor * llama_kv_cache_unified::build_rope_shift(
746
- const llama_cparams & cparams,
747
- ggml_context * ctx,
748
- ggml_tensor * cur,
749
- ggml_tensor * shift,
750
- ggml_tensor * factors,
751
- float freq_base,
752
- float freq_scale) const {
753
- const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
754
-
755
- const auto & yarn_ext_factor = cparams.yarn_ext_factor;
756
- const auto & yarn_beta_fast = cparams.yarn_beta_fast;
757
- const auto & yarn_beta_slow = cparams.yarn_beta_slow;
758
-
759
- const auto & n_rot = hparams.n_rot;
760
- const auto & rope_type = hparams.rope_type;
761
-
762
- // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
763
- // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
764
- const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
765
-
766
- ggml_tensor * tmp;
767
-
768
- if (ggml_is_quantized(cur->type)) {
769
- // dequantize to f32 -> RoPE -> quantize back
770
- tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
771
-
772
- tmp = ggml_rope_ext(ctx, tmp,
773
- shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
774
- yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
775
-
776
- tmp = ggml_cpy(ctx, tmp, cur);
777
- } else {
778
- // we rotate only the first n_rot dimensions
779
- tmp = ggml_rope_ext_inplace(ctx, cur,
780
- shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
781
- yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
782
- }
783
-
784
- return tmp;
785
- }
786
-
787
- class llm_graph_input_k_shift : public llm_graph_input_i {
788
- public:
789
- llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
790
- virtual ~llm_graph_input_k_shift() = default;
791
-
792
- void set_input(const llama_ubatch * ubatch) override;
793
-
794
- ggml_tensor * k_shift; // I32 [kv_size]
795
-
796
- const llama_kv_cache_unified * kv_self;
797
- };
798
-
799
- void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
800
- GGML_UNUSED(ubatch);
801
-
802
- if (k_shift) {
803
- kv_self->set_input_k_shift(k_shift);
804
- }
805
- }
806
-
807
- llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
808
- const llama_cparams & cparams,
809
- ggml_context * ctx,
810
- ggml_cgraph * gf) const {
811
- auto res = std::make_unique<llm_graph_result>();
812
-
813
- const auto & n_embd_head_k = hparams.n_embd_head_k;
814
- //const auto & n_embd_head_v = hparams.n_embd_head_v;
815
-
816
- //GGML_ASSERT(kv_self->size == n_ctx);
817
-
818
- auto inp = std::make_unique<llm_graph_input_k_shift>(this);
819
-
820
- inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
821
- ggml_set_input(inp->k_shift);
822
-
823
- for (const auto & layer : layers) {
824
- const uint32_t il = layer.il;
825
-
826
- const int64_t n_head_kv = hparams.n_head_kv(il);
827
- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
828
-
829
- const float freq_base_l = model.get_rope_freq_base (cparams, il);
830
- const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
831
-
832
- ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
833
-
834
- ggml_tensor * k =
835
- ggml_view_3d(ctx, layer.k,
836
- n_embd_head_k, n_head_kv, cells.size(),
837
- ggml_row_size(layer.k->type, n_embd_head_k),
838
- ggml_row_size(layer.k->type, n_embd_k_gqa),
839
- 0);
840
-
841
- ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
842
-
843
- ggml_build_forward_expand(gf, cur);
844
- }
845
-
846
- res->add_input(std::move(inp));
847
-
848
- return res;
849
- }
850
-
851
- llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
852
- const llama_cparams & cparams,
853
- ggml_context * ctx,
854
- ggml_cgraph * gf) const {
855
- auto res = std::make_unique<llm_graph_result>();
856
-
857
- const auto & ids = defrag_info.ids;
858
-
859
- #if 0
860
- // CPU defrag
861
- //
862
- // TODO: optimizations are possible:
863
- // - multiple threads
864
- // - avoid copying to the host memory when already there
865
- //
866
- // likely not worth the effort, as we have ggml_graph based defrag
867
- //
868
-
869
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
870
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
871
-
872
- const uint32_t kv_size = size;
873
-
874
- std::vector<uint8_t> buf_k;
875
- std::vector<uint8_t> buf_v;
876
-
877
- for (uint32_t il = 0; il < n_layer; ++il) {
878
- const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
879
- const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
880
-
881
- const size_t v_size_el = ggml_type_size(v_l[il]->type);
882
- const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
883
-
884
- buf_k.resize(k_size);
885
- buf_v.resize(v_size);
886
-
887
- ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
888
- ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
889
-
890
- // batch move [i, i+nm) to [id, id+nm)
891
- // note: cells can move only to a lower index
892
- for (uint32_t i = 0; i < n_kv; ++i) {
893
- const uint32_t id = ids[i];
894
-
895
- if (i == id || id == n_kv) {
896
- continue;
897
- }
898
-
899
- uint32_t nm = 1;
900
-
901
- while (i + nm < n_kv && ids[i + nm] == id + nm) {
902
- nm++;
903
- }
904
-
905
- // move keys
906
- {
907
- const int64_t os = i*k_size_row;
908
- const int64_t od = id*k_size_row;
909
-
910
- memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
911
- }
912
-
913
- // move values (note: they are transposed)
914
- {
915
- const int64_t os = i;
916
- const int64_t od = id;
917
-
918
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
919
- memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
920
- }
921
- }
922
-
923
- i += nm - 1;
924
- }
925
-
926
- ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
927
- ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
928
- }
929
- #else
930
- for (uint32_t i = 0; i < ids.size(); ++i) {
931
- const uint32_t id = ids[i];
932
-
933
- if (i == id || id == ids.size()) {
934
- continue;
935
- }
936
-
937
- uint32_t nm = 1;
938
-
939
- while (i + nm < ids.size() && ids[i + nm] == id + nm) {
940
- nm++;
941
- }
942
-
943
- for (const auto & layer : layers) {
944
- const uint32_t il = layer.il;
945
-
946
- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
947
- const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
948
-
949
- ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k,
950
- n_embd_k_gqa, nm,
951
- ggml_row_size(layer.k->type, n_embd_k_gqa),
952
- ggml_row_size(layer.k->type, n_embd_k_gqa*i));
953
-
954
- ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k,
955
- n_embd_k_gqa, nm,
956
- ggml_row_size(layer.k->type, n_embd_k_gqa),
957
- ggml_row_size(layer.k->type, n_embd_k_gqa*id));
958
-
959
- ggml_tensor * view_v_src;
960
- ggml_tensor * view_v_dst;
961
-
962
- if (cparams.flash_attn) {
963
- // NOTE: the V cache is not transposed when using flash attention
964
- view_v_src = ggml_view_2d(ctx, layer.v,
965
- n_embd_v_gqa, nm,
966
- ggml_row_size(layer.v->type, n_embd_v_gqa),
967
- ggml_row_size(layer.v->type, n_embd_v_gqa*i));
968
-
969
- view_v_dst = ggml_view_2d(ctx, layer.v,
970
- n_embd_v_gqa, nm,
971
- ggml_row_size(layer.v->type, n_embd_v_gqa),
972
- ggml_row_size(layer.v->type, n_embd_v_gqa*id));
973
- } else {
974
- view_v_src = ggml_view_2d(ctx, layer.v,
975
- nm, n_embd_v_gqa,
976
- ggml_row_size(layer.v->type, cells.size()),
977
- ggml_row_size(layer.v->type, i));
978
-
979
- view_v_dst = ggml_view_2d(ctx, layer.v,
980
- nm, n_embd_v_gqa,
981
- ggml_row_size(layer.v->type, cells.size()),
982
- ggml_row_size(layer.v->type, id));
983
- }
984
-
985
- ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
986
- ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst));
987
- }
988
-
989
- i += nm - 1;
990
- }
991
-
992
- //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
993
- #endif
994
-
995
- return res;
996
- }
997
-
998
- bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
999
- const uint32_t n_layer = layers.size();
1000
-
1001
- const uint32_t n_kv = cells.used_max_p1();
1002
- const uint32_t n_used = cells.get_used();
1003
-
1004
- assert(n_used <= n_kv);
1005
-
1006
- //const int64_t t_start = ggml_time_us();
1007
-
1008
- // number of cells moved
1009
- uint32_t n_moves = 0;
1010
-
1011
- // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
1012
- // - source view, destination view, copy operation
1013
- // - x2 for keys and values
1014
- //const uint32_t max_moves = max_nodes()/(6*n_layer);
1015
- // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
1016
- const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
1017
-
1018
- // determine which KV cells to move where
1019
- //
1020
- // cell i moves to ids[i]
1021
- //
1022
- // if ids[i] == i || ids[i] == n_kv, then cell i is not moved
1023
- //
1024
- auto & ids = defrag_info.ids;
1025
-
1026
- ids.clear();
1027
- ids.resize(n_kv, n_kv);
1028
-
1029
- for (uint32_t i0 = 0; i0 < n_used; ++i0) {
1030
- if (!cells.is_empty(i0)) {
1031
- ids[i0] = i0;
1032
-
1033
- continue;
1034
- }
1035
-
1036
- // found a hole - fill it with data from the end of the cache
1037
-
1038
- uint32_t nh = 1;
1039
-
1040
- // determine the size of the hole
1041
- while (i0 + nh < n_used && cells.is_empty(i0 + nh)) {
1042
- nh++;
1043
- }
1044
-
1045
- uint32_t nf = 0;
1046
- uint32_t is = n_kv - 1;
1047
-
1048
- // starting from the end, find nh non-empty cells
1049
- for (; is > i0; --is) {
1050
- if (cells.is_empty(is) || ids[is] != n_kv) {
1051
- continue;
1052
- }
1053
-
1054
- // non-empty cell which is not yet moved
1055
- nf++;
1056
-
1057
- if (nf == nh) {
1058
- break;
1059
- }
1060
- }
1061
-
1062
- // this can only happen if `n_used` is not accurate, which would be a bug
1063
- GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
1064
-
1065
- nf = 0;
1066
-
1067
- uint32_t i1 = is;
1068
-
1069
- // are we moving a continuous block of memory?
1070
- bool cont = false;
1071
-
1072
- // should we stop searching for the next move?
1073
- bool stop = false;
1074
-
1075
- // go back and move the nf cells to the hole
1076
- for (; i1 < n_kv; ++i1) {
1077
- if (cells.is_empty(i1) || ids[i1] != n_kv) {
1078
- if (n_moves == max_moves) {
1079
- stop = true;
1080
- break;
1081
- }
1082
-
1083
- cont = false;
1084
- continue;
1085
- }
1086
-
1087
- // this cell goes to (i0 + nf)
1088
- ids[i1] = i0 + nf;
1089
-
1090
- // move the cell meta data
1091
- cells.mv(i1, i0 + nf);
1092
-
1093
- head = n_used;
1094
-
1095
- if (!cont) {
1096
- n_moves++;
1097
- cont = true;
1098
- }
1099
-
1100
- nf++;
1101
-
1102
- if (nf == nh) {
1103
- break;
1104
- }
1105
- }
1106
-
1107
- if (stop || n_moves == max_moves) {
1108
- break;
1109
- }
1110
-
1111
- //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
1112
-
1113
- i0 += nh - 1;
1114
- }
1115
-
1116
- if (n_moves == 0) {
1117
- return false;
1118
- }
1119
-
1120
- LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
1121
-
1122
- LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
1123
-
1124
- return true;
1125
- }
1126
-
1127
- bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
1128
- assert(p0 >= 0 && p1 >= 0);
1129
-
1130
- switch (swa_type) {
1131
- case LLAMA_SWA_TYPE_NONE:
1132
- {
1133
- } break;
1134
- case LLAMA_SWA_TYPE_STANDARD:
1135
- {
1136
- if (p1 - p0 >= (int32_t) n_swa) {
1137
- return true;
1138
- }
1139
- } break;
1140
- case LLAMA_SWA_TYPE_CHUNKED:
1141
- {
1142
- const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
1143
-
1144
- if (p0 < pos_chunk_start) {
1145
- return true;
1146
- }
1147
- } break;
1148
- }
1149
-
1150
- return false;
1151
- }
1152
-
1153
- void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
1154
- std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
1155
- uint32_t cell_count = 0;
1156
-
1157
- // Count the number of cells with the specified seq_id
1158
- // Find all the ranges of cells with this seq id (or all, when -1)
1159
- uint32_t cell_range_begin = cells.size();
1160
-
1161
- for (uint32_t i = 0; i < cells.size(); ++i) {
1162
- if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
1163
- ++cell_count;
1164
- if (cell_range_begin == cells.size()) {
1165
- cell_range_begin = i;
1166
- }
1167
- } else {
1168
- if (cell_range_begin != cells.size()) {
1169
- cell_ranges.emplace_back(cell_range_begin, i);
1170
- cell_range_begin = cells.size();
1171
- }
1172
- }
1173
- }
1174
-
1175
- if (cell_range_begin != cells.size()) {
1176
- cell_ranges.emplace_back(cell_range_begin, cells.size());
1177
- }
1178
-
1179
- // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
1180
- uint32_t cell_count_check = 0;
1181
- for (const auto & range : cell_ranges) {
1182
- cell_count_check += range.second - range.first;
1183
- }
1184
- GGML_ASSERT(cell_count == cell_count_check);
1185
-
1186
- io.write(&cell_count, sizeof(cell_count));
1187
-
1188
- state_write_meta(io, cell_ranges, seq_id);
1189
- state_write_data(io, cell_ranges);
1190
- }
1191
-
1192
- void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
1193
- uint32_t cell_count;
1194
- io.read_to(&cell_count, sizeof(cell_count));
1195
-
1196
- bool res = true;
1197
- res = res && state_read_meta(io, cell_count, seq_id);
1198
- res = res && state_read_data(io, cell_count);
1199
-
1200
- if (!res) {
1201
- if (seq_id == -1) {
1202
- clear();
1203
- } else {
1204
- seq_rm(seq_id, -1, -1);
1205
- }
1206
- throw std::runtime_error("failed to restore kv cache");
1207
- }
1208
- }
1209
-
1210
- void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
1211
- for (const auto & range : cell_ranges) {
1212
- for (uint32_t i = range.first; i < range.second; ++i) {
1213
- std::vector<llama_seq_id> seq_ids;
1214
-
1215
- for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
1216
- if (cur == seq_id || seq_id == -1) {
1217
- if (cells.seq_has(i, cur)) {
1218
- seq_ids.push_back(cur);
1219
- }
1220
- }
1221
- }
1222
-
1223
- const llama_pos pos = cells.pos_get(i);
1224
- const uint32_t n_seq_id = seq_ids.size();
1225
-
1226
- io.write(&pos, sizeof(pos));
1227
- io.write(&n_seq_id, sizeof(n_seq_id));
1228
-
1229
- for (const auto & seq_id : seq_ids) {
1230
- io.write(&seq_id, sizeof(seq_id));
1231
- }
1232
- }
1233
- }
1234
- }
1235
-
1236
- void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
1237
- const uint32_t v_trans = this->v_trans ? 1 : 0;
1238
- const uint32_t n_layer = layers.size();
1239
-
1240
- io.write(&v_trans, sizeof(v_trans));
1241
- io.write(&n_layer, sizeof(n_layer));
1242
-
1243
- std::vector<uint8_t> tmp_buf;
1244
-
1245
- // Iterate and write all the keys first, each row is a cell
1246
- // Get whole range at a time
1247
- for (const auto & layer : layers) {
1248
- const uint32_t il = layer.il;
1249
-
1250
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1251
-
1252
- // Write key type
1253
- const int32_t k_type_i = (int32_t)layer.k->type;
1254
- io.write(&k_type_i, sizeof(k_type_i));
1255
-
1256
- // Write row size of key
1257
- const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
1258
- io.write(&k_size_row, sizeof(k_size_row));
1259
-
1260
- // Read each range of cells of k_size length each into tmp_buf and write out
1261
- for (const auto & range : cell_ranges) {
1262
- const size_t range_size = range.second - range.first;
1263
- const size_t buf_size = range_size * k_size_row;
1264
- io.write_tensor(layer.k, range.first * k_size_row, buf_size);
1265
- }
1266
- }
1267
-
1268
- if (!v_trans) {
1269
- for (const auto & layer : layers) {
1270
- const uint32_t il = layer.il;
1271
-
1272
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1273
-
1274
- // Write value type
1275
- const int32_t v_type_i = (int32_t)layer.v->type;
1276
- io.write(&v_type_i, sizeof(v_type_i));
1277
-
1278
- // Write row size of value
1279
- const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
1280
- io.write(&v_size_row, sizeof(v_size_row));
1281
-
1282
- // Read each range of cells of v_size length each into tmp_buf and write out
1283
- for (const auto & range : cell_ranges) {
1284
- const size_t range_size = range.second - range.first;
1285
- const size_t buf_size = range_size * v_size_row;
1286
- io.write_tensor(layer.v, range.first * v_size_row, buf_size);
1287
- }
1288
- }
1289
- } else {
1290
- // When v is transposed, we also need the element size and get the element ranges from each row
1291
- const uint32_t kv_size = cells.size();
1292
-
1293
- for (const auto & layer : layers) {
1294
- const uint32_t il = layer.il;
1295
-
1296
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1297
-
1298
- // Write value type
1299
- const int32_t v_type_i = (int32_t)layer.v->type;
1300
- io.write(&v_type_i, sizeof(v_type_i));
1301
-
1302
- // Write element size
1303
- const uint32_t v_size_el = ggml_type_size(layer.v->type);
1304
- io.write(&v_size_el, sizeof(v_size_el));
1305
-
1306
- // Write GQA embedding size
1307
- io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
1308
-
1309
- // For each row, we get the element values of each cell
1310
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1311
- // Read each range of cells of v_size_el length each into tmp_buf and write out
1312
- for (const auto & range : cell_ranges) {
1313
- const size_t range_size = range.second - range.first;
1314
- const size_t src_offset = (range.first + j * kv_size) * v_size_el;
1315
- const size_t buf_size = range_size * v_size_el;
1316
- io.write_tensor(layer.v, src_offset, buf_size);
1317
- }
1318
- }
1319
- }
1320
- }
1321
- }
1322
-
1323
- bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
1324
- if (dest_seq_id != -1) {
1325
- // single sequence
1326
-
1327
- seq_rm(dest_seq_id, -1, -1);
1328
-
1329
- llama_sbatch sbatch;
1330
- llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
1331
-
1332
- batch.n_tokens = cell_count;
1333
-
1334
- for (uint32_t i = 0; i < cell_count; ++i) {
1335
- llama_pos pos;
1336
- uint32_t n_seq_id;
1337
-
1338
- io.read_to(&pos, sizeof(pos));
1339
- io.read_to(&n_seq_id, sizeof(n_seq_id));
1340
-
1341
- if (n_seq_id != 1) {
1342
- LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
1343
- return false;
1344
- }
1345
-
1346
- // read the sequence id, but directly discard it - we will use dest_seq_id instead
1347
- {
1348
- llama_seq_id seq_id;
1349
- io.read_to(&seq_id, sizeof(seq_id));
1350
- }
1351
-
1352
- batch.pos[i] = pos;
1353
- batch.n_seq_id[i] = n_seq_id;
1354
- batch.seq_id[i] = &dest_seq_id;
1355
- }
1356
-
1357
- if (!find_slot(batch)) {
1358
- LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1359
- return false;
1360
- }
1361
-
1362
- commit();
1363
-
1364
- // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
1365
- // Assume that this is one contiguous block of cells
1366
- GGML_ASSERT(head + cell_count <= cells.size());
1367
- GGML_ASSERT(cells.pos_get(head) == batch.pos[0]);
1368
- GGML_ASSERT(cells.pos_get(head + cell_count - 1) == batch.pos[cell_count - 1]);
1369
- GGML_ASSERT(cells.seq_has(head, dest_seq_id));
1370
- GGML_ASSERT(cells.seq_has(head + cell_count - 1, dest_seq_id));
1371
- } else {
1372
- // whole KV cache restore
1373
-
1374
- if (cell_count > cells.size()) {
1375
- LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
1376
- return false;
1377
- }
1378
-
1379
- clear();
1380
-
1381
- for (uint32_t i = 0; i < cell_count; ++i) {
1382
- llama_pos pos;
1383
- uint32_t n_seq_id;
1384
-
1385
- io.read_to(&pos, sizeof(pos));
1386
- io.read_to(&n_seq_id, sizeof(n_seq_id));
1387
-
1388
- cells.pos_set(i, pos);
1389
-
1390
- for (uint32_t j = 0; j < n_seq_id; ++j) {
1391
- llama_seq_id seq_id;
1392
- io.read_to(&seq_id, sizeof(seq_id));
1393
-
1394
- if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
1395
- LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
1396
- return false;
1397
- }
1398
-
1399
- cells.seq_add(i, seq_id);
1400
- }
1401
- }
1402
-
1403
- head = 0;
1404
- }
1405
-
1406
- return true;
1407
- }
1408
-
1409
- bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
1410
- uint32_t v_trans;
1411
- uint32_t n_layer;
1412
-
1413
- io.read_to(&v_trans, sizeof(v_trans));
1414
- io.read_to(&n_layer, sizeof(n_layer));
1415
-
1416
- if (n_layer != layers.size()) {
1417
- LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
1418
- return false;
1419
- }
1420
- if (cell_count > cells.size()) {
1421
- LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
1422
- return false;
1423
- }
1424
- if (this->v_trans != (bool) v_trans) {
1425
- LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
1426
- return false;
1427
- }
1428
-
1429
- // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
1430
- for (const auto & layer : layers) {
1431
- const uint32_t il = layer.il;
1432
-
1433
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1434
-
1435
- // Read type of key
1436
- int32_t k_type_i_ref;
1437
- io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
1438
- const int32_t k_type_i = (int32_t) layer.k->type;
1439
- if (k_type_i != k_type_i_ref) {
1440
- LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
1441
- return false;
1442
- }
1443
-
1444
- // Read row size of key
1445
- uint64_t k_size_row_ref;
1446
- io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
1447
- const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
1448
- if (k_size_row != k_size_row_ref) {
1449
- LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
1450
- return false;
1451
- }
1452
-
1453
- if (cell_count) {
1454
- // Read and set the keys for the whole cell range
1455
- ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
1456
- }
1457
- }
1458
-
1459
- if (!this->v_trans) {
1460
- for (const auto & layer : layers) {
1461
- const uint32_t il = layer.il;
1462
-
1463
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1464
-
1465
- // Read type of value
1466
- int32_t v_type_i_ref;
1467
- io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1468
- const int32_t v_type_i = (int32_t)layer.v->type;
1469
- if (v_type_i != v_type_i_ref) {
1470
- LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1471
- return false;
1472
- }
1473
-
1474
- // Read row size of value
1475
- uint64_t v_size_row_ref;
1476
- io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
1477
- const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
1478
- if (v_size_row != v_size_row_ref) {
1479
- LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
1480
- return false;
1481
- }
1482
-
1483
- if (cell_count) {
1484
- // Read and set the values for the whole cell range
1485
- ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
1486
- }
1487
- }
1488
- } else {
1489
- // For each layer, read the values for each cell (transposed)
1490
- for (const auto & layer : layers) {
1491
- const uint32_t il = layer.il;
1492
-
1493
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1494
-
1495
- // Read type of value
1496
- int32_t v_type_i_ref;
1497
- io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1498
- const int32_t v_type_i = (int32_t)layer.v->type;
1499
- if (v_type_i != v_type_i_ref) {
1500
- LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1501
- return false;
1502
- }
1503
-
1504
- // Read element size of value
1505
- uint32_t v_size_el_ref;
1506
- io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
1507
- const size_t v_size_el = ggml_type_size(layer.v->type);
1508
- if (v_size_el != v_size_el_ref) {
1509
- LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
1510
- return false;
1511
- }
1512
-
1513
- // Read GQA embedding size
1514
- uint32_t n_embd_v_gqa_ref;
1515
- io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
1516
- if (n_embd_v_gqa != n_embd_v_gqa_ref) {
1517
- LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
1518
- return false;
1519
- }
1520
-
1521
- if (cell_count) {
1522
- // For each row in the transposed matrix, read the values for the whole cell range
1523
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1524
- const size_t dst_offset = (head + j * cells.size()) * v_size_el;
1525
- ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1526
- }
1527
- }
1528
- }
1529
- }
1530
-
1531
- return true;
1532
- }
1533
-
1534
- //
1535
- // llama_kv_cache_unified_iswa
1536
- //
1537
-
1538
- llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
1539
- const llama_model & model,
1540
- ggml_type type_k,
1541
- ggml_type type_v,
1542
- bool v_trans,
1543
- bool offload,
1544
- bool swa_full,
1545
- uint32_t kv_size,
1546
- uint32_t n_seq_max,
1547
- uint32_t n_batch,
1548
- uint32_t n_pad) : hparams(model.hparams) {
1549
- llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
1550
- llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
1551
-
1552
- const uint32_t size_base = kv_size;
1553
-
1554
- uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad));
1555
-
1556
- // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning
1557
- if (swa_full) {
1558
- LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
1559
- __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
1560
-
1561
- size_swa = size_base;
1562
- do_prune = false;
1563
- }
1564
-
1565
- LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
1566
-
1567
- kv_base = std::make_unique<llama_kv_cache_unified>(
1568
- model, std::move(filter_base), type_k, type_v,
1569
- v_trans, offload, size_base, n_seq_max, n_pad,
1570
- 0, LLAMA_SWA_TYPE_NONE);
1571
-
1572
- LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
1573
-
1574
- kv_swa = std::make_unique<llama_kv_cache_unified>(
1575
- model, std::move(filter_swa), type_k, type_v,
1576
- v_trans, offload, size_swa, n_seq_max, n_pad,
1577
- hparams.n_swa, hparams.swa_type);
1578
- }
1579
-
1580
- void llama_kv_cache_unified_iswa::clear() {
1581
- kv_base->clear();
1582
- kv_swa ->clear();
1583
- }
1584
-
1585
- bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
1586
- bool res = true;
1587
-
1588
- res = res & kv_base->seq_rm(seq_id, p0, p1);
1589
- res = res & kv_swa ->seq_rm(seq_id, p0, p1);
1590
-
1591
- return res;
1592
- }
1593
-
1594
- void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
1595
- kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
1596
- kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
1597
- }
1598
-
1599
- void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
1600
- kv_base->seq_keep(seq_id);
1601
- kv_swa ->seq_keep(seq_id);
1602
- }
1603
-
1604
- void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
1605
- kv_base->seq_add(seq_id, p0, p1, shift);
1606
- kv_swa ->seq_add(seq_id, p0, p1, shift);
1607
- }
1608
-
1609
- void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
1610
- kv_base->seq_div(seq_id, p0, p1, d);
1611
- kv_swa ->seq_div(seq_id, p0, p1, d);
1612
- }
1613
-
1614
- llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
1615
- // the base cache is a superset of the SWA cache, so we can just check the SWA cache
1616
- return kv_swa->seq_pos_min(seq_id);
1617
- }
1618
-
1619
- llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
1620
- return kv_swa->seq_pos_max(seq_id);
1621
- }
1622
-
1623
- void llama_kv_cache_unified_iswa::restore() {
1624
- kv_base->restore();
1625
- kv_swa ->restore();
1626
- }
1627
-
1628
- void llama_kv_cache_unified_iswa::commit() {
1629
- kv_base->commit();
1630
- kv_swa ->commit();
1631
-
1632
- // slide the attention window, forgetting/pruning old tokens that are outside the window
1633
- if (do_prune) {
1634
- for (const auto & [seq_id, entry] : pending.pos) {
1635
- kv_swa->prune_swa(seq_id, entry.pmin, entry.pmax);
1636
- }
1637
-
1638
- }
1639
-
1640
- pending.clear();
1641
- }
1642
-
1643
- bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
1644
- bool res = true;
1645
-
1646
- res = res & kv_base->update(lctx);
1647
- res = res & kv_swa ->update(lctx);
1648
-
1649
- return res;
1650
- }
1651
-
1652
- void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
1653
- kv_base->defrag_sched(thold);
1654
- kv_swa ->defrag_sched(thold);
1655
- }
1656
-
1657
- void llama_kv_cache_unified_iswa::set_full() {
1658
- kv_base->set_full();
1659
- kv_swa ->set_full();
1660
- }
1661
-
1662
- llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
1663
- pending.clear();
1664
-
1665
- if (do_prune) {
1666
- for (int i = 0; i < batch.n_tokens; ++i) {
1667
- for (int s = 0; s < batch.n_seq_id[i]; ++s) {
1668
- const llama_seq_id seq_id = batch.seq_id[i][s];
1669
- const llama_pos pos = batch.pos[i];
1670
-
1671
- if (pending.pos.find(seq_id) == pending.pos.end()) {
1672
- pending.pos[seq_id].pmin = pos;
1673
- pending.pos[seq_id].pmax = pos;
1674
- } else {
1675
- pending.pos[seq_id].pmin = std::min(pending.pos[seq_id].pmin, pos);
1676
- pending.pos[seq_id].pmax = std::max(pending.pos[seq_id].pmax, pos);
1677
- }
1678
- }
1679
- }
1680
- }
1681
-
1682
- return llama_sbatch(batch, hparams.n_embd, true, logits_all);
1683
- }
1684
-
1685
- llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
1686
- GGML_UNUSED(embd_pooled);
1687
- return sbatch.split_simple(n_ubatch);
1688
- }
1689
-
1690
- bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
1691
- bool res = true;
1692
-
1693
- res = res & kv_base->find_slot(batch);
1694
- res = res & kv_swa ->find_slot(batch);
1695
-
1696
- return res;
1697
- }
1698
-
1699
- bool llama_kv_cache_unified_iswa::get_can_shift() const {
1700
- return kv_base->get_size() == kv_swa->get_size();
1701
- }
1702
-
1703
- void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
1704
- kv_base->state_write(io, seq_id);
1705
- kv_swa ->state_write(io, seq_id);
1706
- }
1707
-
1708
- void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
1709
- kv_base->state_read(io, seq_id);
1710
- kv_swa ->state_read(io, seq_id);
1711
- }
1712
-
1713
- llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_base() const {
1714
- return kv_base.get();
1715
- }
1716
-
1717
- llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_swa() const {
1718
- return kv_swa.get();
1719
- }
1720
-
1721
- //
1722
- // llama_kv_cache_recurrent
1723
- //
1724
-
1725
- llama_kv_cache_recurrent::llama_kv_cache_recurrent(
1726
- const llama_model & model,
1727
- ggml_type type_k,
1728
- ggml_type type_v,
1729
- bool offload,
1730
- uint32_t kv_size,
1731
- uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
1732
- const int32_t n_layer = hparams.n_layer;
1733
-
1734
- LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
1735
- __func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
1736
-
1737
- head = 0;
1738
- size = kv_size;
1739
- used = 0;
1740
-
1741
- cells.clear();
1742
- cells.resize(kv_size);
1743
-
1744
- // create a context for each buffer type
1745
- std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
1746
- auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
1747
- auto it = ctx_map.find(buft);
1748
- if (it == ctx_map.end()) {
1749
- ggml_init_params params = {
1750
- /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
1751
- /*.mem_buffer =*/ NULL,
1752
- /*.no_alloc =*/ true,
1753
- };
1754
-
1755
- ggml_context * ctx = ggml_init(params);
1756
- if (!ctx) {
1757
- return nullptr;
1758
- }
1759
-
1760
- ctx_map[buft] = ctx;
1761
- ctxs.emplace_back(ctx);
1762
-
1763
- return ctx;
1764
- }
1765
-
1766
- return it->second;
1767
- };
1768
-
1769
- k_l.reserve(n_layer);
1770
- v_l.reserve(n_layer);
1771
-
1772
- for (int i = 0; i < n_layer; i++) {
1773
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
1774
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
1775
-
1776
- const char * dev_name = "CPU";
1777
-
1778
- ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
1779
-
1780
- if (offload) {
1781
- auto * dev = model.dev_layer(i);
1782
- buft = ggml_backend_dev_buffer_type(dev);
1783
-
1784
- dev_name = ggml_backend_dev_name(dev);
1785
- }
1786
-
1787
- LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name);
1788
-
1789
- ggml_context * ctx = ctx_for_buft(buft);
1790
- if (!ctx) {
1791
- throw std::runtime_error("failed to create ggml context for kv cache");
1792
- }
1793
-
1794
- ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
1795
- ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
1796
- ggml_format_name(k, "cache_k_l%d", i);
1797
- ggml_format_name(v, "cache_v_l%d", i);
1798
- k_l.push_back(k);
1799
- v_l.push_back(v);
1800
- }
1801
-
1802
- // allocate tensors and initialize the buffers to avoid NaNs in the padding
1803
- for (auto it : ctx_map) {
1804
- auto * buft = it.first;
1805
- auto * ctx = it.second;
1806
-
1807
- ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
1808
- if (!buf) {
1809
- throw std::runtime_error("failed to allocate buffer for kv cache");
1810
- }
1811
- ggml_backend_buffer_clear(buf, 0);
1812
- LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
1813
- bufs.emplace_back(buf);
1814
- }
1815
-
1816
- {
1817
- const size_t memory_size_k = size_k_bytes();
1818
- const size_t memory_size_v = size_v_bytes();
1819
-
1820
- LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
1821
- (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
1822
- ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
1823
- ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
1824
- }
1825
- }
1826
-
1827
- void llama_kv_cache_recurrent::clear() {
1828
- for (int32_t i = 0; i < (int32_t) size; ++i) {
1829
- cells[i].pos = -1;
1830
- cells[i].seq_id.clear();
1831
- cells[i].src = -1;
1832
- cells[i].tail = -1;
1833
- }
1834
- head = 0;
1835
- used = 0;
1836
-
1837
- for (auto & buf : bufs) {
1838
- ggml_backend_buffer_clear(buf.get(), 0);
1839
- }
1840
- }
1841
-
1842
- bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
1843
- uint32_t new_head = size;
1844
-
1845
- if (p0 < 0) {
1846
- p0 = 0;
1847
- }
1848
-
1849
- if (p1 < 0) {
1850
- p1 = std::numeric_limits<llama_pos>::max();
1851
- }
1852
-
1853
- // models like Mamba or RWKV can't have a state partially erased
1854
- if (seq_id >= (int64_t) size) {
1855
- // could be fatal
1856
- return false;
1857
- }
1858
- if (0 <= seq_id) {
1859
- int32_t & tail_id = cells[seq_id].tail;
1860
- if (tail_id >= 0) {
1861
- const kv_cell & cell = cells[tail_id];
1862
- // partial intersection is invalid
1863
- if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
1864
- return false;
1865
- }
1866
- // invalidate tails which will be cleared
1867
- if (p0 <= cell.pos && cell.pos < p1) {
1868
- tail_id = -1;
1869
- }
1870
- }
1871
- } else {
1872
- // seq_id is negative, then the range should include everything or nothing
1873
- if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
1874
- return false;
1875
- }
1876
- }
1877
-
1878
- for (uint32_t i = 0; i < size; ++i) {
1879
- if (cells[i].pos >= p0 && cells[i].pos < p1) {
1880
- if (seq_id < 0) {
1881
- cells[i].seq_id.clear();
1882
- } else if (cells[i].has_seq_id(seq_id)) {
1883
- cells[i].seq_id.erase(seq_id);
1884
- } else {
1885
- continue;
1886
- }
1887
- if (cells[i].is_empty()) {
1888
- // keep count of the number of used cells
1889
- if (cells[i].pos >= 0) {
1890
- used--;
1891
- }
1892
- cells[i].pos = -1;
1893
- cells[i].src = -1;
1894
- if (new_head == size) {
1895
- new_head = i;
1896
- }
1897
- }
1898
- }
1899
- }
1900
-
1901
- // If we freed up a slot, set head to it so searching can start there.
1902
- if (new_head != size && new_head < head) {
1903
- head = new_head;
1904
- }
1905
-
1906
- return true;
1907
- }
1908
-
1909
- void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
1910
- if (seq_id_src == seq_id_dst) {
1911
- return;
1912
- }
1913
-
1914
- if (p0 < 0) {
1915
- p0 = 0;
1916
- }
1917
-
1918
- if (p1 < 0) {
1919
- p1 = std::numeric_limits<llama_pos>::max();
1920
- }
1921
-
1922
- if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
1923
- kv_cell & tail_src = cells[seq_id_src];
1924
- kv_cell & tail_dst = cells[seq_id_dst];
1925
- if (tail_dst.tail >= 0) {
1926
- // clear destination seq_id if it wasn't empty
1927
- kv_cell & cell_dst = cells[tail_dst.tail];
1928
-
1929
- cell_dst.seq_id.erase(seq_id_dst);
1930
- tail_dst.tail = -1;
1931
- if (cell_dst.seq_id.empty()) {
1932
- cell_dst.pos = -1;
1933
- cell_dst.src = -1;
1934
- used -= 1;
1935
- }
1936
- }
1937
- if (tail_src.tail >= 0) {
1938
- kv_cell & cell_src = cells[tail_src.tail];
1939
-
1940
- cell_src.seq_id.insert(seq_id_dst);
1941
- tail_dst.tail = tail_src.tail;
1942
- }
1943
- }
1944
- }
1945
-
1946
- void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
1947
- uint32_t new_head = size;
1948
-
1949
- for (uint32_t i = 0; i < size; ++i) {
1950
- if ((llama_seq_id) i != seq_id) {
1951
- cells[i].tail = -1;
1952
- }
1953
-
1954
- if (!cells[i].has_seq_id(seq_id)) {
1955
- if (cells[i].pos >= 0) {
1956
- used--;
1957
- }
1958
-
1959
- cells[i].pos = -1;
1960
- cells[i].src = -1;
1961
- cells[i].seq_id.clear();
1962
-
1963
- if (new_head == size){
1964
- new_head = i;
1965
- }
1966
- } else {
1967
- cells[i].seq_id.clear();
1968
- cells[i].seq_id.insert(seq_id);
1969
- }
1970
- }
1971
-
1972
- // If we freed up a slot, set head to it so searching can start there.
1973
- if (new_head != size && new_head < head) {
1974
- head = new_head;
1975
- }
1976
- }
1977
-
1978
- void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
1979
- if (shift == 0) {
1980
- return;
1981
- }
1982
-
1983
- if (p0 < 0) {
1984
- p0 = 0;
1985
- }
1986
-
1987
- if (p1 < 0) {
1988
- p1 = std::numeric_limits<llama_pos>::max();
1989
- }
1990
-
1991
- // If there is no range then return early to avoid looping over the
1992
- if (p0 == p1) {
1993
- return;
1994
- }
1995
-
1996
- // for Mamba-like or RWKV models, only the pos needs to be shifted
1997
- if (0 <= seq_id && seq_id < (int64_t) size) {
1998
- const int32_t tail_id = cells[seq_id].tail;
1999
- if (tail_id >= 0) {
2000
- kv_cell & cell = cells[tail_id];
2001
- if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
2002
- cell.pos += shift;
2003
- }
2004
- }
2005
- }
2006
- }
2007
-
2008
- void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
2009
- if (d == 1) {
2010
- return;
2011
- }
2012
-
2013
- if (p0 < 0) {
2014
- p0 = 0;
2015
- }
2016
-
2017
- if (p1 < 0) {
2018
- p1 = std::numeric_limits<llama_pos>::max();
2019
- }
2020
-
2021
- // If there is no range then return early to avoid looping over the cache.
2022
- if (p0 == p1) {
2023
- return;
2024
- }
2025
-
2026
- // for Mamba-like or RWKV models, only the pos needs to be changed
2027
- if (0 <= seq_id && seq_id < (int64_t) size) {
2028
- const int32_t tail_id = cells[seq_id].tail;
2029
- if (tail_id >= 0) {
2030
- kv_cell & cell = cells[tail_id];
2031
- if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
2032
- cell.pos /= d;
2033
- }
2034
- }
2035
- }
2036
- }
2037
-
2038
- llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
2039
- llama_pos result = std::numeric_limits<llama_pos>::max();
2040
-
2041
- for (uint32_t i = 0; i < size; ++i) {
2042
- if (cells[i].has_seq_id(seq_id)) {
2043
- result = std::min(result, cells[i].pos);
2044
- }
2045
- }
2046
-
2047
- if (result == std::numeric_limits<llama_pos>::max()) {
2048
- result = -1;
2049
- }
2050
-
2051
- return result;
2052
- }
2053
-
2054
- llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
2055
- llama_pos result = -1;
2056
-
2057
- for (uint32_t i = 0; i < size; ++i) {
2058
- if (cells[i].has_seq_id(seq_id)) {
2059
- result = std::max(result, cells[i].pos);
2060
- }
2061
- }
2062
-
2063
- return result;
2064
- }
2065
-
2066
- void llama_kv_cache_recurrent::restore() {
2067
- if (pending.ranges.empty()) {
2068
- return;
2069
- }
2070
-
2071
- seq_rm(-1, -1, -1);
2072
- }
2073
-
2074
- void llama_kv_cache_recurrent::commit() {
2075
- pending.ranges.clear();
2076
- }
2077
-
2078
- bool llama_kv_cache_recurrent::update(llama_context & ctx) {
2079
- GGML_UNUSED(ctx);
2080
- return false;
2081
- }
2082
-
2083
- void llama_kv_cache_recurrent::defrag_sched(float thold) {
2084
- GGML_UNUSED(thold);
2085
- // noop
2086
- }
2087
-
2088
- void llama_kv_cache_recurrent::set_full() {
2089
- n = size;
2090
- head = 0;
2091
- }
2092
-
2093
- llama_sbatch llama_kv_cache_recurrent::sbatch_init(
2094
- const llama_batch & batch,
2095
- bool logits_all) {
2096
- return llama_sbatch(batch, hparams.n_embd, false, logits_all);
2097
- }
2098
-
2099
- llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
2100
- if (embd_pooled) {
2101
- // Pooled embeddings cannot be split across ubatches (yet)
2102
- return sbatch.split_seq(n_ubatch);
2103
- }
2104
-
2105
- return sbatch.split_equal(n_ubatch);
2106
- }
2107
-
2108
- bool llama_kv_cache_recurrent::find_slot(
2109
- const llama_ubatch & ubatch) {
2110
- const uint32_t n_tokens = ubatch.n_tokens;
2111
- const uint32_t n_seqs = ubatch.n_seqs;
2112
-
2113
- const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
2114
-
2115
- // if we have enough unused cells before the current head ->
2116
- // better to start searching from the beginning of the cache, hoping to fill it
2117
- if (head > used + 2*n_tokens) {
2118
- head = 0;
2119
- }
2120
-
2121
- // For recurrent state architectures (like Mamba or RWKV),
2122
- // each cache cell can store the state for a whole sequence.
2123
- // A slot should be always be contiguous.
2124
-
2125
- // can only process batches with an equal number of new tokens in each sequence
2126
- GGML_ASSERT(ubatch.equal_seqs);
2127
-
2128
- int32_t min = size - 1;
2129
- int32_t max = 0;
2130
-
2131
- // everything should fit if all seq_ids are smaller than the max
2132
- for (uint32_t s = 0; s < n_seqs; ++s) {
2133
- const uint32_t n_seq_id = ubatch.n_seq_id[s];
2134
- for (uint32_t j = 0; j < n_seq_id; ++j) {
2135
- const llama_seq_id seq_id = ubatch.seq_id[s][j];
2136
-
2137
- if (seq_id < 0 || (uint32_t) seq_id >= size) {
2138
- // too big seq_id
2139
- // TODO: would it be possible to resize the cache instead?
2140
- LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max);
2141
- return false;
2142
- }
2143
- if (j > 0) {
2144
- kv_cell & seq = cells[seq_id];
2145
- if (seq.tail >= 0) {
2146
- kv_cell & cell = cells[seq.tail];
2147
- // clear cells from seq_ids that become shared
2148
- // (should not normally happen, but let's handle it anyway)
2149
- cell.seq_id.erase(seq_id);
2150
- seq.tail = -1;
2151
- if (cell.seq_id.empty()) {
2152
- cell.pos = -1;
2153
- cell.src = -1;
2154
- used -= 1;
2155
- }
2156
- }
2157
- }
2158
- }
2159
- }
2160
-
2161
- #ifndef NDEBUG
2162
- {
2163
- std::vector<int32_t> tails_verif;
2164
- tails_verif.assign(size, -1);
2165
- for (uint32_t i = 0; i < size; ++i) {
2166
- kv_cell & cell = cells[i];
2167
- for (llama_seq_id seq_id : cell.seq_id) {
2168
- if (tails_verif[seq_id] != -1) {
2169
- LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
2170
- }
2171
- tails_verif[seq_id] = i;
2172
- }
2173
- }
2174
- for (uint32_t i = 0; i < size; ++i) {
2175
- if (tails_verif[i] != cells[i].tail) {
2176
- LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]);
2177
- }
2178
- }
2179
- }
2180
- #endif
2181
-
2182
- // find next empty cell
2183
- uint32_t next_empty_cell = head;
2184
-
2185
- for (uint32_t i = 0; i < size; ++i) {
2186
- if (next_empty_cell >= size) { next_empty_cell -= size; }
2187
- kv_cell & cell = cells[next_empty_cell];
2188
- if (cell.is_empty()) { break; }
2189
- next_empty_cell += 1;
2190
- }
2191
-
2192
- // find usable cell range
2193
- for (uint32_t s = 0; s < n_seqs; ++s) {
2194
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
2195
- kv_cell & seq_meta = cells[seq_id];
2196
- bool has_cell = false;
2197
- if (seq_meta.tail >= 0) {
2198
- kv_cell & cell = cells[seq_meta.tail];
2199
- GGML_ASSERT(cell.has_seq_id(seq_id));
2200
- // does this seq_id "own" the cell?
2201
- if (cell.seq_id.size() == 1) { has_cell = true; }
2202
- }
2203
- if (!has_cell) {
2204
- kv_cell & empty_cell = cells[next_empty_cell];
2205
- GGML_ASSERT(empty_cell.is_empty());
2206
- // copy old tail into the empty cell
2207
- if (seq_meta.tail >= 0) {
2208
- kv_cell & orig_cell = cells[seq_meta.tail];
2209
- empty_cell.pos = orig_cell.pos;
2210
- empty_cell.src = orig_cell.src;
2211
- orig_cell.seq_id.erase(seq_id);
2212
- empty_cell.seq_id.insert(seq_id); // will be overwritten
2213
- }
2214
- seq_meta.tail = next_empty_cell;
2215
- // find next empty cell
2216
- if (s + 1 < n_seqs) {
2217
- next_empty_cell += 1;
2218
- for (uint32_t i = 0; i < size; ++i) {
2219
- if (next_empty_cell >= size) { next_empty_cell -= size; }
2220
- kv_cell & cell = cells[next_empty_cell];
2221
- if (cell.is_empty()) { break; }
2222
- next_empty_cell += 1;
2223
- }
2224
- }
2225
- }
2226
- if (min > seq_meta.tail) { min = seq_meta.tail; }
2227
- if (max < seq_meta.tail) { max = seq_meta.tail; }
2228
- }
2229
-
2230
- // gather and re-order
2231
- for (uint32_t s = 0; s < n_seqs; ++s) {
2232
- int32_t dst_id = s + min;
2233
- int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
2234
- if (dst_id != src_id) {
2235
- kv_cell & dst_cell = cells[dst_id];
2236
- kv_cell & src_cell = cells[src_id];
2237
-
2238
- std::swap(dst_cell.pos, src_cell.pos);
2239
- std::swap(dst_cell.src, src_cell.src);
2240
- std::swap(dst_cell.seq_id, src_cell.seq_id);
2241
-
2242
- // swap tails (assuming they NEVER overlap)
2243
- for (const llama_seq_id seq_id : src_cell.seq_id) {
2244
- cells[seq_id].tail = src_id;
2245
- }
2246
- for (const llama_seq_id seq_id : dst_cell.seq_id) {
2247
- cells[seq_id].tail = dst_id;
2248
- }
2249
- }
2250
- }
2251
-
2252
- // update the pos of the used seqs
2253
- for (uint32_t s = 0; s < n_seqs; ++s) {
2254
- const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
2255
- int32_t cell_id = s + min;
2256
- kv_cell & cell = cells[cell_id];
2257
-
2258
- if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
2259
- // What should happen when the pos backtracks or skips a value?
2260
- // Clearing the state mid-batch would require special-casing which isn't done.
2261
- LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
2262
- __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
2263
- }
2264
- cell.pos = last_pos;
2265
- cell.seq_id.clear();
2266
- for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
2267
- const llama_seq_id seq_id = ubatch.seq_id[s][j];
2268
- cell.seq_id.insert(seq_id);
2269
- cells[seq_id].tail = cell_id;
2270
- }
2271
- }
2272
-
2273
- // allow getting the range of used cells, from head to head + n
2274
- head = min;
2275
- n = max - min + 1;
2276
- used = std::count_if(cells.begin(), cells.end(),
2277
- [](const kv_cell & cell){ return !cell.is_empty(); });
2278
-
2279
- // sanity check
2280
- return n >= n_seqs;
2281
- }
2282
-
2283
- bool llama_kv_cache_recurrent::get_can_shift() const {
2284
- return false;
2285
- }
2286
-
2287
- int32_t llama_kv_cache_recurrent::s_copy(int i) const {
2288
- const uint32_t cell_id = i + head;
2289
-
2290
- //////////////////////////////////////////////
2291
- // TODO: this should not mutate the KV cache !
2292
- kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
2293
-
2294
- // prevent out-of-bound sources
2295
- if (cell.src < 0 || (uint32_t) cell.src >= size) {
2296
- cell.src = cell_id;
2297
- }
2298
-
2299
- int32_t res = cell.src;
2300
-
2301
- // TODO: do not mutate the KV cache
2302
- // ensure copy only happens once
2303
- if (cell.src != (int32_t) cell_id) {
2304
- cell.src = cell_id;
2305
- }
2306
-
2307
- return res;
2308
- }
2309
-
2310
- float llama_kv_cache_recurrent::s_mask(int i) const {
2311
- const uint32_t cell_id = i + head;
2312
-
2313
- //////////////////////////////////////////////
2314
- // TODO: this should not mutate the KV cache !
2315
- kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
2316
-
2317
- float res = (float) (cell.src >= 0);
2318
-
2319
- // only clear once
2320
- if (cell.src < 0) {
2321
- cell.src = cell_id;
2322
- }
2323
-
2324
- return res;
2325
- }
2326
-
2327
- uint32_t llama_kv_cache_recurrent::cell_max() const {
2328
- for (uint32_t i = size; i > 0; --i) {
2329
- const kv_cell & cell = cells[i - 1];
2330
-
2331
- if (cell.pos >= 0 && !cell.is_empty()) {
2332
- return i;
2333
- }
2334
- }
2335
-
2336
- return 0;
2337
- }
2338
-
2339
- size_t llama_kv_cache_recurrent::total_size() const {
2340
- size_t size = 0;
2341
- for (const auto & buf : bufs) {
2342
- size += ggml_backend_buffer_get_size(buf.get());
2343
- }
2344
-
2345
- return size;
2346
- }
2347
-
2348
- size_t llama_kv_cache_recurrent::size_k_bytes() const {
2349
- size_t size_k_bytes = 0;
2350
-
2351
- for (const auto & k : k_l) {
2352
- size_k_bytes += ggml_nbytes(k);
2353
- }
2354
-
2355
- return size_k_bytes;
2356
- }
2357
-
2358
- size_t llama_kv_cache_recurrent::size_v_bytes() const {
2359
- size_t size_v_bytes = 0;
2360
-
2361
- for (const auto & v : v_l) {
2362
- size_v_bytes += ggml_nbytes(v);
2363
- }
2364
-
2365
- return size_v_bytes;
2366
- }
2367
-
2368
- void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
2369
- std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
2370
- uint32_t cell_count = 0;
2371
-
2372
- // Count the number of cells with the specified seq_id
2373
- // Find all the ranges of cells with this seq id (or all, when -1)
2374
- uint32_t cell_range_begin = size;
2375
- for (uint32_t i = 0; i < size; ++i) {
2376
- const auto & cell = cells[i];
2377
- if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
2378
- ++cell_count;
2379
- if (cell_range_begin == size) {
2380
- cell_range_begin = i;
2381
- }
2382
- } else {
2383
- if (cell_range_begin != size) {
2384
- cell_ranges.emplace_back(cell_range_begin, i);
2385
- cell_range_begin = size;
2386
- }
2387
- }
2388
- }
2389
- if (cell_range_begin != size) {
2390
- cell_ranges.emplace_back(cell_range_begin, size);
2391
- }
2392
-
2393
- // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
2394
- uint32_t cell_count_check = 0;
2395
- for (const auto & range : cell_ranges) {
2396
- cell_count_check += range.second - range.first;
2397
- }
2398
- GGML_ASSERT(cell_count == cell_count_check);
2399
-
2400
- io.write(&cell_count, sizeof(cell_count));
2401
-
2402
- state_write_meta(io, cell_ranges, seq_id);
2403
- state_write_data(io, cell_ranges);
2404
- }
2405
-
2406
- void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
2407
- uint32_t cell_count;
2408
- io.read_to(&cell_count, sizeof(cell_count));
2409
-
2410
- bool res = true;
2411
-
2412
- res = res && state_read_meta(io, cell_count, seq_id);
2413
- res = res && state_read_data(io, cell_count);
2414
-
2415
- if (!res) {
2416
- if (seq_id == -1) {
2417
- clear();
2418
- } else {
2419
- seq_rm(seq_id, -1, -1);
2420
- }
2421
- throw std::runtime_error("failed to restore kv cache");
2422
- }
2423
- }
2424
-
2425
- void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
2426
- for (const auto & range : cell_ranges) {
2427
- for (uint32_t i = range.first; i < range.second; ++i) {
2428
- const auto & cell = cells[i];
2429
- const llama_pos pos = cell.pos;
2430
- const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
2431
-
2432
- io.write(&pos, sizeof(pos));
2433
- io.write(&n_seq_id, sizeof(n_seq_id));
2434
-
2435
- if (n_seq_id) {
2436
- for (auto seq_id : cell.seq_id) {
2437
- io.write(&seq_id, sizeof(seq_id));
2438
- }
2439
- }
2440
- }
2441
- }
2442
- }
2443
-
2444
- void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
2445
- const uint32_t v_trans = 0;
2446
- const uint32_t n_layer = hparams.n_layer;
2447
-
2448
- io.write(&v_trans, sizeof(v_trans));
2449
- io.write(&n_layer, sizeof(n_layer));
2450
-
2451
- std::vector<uint8_t> tmp_buf;
2452
-
2453
- // Iterate and write all the keys first, each row is a cell
2454
- // Get whole range at a time
2455
- for (uint32_t il = 0; il < n_layer; ++il) {
2456
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
2457
-
2458
- // Write key type
2459
- const int32_t k_type_i = (int32_t)k_l[il]->type;
2460
- io.write(&k_type_i, sizeof(k_type_i));
2461
-
2462
- // Write row size of key
2463
- const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
2464
- io.write(&k_size_row, sizeof(k_size_row));
2465
-
2466
- // Read each range of cells of k_size length each into tmp_buf and write out
2467
- for (const auto & range : cell_ranges) {
2468
- const size_t range_size = range.second - range.first;
2469
- const size_t buf_size = range_size * k_size_row;
2470
- io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
2471
- }
2472
- }
2473
-
2474
- if (!v_trans) {
2475
- for (uint32_t il = 0; il < n_layer; ++il) {
2476
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2477
-
2478
- // Write value type
2479
- const int32_t v_type_i = (int32_t)v_l[il]->type;
2480
- io.write(&v_type_i, sizeof(v_type_i));
2481
-
2482
- // Write row size of value
2483
- const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
2484
- io.write(&v_size_row, sizeof(v_size_row));
2485
-
2486
- // Read each range of cells of v_size length each into tmp_buf and write out
2487
- for (const auto & range : cell_ranges) {
2488
- const size_t range_size = range.second - range.first;
2489
- const size_t buf_size = range_size * v_size_row;
2490
- io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
2491
- }
2492
- }
2493
- } else {
2494
- // When v is transposed, we also need the element size and get the element ranges from each row
2495
- const uint32_t kv_size = size;
2496
- for (uint32_t il = 0; il < n_layer; ++il) {
2497
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2498
-
2499
- // Write value type
2500
- const int32_t v_type_i = (int32_t)v_l[il]->type;
2501
- io.write(&v_type_i, sizeof(v_type_i));
2502
-
2503
- // Write element size
2504
- const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
2505
- io.write(&v_size_el, sizeof(v_size_el));
2506
-
2507
- // Write GQA embedding size
2508
- io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
2509
-
2510
- // For each row, we get the element values of each cell
2511
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
2512
- // Read each range of cells of v_size_el length each into tmp_buf and write out
2513
- for (const auto & range : cell_ranges) {
2514
- const size_t range_size = range.second - range.first;
2515
- const size_t src_offset = (range.first + j * kv_size) * v_size_el;
2516
- const size_t buf_size = range_size * v_size_el;
2517
- io.write_tensor(v_l[il], src_offset, buf_size);
2518
- }
2519
- }
2520
- }
2521
- }
2522
- }
2523
-
2524
- bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
2525
- if (dest_seq_id != -1) {
2526
- // single sequence
2527
-
2528
- seq_rm(dest_seq_id, -1, -1);
2529
-
2530
- llama_sbatch sbatch;
2531
- llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
2532
-
2533
- batch.n_tokens = cell_count;
2534
- batch.n_seq_tokens = cell_count;
2535
- batch.n_seqs = 1;
2536
-
2537
- for (uint32_t i = 0; i < cell_count; ++i) {
2538
- llama_pos pos;
2539
- uint32_t n_seq_id;
2540
-
2541
- io.read_to(&pos, sizeof(pos));
2542
- io.read_to(&n_seq_id, sizeof(n_seq_id));
2543
-
2544
- if (n_seq_id != 0) {
2545
- LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
2546
- return false;
2547
- }
2548
-
2549
- batch.pos[i] = pos;
2550
- }
2551
- batch.n_seq_id[0] = 1;
2552
- batch.seq_id[0] = &dest_seq_id;
2553
- if (!find_slot(batch)) {
2554
- LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
2555
- return false;
2556
- }
2557
- commit();
2558
-
2559
- // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
2560
- // Assume that this is one contiguous block of cells
2561
- GGML_ASSERT(head + cell_count <= size);
2562
- GGML_ASSERT(cells[head].pos == batch.pos[0]);
2563
- GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
2564
- GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
2565
- GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
2566
- } else {
2567
- // whole KV cache restore
2568
-
2569
- if (cell_count > size) {
2570
- LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
2571
- return false;
2572
- }
2573
-
2574
- clear();
2575
-
2576
- for (uint32_t i = 0; i < cell_count; ++i) {
2577
- kv_cell & cell = cells[i];
2578
-
2579
- llama_pos pos;
2580
- uint32_t n_seq_id;
2581
-
2582
- io.read_to(&pos, sizeof(pos));
2583
- io.read_to(&n_seq_id, sizeof(n_seq_id));
2584
-
2585
- cell.pos = pos;
2586
-
2587
- for (uint32_t j = 0; j < n_seq_id; ++j) {
2588
- llama_seq_id seq_id;
2589
- io.read_to(&seq_id, sizeof(seq_id));
2590
-
2591
- // TODO: llama_kv_cache_recurrent should have a notion of max sequences
2592
- //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
2593
- if (seq_id < 0) {
2594
- //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
2595
- LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
2596
- return false;
2597
- }
2598
-
2599
- cell.seq_id.insert(seq_id);
2600
-
2601
- int32_t & tail = cells[seq_id].tail;
2602
- if (tail != -1) {
2603
- LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
2604
- return false;
2605
- }
2606
- tail = i;
2607
- }
2608
- }
2609
-
2610
- head = 0;
2611
- used = cell_count;
2612
- }
2613
-
2614
- for (uint32_t i = 0; i < cell_count; ++i) {
2615
- uint32_t cell_id = head + i;
2616
- // make sure the recurrent states will keep their restored state
2617
- cells[cell_id].src = cell_id;
2618
- }
2619
-
2620
- return true;
2621
- }
2622
-
2623
- bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
2624
- uint32_t v_trans;
2625
- uint32_t n_layer;
2626
- io.read_to(&v_trans, sizeof(v_trans));
2627
- io.read_to(&n_layer, sizeof(n_layer));
2628
-
2629
- if (n_layer != hparams.n_layer) {
2630
- LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
2631
- return false;
2632
- }
2633
- if (cell_count > size) {
2634
- LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
2635
- return false;
2636
- }
2637
- if (false != (bool) v_trans) {
2638
- LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
2639
- return false;
2640
- }
2641
-
2642
- // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
2643
- for (uint32_t il = 0; il < n_layer; ++il) {
2644
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
2645
-
2646
- // Read type of key
2647
- int32_t k_type_i_ref;
2648
- io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
2649
- const int32_t k_type_i = (int32_t) k_l[il]->type;
2650
- if (k_type_i != k_type_i_ref) {
2651
- LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
2652
- return false;
2653
- }
2654
-
2655
- // Read row size of key
2656
- uint64_t k_size_row_ref;
2657
- io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
2658
- const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
2659
- if (k_size_row != k_size_row_ref) {
2660
- LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
2661
- return false;
2662
- }
2663
-
2664
- if (cell_count) {
2665
- // Read and set the keys for the whole cell range
2666
- ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
2667
- }
2668
- }
2669
-
2670
- if (!v_trans) {
2671
- for (uint32_t il = 0; il < n_layer; ++il) {
2672
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2673
-
2674
- // Read type of value
2675
- int32_t v_type_i_ref;
2676
- io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
2677
- const int32_t v_type_i = (int32_t)v_l[il]->type;
2678
- if (v_type_i != v_type_i_ref) {
2679
- LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
2680
- return false;
2681
- }
2682
-
2683
- // Read row size of value
2684
- uint64_t v_size_row_ref;
2685
- io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
2686
- const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
2687
- if (v_size_row != v_size_row_ref) {
2688
- LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
2689
- return false;
2690
- }
2691
-
2692
- if (cell_count) {
2693
- // Read and set the values for the whole cell range
2694
- ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
2695
- }
2696
- }
2697
- } else {
2698
- // For each layer, read the values for each cell (transposed)
2699
- for (uint32_t il = 0; il < n_layer; ++il) {
2700
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2701
-
2702
- // Read type of value
2703
- int32_t v_type_i_ref;
2704
- io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
2705
- const int32_t v_type_i = (int32_t)v_l[il]->type;
2706
- if (v_type_i != v_type_i_ref) {
2707
- LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
2708
- return false;
2709
- }
2710
-
2711
- // Read element size of value
2712
- uint32_t v_size_el_ref;
2713
- io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
2714
- const size_t v_size_el = ggml_type_size(v_l[il]->type);
2715
- if (v_size_el != v_size_el_ref) {
2716
- LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
2717
- return false;
2718
- }
2719
-
2720
- // Read GQA embedding size
2721
- uint32_t n_embd_v_gqa_ref;
2722
- io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
2723
- if (n_embd_v_gqa != n_embd_v_gqa_ref) {
2724
- LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
2725
- return false;
2726
- }
2727
-
2728
- if (cell_count) {
2729
- // For each row in the transposed matrix, read the values for the whole cell range
2730
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
2731
- const size_t dst_offset = (head + j * size) * v_size_el;
2732
- ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
2733
- }
2734
- }
2735
- }
2736
- }
2737
-
2738
- return true;
2739
- }
 
1
  #include "llama-kv-cache.h"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/talk-llama/llama-kv-cache.h CHANGED
@@ -2,59 +2,33 @@
2
 
3
  #include "llama.h"
4
  #include "llama-io.h"
5
- #include "llama-graph.h"
6
  #include "llama-memory.h"
7
- #include "llama-kv-cells.h"
8
-
9
- #include "ggml-cpp.h"
10
-
11
- #include <set>
12
- #include <unordered_map>
13
- #include <vector>
14
-
15
- struct llama_cparams;
16
- struct llama_hparams;
17
- struct llama_ubatch;
18
- struct llama_sbatch;
19
- struct llama_model;
20
- struct llama_context;
21
 
22
  struct llama_kv_cache : public llama_memory_i {
23
  virtual ~llama_kv_cache() = default;
24
 
25
- // call if batch processing fails - restores the cache state
26
- virtual void restore() = 0;
 
 
 
 
 
 
27
 
28
- // call after successful batch processing - clears any pending state
29
- virtual void commit() = 0;
30
 
31
  // process any pending defrag/shift/etc. operations
32
  // optionally call once before processing a new batch
 
33
  virtual bool update(llama_context & lctx) = 0;
34
 
35
  // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
36
- virtual void defrag_sched(float thold) = 0;
37
-
38
- // simulate full cache, used for allocating worst-case compute buffers
39
- // TODO: remove
40
- virtual void set_full() = 0;
41
-
42
  //
43
- // batch processing
44
- //
45
-
46
- // =============================================================================================================
47
- // TODO: refactor and simplify this [TAG: KV_API]
48
-
49
- virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
50
-
51
- // different KV caches require different batch splitting strategies
52
- virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
53
-
54
- // find an empty slot of size "n_tokens" in the cache
55
- virtual bool find_slot(const llama_ubatch & batch) = 0;
56
-
57
- // =============================================================================================================
58
 
59
  // getters
60
  virtual bool get_can_shift() const = 0;
@@ -68,435 +42,3 @@ struct llama_kv_cache : public llama_memory_i {
68
  virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
69
  virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
70
  };
71
-
72
- //
73
- // llama_kv_cache_guard
74
- //
75
-
76
- struct llama_kv_cache_guard {
77
- llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
78
-
79
- ~llama_kv_cache_guard() {
80
- kv->restore();
81
- }
82
-
83
- void commit() {
84
- kv->commit();
85
- }
86
-
87
- private:
88
- llama_kv_cache * kv;
89
- };
90
-
91
- //
92
- // llama_kv_cache_unified
93
- //
94
-
95
- class llama_kv_cache_unified : public llama_kv_cache {
96
- public:
97
- static uint32_t get_padding(const llama_cparams & cparams);
98
-
99
- // this callback is used to filter out layers that should not be included in the cache
100
- using layer_filter_cb = std::function<bool(int32_t il)>;
101
-
102
- llama_kv_cache_unified(
103
- const llama_model & model,
104
- layer_filter_cb && filter,
105
- ggml_type type_k,
106
- ggml_type type_v,
107
- bool v_trans,
108
- bool offload,
109
- uint32_t kv_size,
110
- uint32_t n_seq_max,
111
- uint32_t n_pad,
112
- uint32_t n_swa,
113
- llama_swa_type swa_type);
114
-
115
- ~llama_kv_cache_unified() = default;
116
-
117
- //
118
- // llama_memory_i
119
- //
120
-
121
- void clear() override;
122
-
123
- bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
124
- void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
125
- void seq_keep(llama_seq_id seq_id) override;
126
- void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
127
- void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
128
-
129
- llama_pos seq_pos_min(llama_seq_id seq_id) const override;
130
- llama_pos seq_pos_max(llama_seq_id seq_id) const override;
131
-
132
- //
133
- // llama_kv_cache
134
- //
135
-
136
- void restore() override;
137
- void commit() override;
138
-
139
- bool update(llama_context & ctx) override;
140
-
141
- void defrag_sched(float thold) override;
142
-
143
- void set_full() override;
144
-
145
- llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
146
- llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
147
-
148
- // updates the cache head
149
- // Note: On success, it's important that cache.head points
150
- // to the first cell of the slot.
151
- bool find_slot(const llama_ubatch & batch) override;
152
-
153
- bool get_can_shift() const override;
154
-
155
- // state write/load
156
-
157
- void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
158
- void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
159
-
160
- //
161
- // llama_kv_cache_unified specific API
162
- //
163
-
164
- uint32_t get_n() const;
165
- uint32_t get_size() const;
166
-
167
- // get views of the current state of the cache
168
- ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
169
- ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
170
-
171
- // store k_cur and v_cur in the cache based on the current head location
172
- ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
173
- ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
174
-
175
- void prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax);
176
-
177
- void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
178
- void set_input_k_shift (ggml_tensor * dst) const;
179
- void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
180
-
181
- private:
182
- const llama_model & model;
183
- const llama_hparams & hparams;
184
-
185
- struct kv_layer {
186
- // layer index in the model
187
- // note: can be different from the layer index in the KV cache
188
- uint32_t il;
189
-
190
- ggml_tensor * k;
191
- ggml_tensor * v;
192
- };
193
-
194
- bool do_defrag = false;
195
- bool v_trans = true; // the value tensor is transposed
196
-
197
- uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
198
-
199
- // computed before each graph build
200
- // TODO: cells should start to maintain this value dynamically based on the edits
201
- uint32_t n = 0;
202
-
203
- const uint32_t n_seq_max = 1;
204
-
205
- // required padding
206
- const uint32_t n_pad = 1;
207
-
208
- // SWA
209
- const uint32_t n_swa = 0;
210
-
211
- const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
212
-
213
- std::vector<ggml_context_ptr> ctxs;
214
- std::vector<ggml_backend_buffer_ptr> bufs;
215
-
216
- llama_kv_cells_unified cells;
217
-
218
- std::vector<kv_layer> layers;
219
-
220
- // model layer id -> KV cache layer id
221
- std::unordered_map<int32_t, int32_t> map_layer_ids;
222
-
223
- // recovery information used to restore the KV cells to their original state in case of a failure
224
- // TODO: do not store as a state in the llama_kv_cache object, instead return upon batch preparation
225
- // to achieve that, first need to refactor the llama_kv_cache interface [TAG: KV_API]
226
- struct {
227
- void clear() {
228
- states.clear();
229
- }
230
-
231
- struct state {
232
- uint32_t i;
233
-
234
- llama_kv_cells_unified cells;
235
- };
236
-
237
- // stack with the partial states before each ubatch
238
- std::vector<state> states;
239
- } recovery;
240
-
241
- // defrag
242
- struct {
243
- std::vector<uint32_t> ids;
244
- } defrag_info;
245
-
246
- // return true if cells have been moved
247
- bool defrag_prepare(int32_t n_max_nodes);
248
-
249
- size_t total_size() const;
250
-
251
- size_t size_k_bytes() const;
252
- size_t size_v_bytes() const;
253
-
254
- bool is_masked_swa(llama_pos p0, llama_pos p1) const;
255
-
256
- ggml_tensor * build_rope_shift(
257
- const llama_cparams & cparams,
258
- ggml_context * ctx,
259
- ggml_tensor * cur,
260
- ggml_tensor * shift,
261
- ggml_tensor * factors,
262
- float freq_base,
263
- float freq_scale) const;
264
-
265
- llm_graph_result_ptr build_graph_shift(
266
- const llama_cparams & cparams,
267
- ggml_context * ctx,
268
- ggml_cgraph * gf) const;
269
-
270
- llm_graph_result_ptr build_graph_defrag(
271
- const llama_cparams & cparams,
272
- ggml_context * ctx,
273
- ggml_cgraph * gf) const;
274
-
275
- void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
276
- void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
277
-
278
- bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
279
- bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
280
- };
281
-
282
- //
283
- // llama_kv_cache_unified_iswa
284
- //
285
-
286
- // utilizes two instances of llama_kv_cache_unified
287
- // the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
288
- // upon successful commit, the SWA cache removes old tokens outside the n_swa window
289
-
290
- class llama_kv_cache_unified_iswa : public llama_kv_cache {
291
- public:
292
- llama_kv_cache_unified_iswa(
293
- const llama_model & model,
294
- ggml_type type_k,
295
- ggml_type type_v,
296
- bool v_trans,
297
- bool offload,
298
- bool swa_full,
299
- uint32_t kv_size,
300
- uint32_t n_seq_max,
301
- uint32_t n_batch,
302
- uint32_t n_pad);
303
-
304
- ~llama_kv_cache_unified_iswa() = default;
305
-
306
- //
307
- // llama_memory_i
308
- //
309
-
310
- void clear() override;
311
-
312
- bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
313
- void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
314
- void seq_keep(llama_seq_id seq_id) override;
315
- void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
316
- void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
317
-
318
- llama_pos seq_pos_min(llama_seq_id seq_id) const override;
319
- llama_pos seq_pos_max(llama_seq_id seq_id) const override;
320
-
321
- //
322
- // llama_kv_cache
323
- //
324
-
325
- void restore() override;
326
- void commit() override;
327
-
328
- bool update(llama_context & ctx) override;
329
-
330
- void defrag_sched(float thold) override;
331
-
332
- void set_full() override;
333
-
334
- llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
335
- llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
336
-
337
- bool find_slot(const llama_ubatch & batch) override;
338
-
339
- bool get_can_shift() const override;
340
-
341
- // state write/load
342
-
343
- void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
344
- void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
345
-
346
- //
347
- // llama_kv_cache_unified_iswa specific API
348
- //
349
-
350
- llama_kv_cache_unified * get_kv_base() const;
351
- llama_kv_cache_unified * get_kv_swa () const;
352
-
353
- private:
354
- const llama_hparams & hparams;
355
-
356
- bool do_prune = true;
357
-
358
- struct {
359
- struct entry {
360
- llama_pos pmin;
361
- llama_pos pmax;
362
- };
363
-
364
- void clear() {
365
- pos.clear();
366
- }
367
-
368
- // used to perform SWA pruning of old tokens
369
- std::unordered_map<llama_seq_id, entry> pos;
370
- } pending;
371
-
372
- std::unique_ptr<llama_kv_cache_unified> kv_base;
373
- std::unique_ptr<llama_kv_cache_unified> kv_swa;
374
- };
375
-
376
- //
377
- // llama_kv_cache_recurrent
378
- //
379
-
380
- class llama_kv_cache_recurrent : public llama_kv_cache {
381
- public:
382
- struct kv_cell {
383
- llama_pos pos = -1;
384
- int32_t src = -1; // used to copy states
385
- int32_t tail = -1;
386
-
387
- std::set<llama_seq_id> seq_id;
388
-
389
- bool has_seq_id(const llama_seq_id & id) const {
390
- return seq_id.find(id) != seq_id.end();
391
- }
392
-
393
- bool is_empty() const {
394
- return seq_id.empty();
395
- }
396
-
397
- bool is_same_seq(const kv_cell & other) const {
398
- return seq_id == other.seq_id;
399
- }
400
- };
401
-
402
- llama_kv_cache_recurrent(
403
- const llama_model & model,
404
- ggml_type type_k,
405
- ggml_type type_v,
406
- bool offload,
407
- uint32_t kv_size,
408
- uint32_t n_seq_max);
409
-
410
- ~llama_kv_cache_recurrent() = default;
411
-
412
- //
413
- // llama_memory_i
414
- //
415
-
416
- void clear() override;
417
-
418
- bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
419
- void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
420
- void seq_keep(llama_seq_id seq_id) override;
421
- void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
422
- void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
423
-
424
- llama_pos seq_pos_min(llama_seq_id seq_id) const override;
425
- llama_pos seq_pos_max(llama_seq_id seq_id) const override;
426
-
427
- //
428
- // llama_kv_cache
429
- //
430
-
431
- void restore() override;
432
- void commit() override;
433
-
434
- bool update(llama_context & ctx) override;
435
-
436
- void defrag_sched(float thold) override;
437
-
438
- void set_full() override;
439
-
440
- llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
441
- llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
442
-
443
- bool find_slot(const llama_ubatch & batch) override;
444
-
445
- bool get_can_shift() const override;
446
-
447
- // TODO: temporary methods - they are not really const as they do const_cast<>, fix this
448
- int32_t s_copy(int i) const;
449
- float s_mask(int i) const;
450
-
451
- // state write/load
452
-
453
- void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
454
- void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
455
-
456
- uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
457
- uint32_t size = 0; // total number of cells, shared across all sequences
458
- uint32_t used = 0; // used cells (i.e. at least one seq_id)
459
-
460
- // computed before each graph build
461
- uint32_t n = 0;
462
-
463
- std::vector<kv_cell> cells;
464
-
465
- std::vector<ggml_tensor *> k_l; // per layer
466
- std::vector<ggml_tensor *> v_l;
467
-
468
- private:
469
- //const llama_model & model;
470
- const llama_hparams & hparams;
471
-
472
- // commit/restore cache
473
- // TODO: rework for recurrent cache
474
- struct slot_range {
475
- uint32_t c0 = 0; // note: these are cell indices, not sequence positions
476
- uint32_t c1 = 0;
477
- };
478
-
479
- // pending cell updates that are not yet committed
480
- struct {
481
- std::vector<slot_range> ranges;
482
- } pending;
483
-
484
- const uint32_t n_seq_max = 1;
485
-
486
- std::vector<ggml_context_ptr> ctxs;
487
- std::vector<ggml_backend_buffer_ptr> bufs;
488
-
489
- // find how many cells are currently in use
490
- uint32_t cell_max() const;
491
-
492
- size_t total_size() const;
493
-
494
- size_t size_k_bytes() const;
495
- size_t size_v_bytes() const;
496
-
497
- void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
498
- void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
499
-
500
- bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
501
- bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
502
- };
 
2
 
3
  #include "llama.h"
4
  #include "llama-io.h"
 
5
  #include "llama-memory.h"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  struct llama_kv_cache : public llama_memory_i {
8
  virtual ~llama_kv_cache() = default;
9
 
10
+ // split the input batch into a set of ubatches and verify that they can fit into the cache
11
+ // return a state object containing the ubatches and KV cache state required to process them
12
+ // check the llama_memory_state_i::get_status() for the result
13
+ virtual llama_memory_state_ptr init_batch(
14
+ const llama_batch & batch,
15
+ uint32_t n_ubatch,
16
+ bool embd_pooled,
17
+ bool logits_all) = 0;
18
 
19
+ // simulate full cache, used for allocating worst-case compute buffers
20
+ virtual llama_memory_state_ptr init_full() = 0;
21
 
22
  // process any pending defrag/shift/etc. operations
23
  // optionally call once before processing a new batch
24
+ // return true if any operations were performed
25
  virtual bool update(llama_context & lctx) = 0;
26
 
27
  // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
28
+ // TODO: change to
29
+ // llama_memory_state_ptr init_defrag(float thold) = 0;
 
 
 
 
30
  //
31
+ virtual void defrag_sched(float thold) = 0;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  // getters
34
  virtual bool get_can_shift() const = 0;
 
42
  virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
43
  virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
44
  };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/talk-llama/llama-kv-cells.h CHANGED
@@ -68,12 +68,6 @@ public:
68
  // the index of the last cell that is used + 1
69
  // return 0 if no cells are used
70
  uint32_t used_max_p1() const {
71
- #if 0
72
- if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin());
73
- if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin());
74
- if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin());
75
- #endif
76
-
77
  return used.empty() ? 0 : *used.rbegin() + 1;
78
  }
79
 
@@ -144,6 +138,19 @@ public:
144
  }
145
  }
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  // note: call only if the cell has seq_id
148
  // return true if the cell becomes empty
149
  bool seq_rm(uint32_t i, llama_seq_id seq_id) {
@@ -196,6 +203,15 @@ public:
196
  return false;
197
  }
198
 
 
 
 
 
 
 
 
 
 
199
  bool seq_has(uint32_t i, llama_seq_id seq_id) const {
200
  assert(i < pos.size());
201
  assert(seq_id >= 0);
@@ -213,6 +229,20 @@ public:
213
  seq_pos[seq_id].insert(pos[i]);
214
  }
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  // the minimum position of sequence seq_id currently present in any of the cells
217
  // return -1 if the sequence is not present
218
  llama_pos seq_pos_min(llama_seq_id seq_id) const {
@@ -268,6 +298,7 @@ public:
268
  void pos_set(uint32_t i, llama_pos p) {
269
  assert(i < pos.size());
270
  assert(pos[i] == -1);
 
271
 
272
  pos[i] = p;
273
 
 
68
  // the index of the last cell that is used + 1
69
  // return 0 if no cells are used
70
  uint32_t used_max_p1() const {
 
 
 
 
 
 
71
  return used.empty() ? 0 : *used.rbegin() + 1;
72
  }
73
 
 
138
  }
139
  }
140
 
141
+ // clear a non-empty cell
142
+ void rm(uint32_t i) {
143
+ assert(i < pos.size());
144
+ assert(pos[i] != -1);
145
+
146
+ seq_pos_rm(i);
147
+
148
+ pos[i] = -1;
149
+ seq[i].reset();
150
+
151
+ used.erase(i);
152
+ }
153
+
154
  // note: call only if the cell has seq_id
155
  // return true if the cell becomes empty
156
  bool seq_rm(uint32_t i, llama_seq_id seq_id) {
 
203
  return false;
204
  }
205
 
206
+ // number of different sequences in the cell
207
+ int seq_count(uint32_t i) const {
208
+ assert(i < pos.size());
209
+ assert(pos[i] != -1);
210
+
211
+ return seq[i].count();
212
+ }
213
+
214
+ // check if the cell contains seq_id
215
  bool seq_has(uint32_t i, llama_seq_id seq_id) const {
216
  assert(i < pos.size());
217
  assert(seq_id >= 0);
 
229
  seq_pos[seq_id].insert(pos[i]);
230
  }
231
 
232
+ // return the sequence id of this cell
233
+ // note: call only for cells with exactly one sequence
234
+ llama_seq_id seq_get(uint32_t i) const {
235
+ assert(seq[i].count() == 1);
236
+
237
+ for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
238
+ if (seq[i].test(s)) {
239
+ return s;
240
+ }
241
+ }
242
+
243
+ return -1;
244
+ }
245
+
246
  // the minimum position of sequence seq_id currently present in any of the cells
247
  // return -1 if the sequence is not present
248
  llama_pos seq_pos_min(llama_seq_id seq_id) const {
 
298
  void pos_set(uint32_t i, llama_pos p) {
299
  assert(i < pos.size());
300
  assert(pos[i] == -1);
301
+ assert(seq[i].none());
302
 
303
  pos[i] = p;
304
 
examples/talk-llama/llama-memory.h CHANGED
@@ -2,6 +2,11 @@
2
 
3
  #include "llama.h"
4
 
 
 
 
 
 
5
  struct llama_memory_params {
6
  // kv cache
7
  ggml_type type_k;
@@ -30,3 +35,42 @@ public:
30
 
31
  virtual bool get_can_edit() const = 0;
32
  };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  #include "llama.h"
4
 
5
+ #include <memory>
6
+ #include <vector>
7
+
8
+ struct llama_ubatch;
9
+
10
  struct llama_memory_params {
11
  // kv cache
12
  ggml_type type_k;
 
35
 
36
  virtual bool get_can_edit() const = 0;
37
  };
38
+
39
+ enum llama_memory_status {
40
+ LLAMA_MEMORY_STATUS_SUCCESS = 0,
41
+ LLAMA_MEMORY_STATUS_FAILED_PREPARE,
42
+ LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
43
+ };
44
+
45
+ // the interface for managing the memory state during batch processing
46
+ // this interface is implemented per memory type. see:
47
+ // - llama_kv_cache_unified_state
48
+ // - llama_kv_cache_unified_iswa_state
49
+ // ...
50
+ //
51
+ // the only method that can mutate the memory and the memory state is llama_memory_i::apply()
52
+ //
53
+ // TODO: rename to llama_memory_context_i ?
54
+ class llama_memory_state_i {
55
+ public:
56
+ virtual ~llama_memory_state_i() = default;
57
+
58
+ // consume the current ubatch from the state and proceed to the next one
59
+ // return false if we are done
60
+ virtual bool next() = 0;
61
+
62
+ // apply the memory state for the current ubatch to the memory object
63
+ // return false on failure
64
+ virtual bool apply() = 0;
65
+
66
+ // TODO: this might get reworked in the future when refactoring llama_batch
67
+ virtual std::vector<int64_t> & out_ids() = 0;
68
+
69
+ // get the current ubatch
70
+ virtual const llama_ubatch & get_ubatch() const = 0;
71
+
72
+ // get the status of the memory state
73
+ virtual llama_memory_status get_status() const = 0;
74
+ };
75
+
76
+ using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
examples/talk-llama/llama-model.cpp CHANGED
@@ -5,7 +5,10 @@
5
  #include "llama-batch.h"
6
  #include "llama-cparams.h"
7
  #include "llama-model-loader.h"
8
- #include "llama-kv-cache.h"
 
 
 
9
 
10
  #include "ggml-cpp.h"
11
 
@@ -683,6 +686,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
683
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
684
  ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
685
  ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
 
686
 
687
  switch (hparams.n_layer) {
688
  case 3:
@@ -2113,7 +2117,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2113
  case LLM_ARCH_NOMIC_BERT_MOE:
2114
  {
2115
  tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
2116
- type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0);
2117
 
2118
  if (arch == LLM_ARCH_BERT) {
2119
  pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0);
@@ -2121,8 +2125,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2121
  cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
2122
  cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED);
2123
 
2124
- cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED);
2125
- cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {1}, TENSOR_NOT_REQUIRED);
2126
  }
2127
 
2128
  tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
@@ -2131,7 +2135,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2131
  for (int i = 0; i < n_layer; ++i) {
2132
  auto & layer = layers[i];
2133
 
2134
- if (arch == LLM_ARCH_BERT) {
 
 
 
2135
  layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
2136
  layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
2137
 
@@ -2140,12 +2147,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2140
 
2141
  layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
2142
  layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
2143
- } else {
2144
- layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
2145
- }
2146
-
2147
- if (arch == LLM_ARCH_NOMIC_BERT_MOE) {
2148
- layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0);
2149
  }
2150
 
2151
  layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
@@ -5887,8 +5888,10 @@ struct llm_build_bert : public llm_graph_context {
5887
  inpL = build_inp_embd(model.tok_embd);
5888
 
5889
  // token types are hardcoded to zero ("Sentence A")
5890
- ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
5891
- inpL = ggml_add(ctx0, inpL, type_row0);
 
 
5892
  if (model.arch == LLM_ARCH_BERT) {
5893
  inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL);
5894
  }
@@ -5909,36 +5912,11 @@ struct llm_build_bert : public llm_graph_context {
5909
  ggml_tensor * Vcur;
5910
 
5911
  // self-attention
5912
- if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) {
5913
- Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
5914
-
5915
- if (model.layers[il].attn_q_norm) {
5916
- Qcur = build_norm(Qcur,
5917
- model.layers[il].attn_q_norm,
5918
- model.layers[il].attn_q_norm_b,
5919
- LLM_NORM, il);
5920
- }
5921
-
5922
- Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
5923
-
5924
- if (model.layers[il].attn_k_norm) {
5925
- Kcur = build_norm(Kcur,
5926
- model.layers[il].attn_k_norm,
5927
- model.layers[il].attn_k_norm_b,
5928
- LLM_NORM, il);
5929
- }
5930
-
5931
- Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
5932
-
5933
- Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
5934
- Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
5935
- Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
5936
- } else {
5937
- // compute Q and K and RoPE them
5938
  cur = build_lora_mm(model.layers[il].wqkv, cur);
5939
  cb(cur, "wqkv", il);
5940
 
5941
- if (model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
5942
  cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
5943
  cb(cur, "bqkv", il);
5944
  }
@@ -5946,11 +5924,32 @@ struct llm_build_bert : public llm_graph_context {
5946
  Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
5947
  Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
5948
  Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
 
 
 
 
 
5949
 
5950
- Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
5951
- Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
5952
- Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
 
 
5953
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5954
  Qcur = ggml_rope_ext(
5955
  ctx0, Qcur, inp_pos, nullptr,
5956
  n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -8896,9 +8895,9 @@ struct llm_build_mamba : public llm_graph_context {
8896
  ggml_tensor * state_mask,
8897
  const llama_ubatch & ubatch,
8898
  int il) const {
8899
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
8900
 
8901
- const auto kv_head = kv_self->head;
8902
 
8903
  const int64_t d_conv = hparams.ssm_d_conv;
8904
  const int64_t d_inner = hparams.ssm_d_inner;
@@ -8916,8 +8915,8 @@ struct llm_build_mamba : public llm_graph_context {
8916
  GGML_ASSERT(ubatch.equal_seqs);
8917
  GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
8918
 
8919
- ggml_tensor * conv_states_all = kv_self->k_l[il];
8920
- ggml_tensor * ssm_states_all = kv_self->v_l[il];
8921
 
8922
  // (ab)using the KV cache to store the states
8923
  ggml_tensor * conv = build_copy_mask_state(
@@ -11644,7 +11643,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
11644
  ggml_tensor * state_mask,
11645
  const llama_ubatch & ubatch,
11646
  int il) const {
11647
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
11648
 
11649
  const auto n_tokens = ubatch.n_tokens;
11650
  const auto n_seqs = ubatch.n_seqs;
@@ -11654,7 +11653,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
11654
  const auto n_head = n_embd / head_size;
11655
  const auto n_head_kv = hparams.n_head_kv(il);
11656
 
11657
- const auto kv_head = kv_self->head;
11658
 
11659
  const auto & layer = model.layers[il];
11660
 
@@ -11766,7 +11765,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
11766
  }
11767
 
11768
  ggml_tensor * wkv_state = build_copy_mask_state(
11769
- gf, kv_self->v_l[il], state_copy, state_mask,
11770
  hparams.n_embd_v_s(), n_seqs);
11771
 
11772
  ggml_tensor * wkv_output;
@@ -11785,9 +11784,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
11785
  wkv_state,
11786
  ggml_view_1d(
11787
  ctx0,
11788
- kv_self->v_l[il],
11789
  hparams.n_embd_v_s() * n_seqs,
11790
- hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il])
11791
  )
11792
  )
11793
  );
@@ -12040,7 +12039,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
12040
  ggml_tensor *& first_layer_value,
12041
  const llama_ubatch & ubatch,
12042
  int il) const {
12043
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
12044
 
12045
  const auto n_tokens = ubatch.n_tokens;
12046
  const auto n_seqs = ubatch.n_seqs;
@@ -12049,7 +12048,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
12049
  const auto head_count = n_embd / head_size;
12050
  const auto n_seq_tokens = ubatch.n_seq_tokens;
12051
 
12052
- const auto kv_head = kv_self->head;
12053
 
12054
  const auto & layer = model.layers[il];
12055
 
@@ -12120,7 +12119,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
12120
  a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
12121
 
12122
  ggml_tensor * wkv_state = build_copy_mask_state(
12123
- gf, kv_self->v_l[il], state_copy, state_mask,
12124
  hparams.n_embd_v_s(), n_seqs);
12125
 
12126
  ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
@@ -12134,9 +12133,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
12134
  wkv_state,
12135
  ggml_view_1d(
12136
  ctx0,
12137
- kv_self->v_l[il],
12138
  hparams.n_embd_v_s() * n_seqs,
12139
- hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il])
12140
  )
12141
  )
12142
  );
@@ -13234,7 +13233,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
13234
  params.swa_full,
13235
  cparams.n_ctx,
13236
  cparams.n_seq_max,
13237
- cparams.n_batch,
13238
  padding);
13239
  } else {
13240
  GGML_ASSERT(!hparams.is_swa_any());
@@ -13266,7 +13265,6 @@ llm_graph_result_ptr llama_model::build_graph(
13266
 
13267
  switch (arch) {
13268
  case LLM_ARCH_LLAMA:
13269
- case LLM_ARCH_MINICPM:
13270
  {
13271
  llm = std::make_unique<llm_build_llama>(*this, params, gf);
13272
  } break;
@@ -13507,6 +13505,7 @@ llm_graph_result_ptr llama_model::build_graph(
13507
  } break;
13508
  case LLM_ARCH_GRANITE:
13509
  case LLM_ARCH_GRANITE_MOE:
 
13510
  {
13511
  llm = std::make_unique<llm_build_granite>(*this, params, gf);
13512
  } break;
@@ -13597,6 +13596,10 @@ int32_t llama_model_n_head_kv(const llama_model * model) {
13597
  return model->hparams.n_head_kv();
13598
  }
13599
 
 
 
 
 
13600
  // deprecated
13601
  int32_t llama_n_ctx_train(const llama_model * model) {
13602
  return llama_model_n_ctx_train(model);
 
5
  #include "llama-batch.h"
6
  #include "llama-cparams.h"
7
  #include "llama-model-loader.h"
8
+
9
+ #include "llama-kv-cache-unified.h"
10
+ #include "llama-kv-cache-unified-iswa.h"
11
+ #include "llama-kv-cache-recurrent.h"
12
 
13
  #include "ggml-cpp.h"
14
 
 
686
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
687
  ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
688
  ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
689
+ ml.get_arr_n(LLM_KV_CLASSIFIER_OUTPUT_LABELS, hparams.n_cls_out, false);
690
 
691
  switch (hparams.n_layer) {
692
  case 3:
 
2117
  case LLM_ARCH_NOMIC_BERT_MOE:
2118
  {
2119
  tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
2120
+ type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED);
2121
 
2122
  if (arch == LLM_ARCH_BERT) {
2123
  pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0);
 
2125
  cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
2126
  cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED);
2127
 
2128
+ cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
2129
+ cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
2130
  }
2131
 
2132
  tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
 
2135
  for (int i = 0; i < n_layer; ++i) {
2136
  auto & layer = layers[i];
2137
 
2138
+ layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
2139
+ layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
2140
+
2141
+ if (!layer.wqkv) {
2142
  layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
2143
  layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
2144
 
 
2147
 
2148
  layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
2149
  layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
 
 
 
 
 
 
2150
  }
2151
 
2152
  layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
 
5888
  inpL = build_inp_embd(model.tok_embd);
5889
 
5890
  // token types are hardcoded to zero ("Sentence A")
5891
+ if (model.type_embd) {
5892
+ ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
5893
+ inpL = ggml_add(ctx0, inpL, type_row0);
5894
+ }
5895
  if (model.arch == LLM_ARCH_BERT) {
5896
  inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL);
5897
  }
 
5912
  ggml_tensor * Vcur;
5913
 
5914
  // self-attention
5915
+ if (model.layers[il].wqkv) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5916
  cur = build_lora_mm(model.layers[il].wqkv, cur);
5917
  cb(cur, "wqkv", il);
5918
 
5919
+ if (model.layers[il].bqkv) {
5920
  cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
5921
  cb(cur, "bqkv", il);
5922
  }
 
5924
  Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
5925
  Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
5926
  Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
5927
+ } else {
5928
+ Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
5929
+ Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
5930
+ Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
5931
+ }
5932
 
5933
+ if (model.layers[il].attn_q_norm) {
5934
+ Qcur = build_norm(Qcur,
5935
+ model.layers[il].attn_q_norm,
5936
+ model.layers[il].attn_q_norm_b,
5937
+ LLM_NORM, il);
5938
+ }
5939
 
5940
+ if (model.layers[il].attn_k_norm) {
5941
+ Kcur = build_norm(Kcur,
5942
+ model.layers[il].attn_k_norm,
5943
+ model.layers[il].attn_k_norm_b,
5944
+ LLM_NORM, il);
5945
+ }
5946
+
5947
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
5948
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
5949
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
5950
+
5951
+ // RoPE
5952
+ if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
5953
  Qcur = ggml_rope_ext(
5954
  ctx0, Qcur, inp_pos, nullptr,
5955
  n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
 
8895
  ggml_tensor * state_mask,
8896
  const llama_ubatch & ubatch,
8897
  int il) const {
8898
+ const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
8899
 
8900
+ const auto kv_head = kv_state->get_head();
8901
 
8902
  const int64_t d_conv = hparams.ssm_d_conv;
8903
  const int64_t d_inner = hparams.ssm_d_inner;
 
8915
  GGML_ASSERT(ubatch.equal_seqs);
8916
  GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
8917
 
8918
+ ggml_tensor * conv_states_all = kv_state->get_k_l(il);
8919
+ ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
8920
 
8921
  // (ab)using the KV cache to store the states
8922
  ggml_tensor * conv = build_copy_mask_state(
 
11643
  ggml_tensor * state_mask,
11644
  const llama_ubatch & ubatch,
11645
  int il) const {
11646
+ const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
11647
 
11648
  const auto n_tokens = ubatch.n_tokens;
11649
  const auto n_seqs = ubatch.n_seqs;
 
11653
  const auto n_head = n_embd / head_size;
11654
  const auto n_head_kv = hparams.n_head_kv(il);
11655
 
11656
+ const auto kv_head = kv_state->get_head();
11657
 
11658
  const auto & layer = model.layers[il];
11659
 
 
11765
  }
11766
 
11767
  ggml_tensor * wkv_state = build_copy_mask_state(
11768
+ gf, kv_state->get_v_l(il), state_copy, state_mask,
11769
  hparams.n_embd_v_s(), n_seqs);
11770
 
11771
  ggml_tensor * wkv_output;
 
11784
  wkv_state,
11785
  ggml_view_1d(
11786
  ctx0,
11787
+ kv_state->get_v_l(il),
11788
  hparams.n_embd_v_s() * n_seqs,
11789
+ hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il))
11790
  )
11791
  )
11792
  );
 
12039
  ggml_tensor *& first_layer_value,
12040
  const llama_ubatch & ubatch,
12041
  int il) const {
12042
+ const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
12043
 
12044
  const auto n_tokens = ubatch.n_tokens;
12045
  const auto n_seqs = ubatch.n_seqs;
 
12048
  const auto head_count = n_embd / head_size;
12049
  const auto n_seq_tokens = ubatch.n_seq_tokens;
12050
 
12051
+ const auto kv_head = kv_state->get_head();
12052
 
12053
  const auto & layer = model.layers[il];
12054
 
 
12119
  a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
12120
 
12121
  ggml_tensor * wkv_state = build_copy_mask_state(
12122
+ gf, kv_state->get_v_l(il), state_copy, state_mask,
12123
  hparams.n_embd_v_s(), n_seqs);
12124
 
12125
  ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
 
12133
  wkv_state,
12134
  ggml_view_1d(
12135
  ctx0,
12136
+ kv_state->get_v_l(il),
12137
  hparams.n_embd_v_s() * n_seqs,
12138
+ hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il))
12139
  )
12140
  )
12141
  );
 
13233
  params.swa_full,
13234
  cparams.n_ctx,
13235
  cparams.n_seq_max,
13236
+ cparams.n_ubatch,
13237
  padding);
13238
  } else {
13239
  GGML_ASSERT(!hparams.is_swa_any());
 
13265
 
13266
  switch (arch) {
13267
  case LLM_ARCH_LLAMA:
 
13268
  {
13269
  llm = std::make_unique<llm_build_llama>(*this, params, gf);
13270
  } break;
 
13505
  } break;
13506
  case LLM_ARCH_GRANITE:
13507
  case LLM_ARCH_GRANITE_MOE:
13508
+ case LLM_ARCH_MINICPM:
13509
  {
13510
  llm = std::make_unique<llm_build_granite>(*this, params, gf);
13511
  } break;
 
13596
  return model->hparams.n_head_kv();
13597
  }
13598
 
13599
+ int32_t llama_model_n_swa(const llama_model * model) {
13600
+ return model->hparams.n_swa;
13601
+ }
13602
+
13603
  // deprecated
13604
  int32_t llama_n_ctx_train(const llama_model * model) {
13605
  return llama_model_n_ctx_train(model);
examples/talk-llama/llama.h CHANGED
@@ -259,9 +259,9 @@ extern "C" {
259
  llama_token * token;
260
  float * embd;
261
  llama_pos * pos;
262
- int32_t * n_seq_id;
263
- llama_seq_id ** seq_id;
264
- int8_t * logits; // TODO: rename this to "output"
265
  } llama_batch;
266
 
267
  enum llama_model_kv_override_type {
@@ -366,6 +366,8 @@ extern "C" {
366
  bool no_perf; // measure performance timings
367
  bool op_offload; // offload host tensor operations to device
368
  bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
 
 
369
  };
370
 
371
  // model quantization parameters
@@ -502,6 +504,7 @@ extern "C" {
502
  LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
503
  LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
504
  LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
 
505
 
506
  // Get the model's RoPE frequency scaling factor
507
  LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
@@ -652,7 +655,6 @@ extern "C" {
652
  // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
653
  // If the KV cache is RoPEd, the KV data is updated accordingly:
654
  // - lazily on next llama_decode()
655
- // - explicitly with llama_kv_self_update()
656
  // p0 < 0 : [0, p1]
657
  // p1 < 0 : [p0, inf)
658
  LLAMA_API void llama_kv_self_seq_add(
@@ -665,7 +667,6 @@ extern "C" {
665
  // Integer division of the positions by factor of `d > 1`
666
  // If the KV cache is RoPEd, the KV data is updated accordingly:
667
  // - lazily on next llama_decode()
668
- // - explicitly with llama_kv_self_update()
669
  // p0 < 0 : [0, p1]
670
  // p1 < 0 : [p0, inf)
671
  LLAMA_API void llama_kv_self_seq_div(
@@ -677,12 +678,14 @@ extern "C" {
677
 
678
  // Returns the smallest position present in the KV cache for the specified sequence
679
  // This is typically non-zero only for SWA caches
 
680
  // Return -1 if the sequence is empty
681
  LLAMA_API llama_pos llama_kv_self_seq_pos_min(
682
  struct llama_context * ctx,
683
  llama_seq_id seq_id);
684
 
685
  // Returns the largest position present in the KV cache for the specified sequence
 
686
  // Return -1 if the sequence is empty
687
  LLAMA_API llama_pos llama_kv_self_seq_pos_max(
688
  struct llama_context * ctx,
@@ -691,14 +694,15 @@ extern "C" {
691
  // Defragment the KV cache
692
  // This will be applied:
693
  // - lazily on next llama_decode()
694
- // - explicitly with llama_kv_self_update()
695
- LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx);
696
 
697
  // Check if the context supports KV cache shifting
698
  LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
699
 
700
  // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
701
- LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
 
702
 
703
  //
704
  // State / sessions
 
259
  llama_token * token;
260
  float * embd;
261
  llama_pos * pos;
262
+ int32_t * n_seq_id; // TODO: remove, should belong to only 1 sequence
263
+ llama_seq_id ** seq_id; // TODO: become llama_seq_id * seq_id;
264
+ int8_t * logits; // TODO: rename this to "output"
265
  } llama_batch;
266
 
267
  enum llama_model_kv_override_type {
 
366
  bool no_perf; // measure performance timings
367
  bool op_offload; // offload host tensor operations to device
368
  bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
369
+ // NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
370
+ // ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
371
  };
372
 
373
  // model quantization parameters
 
504
  LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
505
  LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
506
  LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
507
+ LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
508
 
509
  // Get the model's RoPE frequency scaling factor
510
  LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
 
655
  // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
656
  // If the KV cache is RoPEd, the KV data is updated accordingly:
657
  // - lazily on next llama_decode()
 
658
  // p0 < 0 : [0, p1]
659
  // p1 < 0 : [p0, inf)
660
  LLAMA_API void llama_kv_self_seq_add(
 
667
  // Integer division of the positions by factor of `d > 1`
668
  // If the KV cache is RoPEd, the KV data is updated accordingly:
669
  // - lazily on next llama_decode()
 
670
  // p0 < 0 : [0, p1]
671
  // p1 < 0 : [p0, inf)
672
  LLAMA_API void llama_kv_self_seq_div(
 
678
 
679
  // Returns the smallest position present in the KV cache for the specified sequence
680
  // This is typically non-zero only for SWA caches
681
+ // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
682
  // Return -1 if the sequence is empty
683
  LLAMA_API llama_pos llama_kv_self_seq_pos_min(
684
  struct llama_context * ctx,
685
  llama_seq_id seq_id);
686
 
687
  // Returns the largest position present in the KV cache for the specified sequence
688
+ // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
689
  // Return -1 if the sequence is empty
690
  LLAMA_API llama_pos llama_kv_self_seq_pos_max(
691
  struct llama_context * ctx,
 
694
  // Defragment the KV cache
695
  // This will be applied:
696
  // - lazily on next llama_decode()
697
+ LLAMA_API DEPRECATED(void llama_kv_self_defrag(struct llama_context * ctx),
698
+ "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
699
 
700
  // Check if the context supports KV cache shifting
701
  LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
702
 
703
  // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
704
+ LLAMA_API DEPRECATED(void llama_kv_self_update(struct llama_context * ctx),
705
+ "simply remove this call, updates are applied lazily on the next llama_decode()");
706
 
707
  //
708
  // State / sessions